Source code for pegasus.plotting.plot_library

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt

from scipy.sparse import issparse
from pandas.api.types import is_numeric_dtype, is_categorical_dtype, is_list_like
from scipy.stats import zscore
from sklearn.metrics import adjusted_mutual_info_score
from natsort import natsorted

import anndata
from pegasusio import UnimodalData, MultimodalData

from typing import List, Tuple, Union, Optional, Callable

import logging
logger = logging.getLogger(__name__)

from pegasus.tools import X_from_rep, slicing, largest_variance_from_random_matrix
from .plot_utils import (
    _transform_basis,
    _get_nrows_and_ncols,
    _get_marker_size,
    _get_dot_size,
    _get_subplot_layouts,
    _get_legend_ncol,
    _get_palette,
    RestrictionParser,
    DictWithDefault,
    _generate_categories,
    _plot_corners,
    _plot_spots,
    _get_valid_attrs,
)


[docs]def scatter( data: Union[MultimodalData, UnimodalData, anndata.AnnData], attrs: Optional[Union[str, List[str]]] = None, basis: Optional[Union[str, List[str]]] = "umap", components: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = (1, 2), matkey: Optional[str] = None, restrictions: Optional[Union[str, List[str]]] = None, show_background: Optional[bool] = False, fix_corners: Optional[bool] = True, alpha: Optional[Union[float, List[float]]] = 1.0, legend_loc: Optional[Union[str, List[str]]] = "right margin", legend_fontsize: Optional[Union[int, List[int]]] = 10, legend_ncol: Optional[str] = None, palettes: Optional[Union[str, List[str]]] = None, cmaps: Optional[Union[str, List[str]]] = "YlOrRd", vmin: Optional[Union[float, List[float]]] = None, vmax: Optional[Union[float, List[float]]] = None, nrows: Optional[int] = None, ncols: Optional[int] = None, panel_size: Optional[Tuple[float, float]] = (4, 4), left: Optional[float] = 0.2, bottom: Optional[float] = 0.15, wspace: Optional[float] = 0.4, hspace: Optional[float] = 0.15, marker_size: Optional[float] = None, scale_factor: Optional[float] = None, aspect: Optional[str] = "auto", return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, show_neg_for_sig: Optional[bool] = False, **kwargs, ) -> Union[plt.Figure, None]: """Generate scatter plots for different attributes Parameters ---------- data: ``pegasusio.MultimodalData`` Use current selected modality in data. attrs: ``str`` or ``List[str]``, default: None Color scatter plots by attrs. Each attribute in attrs can be one key in data.obs, data.var_names (e.g. one gene) or data.obsm (attribute has the format of 'obsm_key@component', like 'X_pca@0'). If one attribute is categorical, a palette will be used to color each category separately. Otherwise, a color map will be used. If no attributes are provided, plot the basis for all data. basis: ``str`` or ``List[str]``, optional, default: ``umap`` Basis to be used to generate scatter plots. Can be either 'pca', 'diffmap', 'umap', 'tsne', 'fitsne', 'fle', 'net_tsne', 'net_fitsne', 'net_umap' or 'net_fle'. If `basis` is a list, each of element in `attrs` will be plotted for each basis in `basis`. components: ``Tuple[int, int]`` or ``List[Tuple[int, int]]``, optional, default: ``(1, 2)`` Components in basis to be used. Default to the first two components. If `components` is a list, for each element in `attrs` and each `basis`, all components enumeration will be plotted. matkey: ``str``, optional, default: None If matkey is set, select matrix with matkey as keyword in the current modality. Only works for MultimodalData or UnimodalData objects. restrictions: ``str`` or ``List[str]``, optional, default: None A list of restrictions to subset data for plotting. There are two types of restrictions: global restriction and attribute-specific restriction. Global restriction appiles to all attributes in ``attrs`` and takes the format of 'key:value,value...', or 'key:~value,value...'. This restriction selects cells with the ``data.obs[key]`` values belong to 'value,value...' (or not belong to if '~' shows). Attribute-specific restriction takes the format of 'attr:key:value,value...', or 'attr:key:~value,value...'. It only applies to one attribute 'attr'. If 'attr' and 'key' are the same, one can use '.' to replace 'key' (e.g. ``cluster_labels:.:value1,value2``). show_background: ``bool``, optional, default: False Only applicable if `restrictions` is set. By default, only data points selected are shown. If show_background is True, data points that are not selected will also be shown. fix_corners: ``bool``, optional, default: True If True, fix the corners of the plots as defined using all data points. alpha: ``float`` or ``List[float]``, optional, default: ``1.0`` Alpha value for blending, from 0.0 (transparent) to 1.0 (opaque). If this is a list, the length must match attrs, which means we set a separate alpha value for each attribute. legend_loc: ``str`` or ``List[str]``, optional, default: ``right margin`` Legend location. Can be either "right margin" or "on data". If a list is provided, set 'legend_loc' for each attribute in 'attrs' separately. legend_fontsize: ``int`` or ``List[int]``, optional, default: ``10`` Legend fontsize. If a list is provided, set 'legend_fontsize' for each attribute in 'attrs' separately. legend_ncol: ``str``, optional, default: None Only applicable if legend_loc == "right margin". Set number of columns used to show legends. palettes: ``str`` or ``List[str]``, optional, default: None Used for setting colors for every categories in categorical attributes. Each string in ``palettes`` takes the format of 'attr:color1,color2,...,colorn'. 'attr' is the categorical attribute and 'color1' - 'colorn' are the colors for each category in 'attr' (e.g. 'cluster_labels:black,blue,red,...,yellow'). If there is only one categorical attribute in 'attrs', ``palletes`` can be set as a single string and the 'attr' keyword can be omitted (e.g. "blue,yellow,red"). cmaps: ``str`` or ``List[str]``, optional, default: ``YlOrRd`` Used for setting colormap for numeric attributes. Each string in ``cmaps`` takes the format of 'colormap' or 'attr:colormap'. 'colormap' sets the default colormap for all numeric attributes. 'attr:colormap' overwrites attribute 'attr's colormap as 'colormap'. vmin: ``float``, optional, default: None Minimum value to show on a numeric scatter plot (feature plot). vmax: ``float``, optional, default: None Maximum value to show on a numeric scatter plot (feature plot). nrows: ``int``, optional, default: None Number of rows in the figure. If not set, pegasus will figure it out automatically. ncols: ``int``, optional, default: None Number of columns in the figure. If not set, pegasus will figure it out automatically. panel_size: `tuple`, optional (default: `(6, 4)`) The panel size (width, height) in inches. left: `float`, optional (default: `0.2`) This parameter sets the figure's left margin as a fraction of panel's width (left * panel_size[0]). bottom: `float`, optional (default: `0.15`) This parameter sets the figure's bottom margin as a fraction of panel's height (bottom * panel_size[1]). wspace: `float`, optional (default: `0.4`) This parameter sets the width between panels and also the figure's right margin as a fraction of panel's width (wspace * panel_size[0]). hspace: `float`, optional (default: `0.15`) This parameter sets the height between panels and also the figure's top margin as a fraction of panel's height (hspace * panel_size[1]). marker_size: ``float``, optional (default: ``None``) Manually set the marker size in the plot. If ``None``, automatically adjust the marker size to the plot size. scale_factor: ``float``, optional (default: ``None``) Manually set the scale factor in the plot if it's not ``None``. This is used by generating the spatial plots for 10x Visium data. aspect: ``str``, optional (default: ``auto``) Set the aspect of the axis scaling, i.e. the ratio of y-unit to x-unit. Set ``auto`` to fill the position rectangle with data; ``equal`` for the same scaling for x and y. It applies to all subplots. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: 300.0 The resolution of the figure in dots-per-inch. show_neg_for_sig: ``bool``, optional, default: False For signature scores (i.e. attribute type registered as 'signature'), if we should show negative scores or show them as zeros. Default is False (i.e. show them as zeros). Returns ------- `Figure` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples -------- >>> pg.scatter(data, attrs=['louvain_labels', 'Channel'], basis='fitsne') >>> pg.scatter(data, attrs=['CD14', 'TRAC'], basis='umap') """ if not isinstance(data, anndata.AnnData): cur_matkey = data.current_matrix() if matkey is not None: assert not isinstance(data, anndata.AnnData) data.select_matrix(matkey) if attrs is None: attrs = ['_all'] # default, plot all points if palettes is None: palettes = '_all:slategrey' elif not is_list_like(attrs): attrs = [attrs] # Select only valid attributes attrs = _get_valid_attrs(data, attrs) if isinstance(basis, str): basis = [basis] if isinstance(components, tuple): components = [components] # check validity for basis and components max_comp = max(max([x[0] for x in components]), max([x[1] for x in components])) for basis_key in basis: rep = f"X_{basis_key}" if rep not in data.obsm: raise KeyError(f"Basis {basis_key} does not exist!") if data.obsm[rep].shape[1] < max_comp: raise KeyError(f"Basis {basis_key} only has {data.obsm[rep].shape[1]} components, less than max component {max_comp} specified in components!") nattrs = len(attrs) nbasis = len(basis) ncomps = len(components) nfigs = nattrs * nbasis * ncomps share_xy = (nbasis == 1) and (ncomps == 1) if not is_list_like(alpha): alpha = [alpha] * nattrs if not is_list_like(legend_loc): legend_loc = [legend_loc] * nattrs if not is_list_like(legend_fontsize): legend_fontsize = [legend_fontsize] * nattrs palettes = DictWithDefault(palettes) cmaps = DictWithDefault(cmaps) restr_obj = RestrictionParser(restrictions) restr_obj.calc_default(data) global_marker_size = None nrows, ncols = _get_nrows_and_ncols(nfigs, nrows, ncols) fig, axes = _get_subplot_layouts(nrows=nrows, ncols=ncols, panel_size=panel_size, dpi=dpi, left=left, bottom=bottom, wspace=wspace, hspace=hspace, squeeze=False, sharex=share_xy, sharey=share_xy) for i in range(nrows): for j in range(ncols): ax = axes[i, j] ax.grid(False) ax.set_xticks([]) ax.set_yticks([]) ax.set_aspect(aspect) if i * ncols + j >= nfigs: ax.set_frame_on(False) offset_start_ = 0 offset_inc_ = nbasis * ncomps for basis_key in basis: basis_ = _transform_basis(basis_key) for comp_key in components: x = data.obsm[f"X_{basis_key}"][:, comp_key[0]-1] y = data.obsm[f"X_{basis_key}"][:, comp_key[1]-1] # four corners of the plot corners = np.array(np.meshgrid([x.min(), x.max()], [y.min(), y.max()])).T.reshape(-1, 2) if global_marker_size == None: global_marker_size = _get_marker_size(x.size) if marker_size is None else marker_size x_label = f"{basis_}{comp_key[0]}" y_label = f"{basis_}{comp_key[1]}" pos = offset_start_ for attr_id, attr in enumerate(attrs): i = pos // ncols j = pos % ncols ax = axes[i, j] if attr == '_all': # if default values = pd.Categorical.from_codes(np.zeros(data.shape[0], dtype=int), categories=['cell']) elif attr in data.obs: values = data.obs[attr].values if data.get_attr_type(attr) == "signature" and (not show_neg_for_sig): values = values.copy() values[values < 0.0] = 0.0 elif attr in data.var_names: loc = data.var_names.get_loc(attr) values = slicing(data.X, col = loc) else: obsm_key, sep, component = attr.partition("@") values = data.obsm[obsm_key][:, int(component)] selected = restr_obj.get_satisfied(data, attr) local_marker_size = global_marker_size if (marker_size is None) and (not fix_corners) and (is_numeric_dtype(values) or (not show_background)): local_marker_size = _get_marker_size(selected.sum()) if is_numeric_dtype(values): # Numeric attribute cmap = cmaps.get(attr, squeeze = True) if cmap is None: raise KeyError(f"Please set colormap for attribute {attr} or set a default colormap!") if fix_corners: _plot_corners(ax, corners, local_marker_size) if scale_factor is None: img = ax.scatter( x[selected], y[selected], c=values[selected], s=local_marker_size, marker=".", alpha=alpha[attr_id], edgecolors="none", cmap=cmap, vmin=vmin if not isinstance(vmin, list) else vmin[attr_id], vmax=vmax if not isinstance(vmax, list) else vmax[attr_id], rasterized=True, ) else: img = _plot_spots( x[selected] * scale_factor, y[selected] * scale_factor, c=values[selected], s=local_marker_size, alpha=alpha[attr_id], edgecolors="none", cmap=cmap, vmin=vmin if not isinstance(vmin, list) else vmin[attr_id], vmax=vmax if not isinstance(vmax, list) else vmax[attr_id], rasterized=True, ax=ax, ) left, bottom, width, height = ax.get_position().bounds rect = [left + width * (1.0 + 0.05), bottom, width * 0.1, height] ax_colorbar = fig.add_axes(rect) fig.colorbar(img, cax=ax_colorbar) else: # Categorical attribute labels, with_background = _generate_categories(values, restr_obj.get_satisfied(data, attr)) label_size = labels.categories.size palette = palettes.get(attr) if palette is None: palette = _get_palette(label_size, with_background=with_background, show_background=show_background) elif with_background: palette = ["gainsboro" if show_background else "white"] + list(palette) text_list = [] for k, cat in enumerate(labels.categories): idx = labels == cat if idx.sum() > 0: scatter_kwargs = {"alpha": alpha[attr_id], "edgecolors": "none", "rasterized": True} if cat != "": if (legend_loc[attr_id] != "on data") and (scale_factor is None): scatter_kwargs["label"] = cat else: text_list.append((np.median(x[idx]), np.median(y[idx]), cat)) if cat != "" or (cat == "" and show_background): if scale_factor is None: ax.scatter( x[idx], y[idx], c=palette[k], s=local_marker_size, marker=".", **scatter_kwargs, ) else: _plot_spots( x[idx] * scale_factor, y[idx] * scale_factor, c=palette[k], s=local_marker_size, ax=ax, **scatter_kwargs, ) else: if fix_corners: _plot_corners(ax, corners, local_marker_size) if attr != '_all': if legend_loc[attr_id] == "right margin": if scale_factor is not None: for k, cat in enumerate(labels.categories): ax.scatter([], [], c=palette[k], label=cat) legend = ax.legend( loc="center left", bbox_to_anchor=(1, 0.5), frameon=False, fontsize=legend_fontsize[attr_id], ncol=_get_legend_ncol(label_size, legend_ncol), ) for handle in legend.legendHandles: handle.set_sizes([300.0 if scale_factor is None else 100.0]) elif legend_loc[attr_id] == "on data": texts = [] for px, py, txt in text_list: texts.append(ax.text(px, py, txt, fontsize=legend_fontsize[attr_id], fontweight = "bold", ha = "center", va = "center")) # from adjustText import adjust_text # adjust_text(texts, arrowprops=dict(arrowstyle='-', color='k', lw=0.5)) if attr != '_all': ax.set_title(attr) if (share_xy and (i + 1) * ncols + j >= nfigs) or (not share_xy): ax.set_xlabel(x_label) if (share_xy and j == 0) or (not share_xy): ax.set_ylabel(y_label) pos += offset_inc_ offset_start_ += 1 # Reset current matrix if needed. if not isinstance(data, anndata.AnnData): if cur_matkey != data.current_matrix(): data.select_matrix(cur_matkey) return fig if return_fig else None
[docs]def scatter_groups( data: Union[MultimodalData, UnimodalData, anndata.AnnData], attr: str, groupby: str, basis: Optional[str] = "umap", matkey: Optional[str] = None, restrictions: Optional[Union[str, List[str]]] = None, show_full: Optional[bool] = True, categories: Optional[List[str]] = None, alpha: Optional[float] = 1.0, legend_loc: Optional[str] = "right margin", legend_ncol: Optional[str] = None, palette: Optional[str] = None, cmap: Optional[str] = "YlOrRd", vmin: Optional[float] = None, vmax: Optional[float] = None, nrows: Optional[int] = None, ncols: Optional[int] = None, panel_size: Optional[Tuple[float, float]] = (4, 4), left: Optional[float] = 0.2, bottom: Optional[float] = 0.15, wspace: Optional[float] = 0.4, hspace: Optional[float] = 0.15, return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, show_neg_for_sig: Optional[bool] = False, **kwargs, ) -> Union[plt.Figure, None]: """ Generate scatter plots of attribute 'attr' for each category in attribute 'group'. Optionally show scatter plot containing data points from all categories in 'group'. Parameters ---------- data: ``pegasusio.MultimodalData`` Use current selected modality in data. attr: ``str`` Color scatter plots by attribute 'attr'. This attribute should be one key in data.obs, data.var_names (e.g. one gene) or data.obsm (attribute has the format of 'obsm_key@component', like 'X_pca@0'). If it is categorical, a palette will be used to color each category separately. Otherwise, a color map will be used. groupby: ``str`` Generate separate scatter plots of 'attr' for data points in each category in 'groupby', which should be a key in data.obs representing one categorical variable. basis: ``str``, optional, default: ``umap`` Basis to be used to generate scatter plots. Can be either 'umap', 'tsne', 'fitsne', 'fle', 'net_tsne', 'net_fitsne', 'net_umap' or 'net_fle'. matkey: ``str``, optional, default: None If matkey is set, select matrix with matkey as keyword in the current modality. Only works for MultimodalData or UnimodalData objects. restrictions: ``str`` or ``List[str]``, optional, default: None A list of restrictions to subset data for plotting. Each restriction takes the format of 'key:value,value...', or 'key:~value,value...'. This restriction selects cells with the ``data.obs[key]`` values belong to 'value,value...' (or not belong to if '~' shows). show_full: ``bool``, optional, default: True Show the scatter plot with all categories in 'groupby' as the first plot. categories: ``List[str]``, optional, default: None Redefine group structure based on attribute 'groupby'. If 'categories' is not None, each string in the list takes the format of 'category_name:value,value', or 'category_name:~value,value...", where 'category_name' refers to new category name, 'value' refers to one of the category in 'groupby' and '~' refers to exclude values. alpha: ``float``, optional, default: ``1.0`` Alpha value for blending, from 0.0 (transparent) to 1.0 (opaque). legend_loc: ``str``, optional, default: ``right margin`` Legend location. Can be either "right margin" or "on data". legend_ncol: ``str``, optional, default: None Only applicable if legend_loc == "right margin". Set number of columns used to show legends. palette: ``str``, optional, default: None Used for setting colors for one categorical attribute (e.g. "black,blue,red,...,yellow"). cmap: ``str``, optional, default: ``YlOrRd`` Used for setting colormap for one numeric attribute. vmin: ``float``, optional, default: None Minimum value to show on a numeric scatter plot (feature plot). vmax: ``float``, optional, default: None Maximum value to show on a numeric scatter plot (feature plot). nrows: ``int``, optional, default: None Number of rows in the figure. If not set, pegasus will figure it out automatically. ncols: ``int``, optional, default: None Number of columns in the figure. If not set, pegasus will figure it out automatically. panel_size: `tuple`, optional (default: `(6, 4)`) The panel size (width, height) in inches. left: `float`, optional (default: `0.2`) This parameter sets the figure's left margin as a fraction of panel's width (left * panel_size[0]). bottom: `float`, optional (default: `0.15`) This parameter sets the figure's bottom margin as a fraction of panel's height (bottom * panel_size[1]). wspace: `float`, optional (default: `0.4`) This parameter sets the width between panels and also the figure's right margin as a fraction of panel's width (wspace * panel_size[0]). hspace: `float`, optional (defualt: `0.15`) This parameter sets the height between panels and also the figure's top margin as a fraction of panel's height (hspace * panel_size[1]). return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: 300.0 The resolution of the figure in dots-per-inch. show_neg_for_sig: ``bool``, optional, default: False For signature scores (i.e. attribute type registered as 'signature'), if we should show negative scores or show them as zeros. Default is False (i.e. show them as zeros). Returns ------- `Figure` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples -------- >>> pg.scatter_groups(data, attr='louvain_labels', groupby='Individual', basis='tsne', nrows = 2, ncols = 4, alpha = 0.5) >>> pg.scatter_groups(data, attr='anno', groupby='Channel', basis='umap', categories=['new_cat1:channel1,channel2', 'new_cat2:channel3']) """ if not isinstance(data, anndata.AnnData): cur_matkey = data.current_matrix() if matkey is not None: assert not isinstance(data, anndata.AnnData) data.select_matrix(matkey) x = data.obsm[f"X_{basis}"][:, 0] y = data.obsm[f"X_{basis}"][:, 1] # four corners of the plot corners = np.array(np.meshgrid([x.min(), x.max()], [y.min(), y.max()])).T.reshape(-1, 2) basis = _transform_basis(basis) marker_size = _get_marker_size(x.size) if attr in data.obs: values = data.obs[attr].values if data.get_attr_type(attr) == "signature" and (not show_neg_for_sig): values = values.copy() values[values < 0.0] = 0.0 elif attr in data.var_names: loc = data.var_names.get_loc(attr) values = slicing(data.X, col = loc) else: obsm_key, sep, component = attr.partition("@") if (sep != "@") or (obsm_key not in data.obsm) or (not component.isdigit()): raise KeyError(f"{attr} is not in data.obs, data.var_names or data.obsm!") values = data.obsm[obsm_key][:, int(component)] is_cat = is_categorical_dtype(values) if (not is_cat) and (not is_numeric_dtype(values)): values = pd.Categorical(values, categories=natsorted(np.unique(values))) is_cat = True assert groupby in data.obs groups = data.obs[groupby].values if not is_categorical_dtype(groups): groups = pd.Categorical(groups, categories=natsorted(np.unique(groups))) restr_obj = RestrictionParser(restrictions) restr_obj.calc_default(data) selected = restr_obj.get_satisfied(data) nsel = selected.sum() if nsel < data.shape[0]: x = x[selected] y = y[selected] values = values[selected] groups = groups[selected] df_g = pd.DataFrame() if show_full: df_g["All"] = np.ones(nsel, dtype=bool) if categories is None: for cat in groups.categories: df_g[cat] = groups == cat else: cat_obj = RestrictionParser(categories) for cat, idx in cat_obj.next_category(groups): df_g[cat] = idx nrows, ncols = _get_nrows_and_ncols(df_g.shape[1], nrows, ncols) fig, axes = _get_subplot_layouts(nrows=nrows, ncols=ncols, panel_size=panel_size, dpi=dpi, left=left, bottom=bottom, wspace=wspace, hspace=hspace, squeeze=False) legend_fontsize = 5 if legend_loc == 'on data' else 10 if is_cat: labels = values label_size = labels.categories.size palette = _get_palette(label_size) if palette is None else np.array(palette.split(",")) legend_ncol = _get_legend_ncol(label_size, legend_ncol) for i in range(nrows): for j in range(ncols): ax = axes[i, j] ax.grid(False) ax.set_xticks([]) ax.set_yticks([]) gid = i * ncols + j if gid < df_g.shape[1]: if is_cat: text_list = [] for k, cat in enumerate(labels.categories): idx = np.logical_and(df_g.iloc[:, gid].values, labels == cat) _plot_corners(ax, corners, marker_size) if idx.sum() > 0: scatter_kwargs = {"marker": ".", "alpha": alpha, "edgecolors": "none", "rasterized": True} if legend_loc != "on data": scatter_kwargs["label"] = str(cat) else: text_list.append((np.median(x[idx]), np.median(y[idx]), str(cat))) ax.scatter( x[idx], y[idx], c=palette[k], s=marker_size, **scatter_kwargs, ) if legend_loc == "right margin": legend = ax.legend( loc="center left", bbox_to_anchor=(1, 0.5), frameon=False, fontsize=legend_fontsize, ncol=legend_ncol, ) for handle in legend.legendHandles: handle.set_sizes([300.0]) elif legend_loc == "on data": texts = [] for px, py, txt in text_list: texts.append(ax.text(px, py, txt, fontsize=legend_fontsize, fontweight = "bold", ha = "center", va = "center")) else: _plot_corners(ax, corners, marker_size) idx_g = df_g.iloc[:, gid].values img = ax.scatter( x[idx_g], y[idx_g], s=marker_size, c=values[idx_g], marker=".", alpha=alpha, edgecolors="none", cmap=cmap, vmin=vmin, vmax=vmax, rasterized=True, ) left, bottom, width, height = ax.get_position().bounds rect = [left + width * (1.0 + 0.05), bottom, width * 0.1, height] ax_colorbar = fig.add_axes(rect) fig.colorbar(img, cax=ax_colorbar) ax.set_title(str(df_g.columns[gid])) else: ax.set_frame_on(False) if i == nrows - 1: ax.set_xlabel(basis + "1") if j == 0: ax.set_ylabel(basis + "2") if not isinstance(data, anndata.AnnData): if cur_matkey != data.current_matrix(): data.select_matrix(cur_matkey) return fig if return_fig else None
[docs]def spatial( data: Union[MultimodalData, UnimodalData, anndata.AnnData], attrs: Optional[Union[str, List[str]]] = None, basis: str = 'spatial', resolution: str = 'hires', cmaps: Optional[Union[str, List[str]]] = 'viridis', restrictions: Optional[Union[str, List[str]]] = None, show_background: Optional[bool] = False, palettes: Optional[Union[str, List[str]]] = None, vmin: Optional[Union[float, List[float]]] = None, vmax: Optional[Union[float, List[float]]] = None, alpha: Union[float, List[float]] = 1.0, alpha_img: float = 1.0, nrows: Optional[int] = None, ncols: Optional[int] = None, y_flip: bool = True, margin_percent: float = 0.05, aspect: Optional[str] = "equal", dpi: float = 300.0, return_fig: bool = False, **kwargs, ) -> Union[plt.Figure, None]: """Scatter plot on spatial coordinates. This function is inspired by SCANPY's `pl.spatial <https://scanpy.readthedocs.io/en/latest/generated/scanpy.pl.spatial.html#scanpy-pl-spatial>`_ function. Parameters ---------- data: ``pegasusio.MultimodalData`` or ``pegasusio.UnimodalData`` or ``anndata.AnnData`` Use current selected modality in data. attr: ``str``, optional, default: ``None`` Color scatter plots by attribute 'attr'. This attribute should be one key in data.obs, data.var_names (e.g. one gene) or data.obsm (attribute has the format of 'obsm_key@component', like 'X_pca@0'). If it is categorical, a palette will be used to color each category separately. Otherwise, a color map will be used. If ``None``, just plot data points of the same color. basis: ``str``, optional, default: ``spatial`` Basis to be used to generate spatial plots. Must be the 2D array showing the spatial coordinates of data points. resolution: ``str``, optional, default: ``hires`` Use the spatial image whose value is specified in ``data.img['image_id']`` to show in background. For 10X Visium data, user can either specify ``hires`` or ``lowres`` to use High or Low resolution spatial images, respectively. Alternatively, if ``data.img`` does not exist, then no spatial image will be shown. cmaps: ``str`` or ``List[str]``, optional, default: ``viridis`` The colormap(s) for plotting numeric attributes. The default ``viridis`` colormap theme follows the spatial plot function in SCANPY (``scanpy.pl.spatial``). restrictions: ``str`` or ``List[str]``, optional, default: None A list of restrictions to subset data for plotting. There are two types of restrictions: global restriction and attribute-specific restriction. Global restriction appiles to all attributes in ``attrs`` and takes the format of 'key:value,value...', or 'key:~value,value...'. This restriction selects cells with the ``data.obs[key]`` values belong to 'value,value...' (or not belong to if '~' shows). Attribute-specific restriction takes the format of 'attr:key:value,value...', or 'attr:key:~value,value...'. It only applies to one attribute 'attr'. If 'attr' and 'key' are the same, one can use '.' to replace 'key' (e.g. ``cluster_labels:.:value1,value2``). show_background: ``bool``, optional, default: False Only applicable if `restrictions` is set. By default, only data points selected are shown. If show_background is True, data points that are not selected will also be shown. palettes: ``str`` or ``List[str]``, optional, default: None Used for setting colors for every categories in categorical attributes. Each string in ``palettes`` takes the format of 'attr:color1,color2,...,colorn'. 'attr' is the categorical attribute and 'color1' - 'colorn' are the colors for each category in 'attr' (e.g. 'cluster_labels:black,blue,red,...,yellow'). If there is only one categorical attribute in 'attrs', ``palletes`` can be set as a single string and the 'attr' keyword can be omitted (e.g. "blue,yellow,red"). vmin: ``float``, optional, default: ``None`` Minimum value to show on a numeric scatter plot (feature plot). vmax: ``float``, optional, default: ``None`` Maximum value to show on a numeric scatter plot (feature plot). alpha: ``float`` or ``List[float]``, optional, default: ``1.0`` Alpha value for blending the attribute layers, from 0.0 (transparent) to 1.0 (opaque). If this is a list, the length must match attrs, which means we set a separate alpha value for each attribute. alpha_img: ``float``, optional, default: ``1.0`` Alpha value for blending the background spatial image, from 0.0 (transparent) to 1.0 (opaque). nrows: ``int``, optional, default: ``None`` Number of rows in the figure. If not set, pegasus will figure it out automatically. ncols: ``int``, optional, default: ``None`` Number of columns in the figure. If not set, pegasus will figure it out automatically. y_flip: ``bool``, optional, default: ``False`` Set to ``True`` if flipping the y axis is needed. This is for the case when y-coordinate origin starts from the top. For 10x Visium data, if ``resolution`` is specified, this parameter is then ignored. margin_percent: ``float``, optional, default: ``0.05`` The margin is set to ``margin_percent``*100% of the smaller edge of the image size in each of the 4 sides. aspect:``str``, optional (default: ``equal``) Set the aspect of the axis scaling, i.e. the ratio of y-unit to x-unit. Set ``auto`` to fill the position rectangle with data; ``equal`` for the same scaling for x and y. It applies to all subplots. dpi: ``float``, optional, default: ``300.0`` The resolution of the figure in dots-per-inch. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. Returns ------- ``Figure`` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples -------- >>> pg.spatial(data, attrs=['louvain_labels', 'Channel']) >>> pg.spatial(data, attrs=['CD14', 'TRAC'], resolution='lowres') """ assert f"X_{basis}" in data.obsm.keys(), f"'X_{basis}' coordinates do not exist!" if data.img is None: resolution = None #assert data.img, "The spatial image data are missing!" elif resolution: assert resolution in data.img['image_id'].values, f"'{resolution}' image does not exist!" if attrs is not None: if not is_list_like(attrs): attrs = [attrs] # Select only valid attributes attrs = _get_valid_attrs(data, attrs) nattrs = len(attrs) if attrs is not None else 1 if resolution: image_item = data.img.loc[data.img['image_id']==resolution] image_obj = image_item['data'].iat[0] scale_factor = image_item['scale_factor'].iat[0] spot_radius = image_item['spot_diameter'].iat[0] * 0.5 else: scale_factor = None spot_radius = None fig = scatter( data=data, attrs=attrs, basis=basis, marker_size=spot_radius, scale_factor=scale_factor, cmaps=cmaps, restrictions=restrictions, show_background=show_background, palettes=palettes, vmin=vmin, vmax=vmax, nrows=nrows, ncols=ncols, dpi=dpi, alpha=alpha, aspect=aspect, return_fig=True, ) if scale_factor is None: scale_factor = 1.0 coord_x = (data.obsm[f"X_{basis}"][:, 0].min() * scale_factor, data.obsm[f"X_{basis}"][:, 0].max() * scale_factor) coord_y = (data.obsm[f"X_{basis}"][:, 1].min() * scale_factor, data.obsm[f"X_{basis}"][:, 1].max() * scale_factor) margin_offset = min(np.abs(coord_x[1] - coord_x[0]), np.abs(coord_y[1] - coord_y[0])) * margin_percent for i in range(nattrs): ax = fig.axes[i] if resolution: ax.imshow(image_obj, alpha=alpha_img) ax.set_xlim(coord_x[0]-margin_offset, coord_x[1]+margin_offset) if resolution or y_flip: ax.set_ylim(coord_y[1]+margin_offset, coord_y[0]-margin_offset) else: ax.set_ylim(coord_y[0]-margin_offset, coord_y[1]+margin_offset) return fig if return_fig else None
[docs]def compo_plot( data: Union[MultimodalData, UnimodalData, anndata.AnnData], groupby: str, condition: str, style: Optional[str] = "frequency", restrictions: Optional[Union[str, List[str]]] = None, switch_axes: Optional[bool] = False, groupby_label: Optional[str] = None, sort_function: Union[Callable[[List[str]], List[str]], str] = 'natsorted', panel_size: Optional[Tuple[float, float]] = (6, 4), palette: Optional[List[str]] = None, color_unused: bool = False, left: Optional[float] = 0.15, bottom: Optional[float] = 0.15, wspace: Optional[float] = 0.3, hspace: Optional[float] = 0.15, return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwargs, ) -> Union[plt.Figure, None]: """Generate a composition plot, which shows the percentage of cells from each condition for every cluster. This function is used to generate composition plots, which are bar plots showing the cell compositions (from different conditions) for each cluster. This type of plots is useful to fast assess library quality and batch effects. Parameters ---------- data : ``AnnData`` or ``UnimodalData`` or ``MultimodalData`` object Single cell expression data. groupby : ``str`` A categorical variable in data.obs that is used to categorize the cells, e.g. cell type. condition: ``str`` A categorical variable in data.obs that is used to calculate frequency within each category defined by ``groupby``, e.g. donor. style: ``str``, optional (default: ``frequency``) Composition plot style. Can be either ``frequency``, or ``normalized``. Within each cluster, the ``frequency`` style show the percentage of cells from each ``condition`` within each category in ``groupby`` (stacked), the ``normalized`` style shows for each category in ``groupby`` the percentage of cells that are also in each ``condition`` over all cells that are in the same ``condition`` (not stacked). restrictions: ``str`` or ``List[str]``, optional, default: None A list of restrictions to subset data for plotting. Each restriction takes the format of 'key:value,value...', or 'key:~value,value...'. This restriction selects cells with the ``data.obs[key]`` values belong to 'value,value...' (or not belong to if '~' shows). switch_axes: ``bool``, optional, default: ``False`` By default, X axis is for groupby, and Y axis for frequencies with respect to condition. If this parameter is ``True``, switch the axes. groupby_label: ``str``, optional (default ``None``) Label for the axis displaying ``groupby`` categories. If ``None``, use ``groupby``. sort_function: ``Union[Callable[List[str], List[str]], str]``, optional, default: ``natsorted`` Function used for sorting both groupby and condition labels. If ``natsorted``, apply natsorted function to sort by natural order. If ``None``, don't sort. Otherwise, a callable function will be applied to the labels for sorting. panel_size: ``tuple``, optional (default: ``(6, 4)``) The plot size (width, height) in inches. palette: ``List[str]``, optional (default: ``None``) Used for setting colors for categories in ``condition``. Within the list, each string is the color for one category. left: ``float``, optional (default: ``0.15``) This parameter sets the figure's left margin as a fraction of panel's width (left * panel_size[0]). bottom: ``float``, optional (default: ``0.15``) This parameter sets the figure's bottom margin as a fraction of panel's height (bottom * panel_size[1]). wspace: ``float``, optional (default: ``0.3``) This parameter sets the width between panels and also the figure's right margin as a fraction of panel's width (wspace * panel_size[0]). hspace: ``float``, optional (defualt: ``0.15``) This parameter sets the height between panels and also the figure's top margin as a fraction of panel's height (hspace * panel_size[1]). return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. Returns ------- ``Figure`` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples -------- >>> fig = pg.compo_plot(data, 'louvain_labels', 'Donor', style = 'normalized') """ if groupby_label is None: groupby_label = groupby fig, ax = _get_subplot_layouts(panel_size=panel_size, dpi=dpi, left=left, bottom=bottom, wspace=wspace, hspace=hspace) # default nrows = 1 & ncols = 1 restr_obj = RestrictionParser(restrictions) restr_obj.calc_default(data) selected = restr_obj.get_satisfied(data) df = pd.crosstab(data.obs.loc[selected, groupby], data.obs.loc[selected, condition]) index_values = df.index.tolist() column_values = df.columns.tolist() if sort_function == "natsorted": sort_function = natsorted if callable(sort_function): index_values = sort_function(index_values) column_values = sort_function(column_values) if switch_axes: index_values.reverse() df = df.reindex(index = index_values, columns = column_values) if style == "frequency": df = df.div(df.sum(axis=1), axis=0) * 100.0 else: assert style == "normalized" df = df.div(df.sum(axis=0), axis=1) * 100.0 if color_unused: if palette is None: color_list = _get_palette(data.obs[condition].cat.categories.size) else: assert len(palette) >= data.obs[condition].cat.categories.size, "The palette provided has fewer colors than needed!" color_idx = df.columns.map(data.obs[condition].cat.categories.get_loc) color_list = palette[color_idx] else: if palette is None: color_list = _get_palette(df.shape[1]) else: assert len(palette) >= df.shape[1], "The palette provided has fewer colors than needed!" color_list = palette[0:df.shape[1]] df.plot( kind = "bar" if not switch_axes else "barh", stacked = style == "frequency", legend = False, color = color_list, ax = ax, ) ax.grid(False) if not switch_axes: ax.set_xlabel(groupby_label) ax.set_ylabel("Percentage") else: ax.set_xlabel("Percentage") ax.set_ylabel(groupby_label) ax.legend(loc="center left", bbox_to_anchor=(1.05, 0.5)) if len(max(df.index.astype(str), key=len)) >= 5: ax.set_xticklabels(ax.get_xticklabels(), rotation=-45, ha='left') return fig if return_fig else None
[docs]def violin( data: Union[MultimodalData, UnimodalData, anndata.AnnData], attrs: Union[str, List[str]], groupby: str, hue: Optional[str] = None, matkey: Optional[str] = None, stripplot: Optional[bool] = False, stripsize: int = 1, inner: Optional[str] = None, scale: Optional[str] = 'width', panel_size: Optional[Tuple[float, float]] = (8, 0.5), palette: Optional[List[str]] = None, left: Optional[float] = 0.15, bottom: Optional[float] = 0.15, wspace: Optional[float] = 0.1, ylabel: Optional[str] = None, return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwargs, ) -> Union[plt.Figure, None]: """ Generate a stacked violin plot. Parameters ---------- data: ``AnnData`` or ``MultimodalData`` or ``UnimodalData`` object Single-cell expression data. attrs: ``str`` or ``List[str]`` Cell attributes or features to plot. Cell attributes must exist in ``data.obs`` and must be numeric. Features must exist in ``data.var``. groupby: ``str`` A categorical variable in data.obs that is used to categorize the cells, e.g. Clusters. hue: ``str``, optional, default: None 'hue' should be a categorical variable in data.obs that has only two levels. Set 'hue' will show us split violin plots. matkey: ``str``, optional, default: ``None`` If matkey is set, select matrix with matkey as keyword in the current modality. Only works for MultimodalData or UnimodalData objects. stripplot: ``bool``, optional, default: ``False`` Attach a stripplot to the violinplot or not. This option will be automatically turn off if 'hue' is set. inner: ``str``, optional, default: ``None`` Representation of the datapoints in the violin interior: - If ``box``, draw a miniature boxplot. - If ``quartiles``, draw the quartiles of the distribution. - If ``point`` or ``stick``, show each underlying datapoint. - If ``None``, will draw unadorned violins. scale: ``str``, optional, default: ``width`` The method used to scale the width of each violin: - If ``width``, each violin will have the same width. - If ``area``, each violin will have the same area. - If ``count``, the width of the violins will be scaled by the number of observations in that bin. panel_size: ``Tuple[float, float]``, optional, default: ``(8, 0.5)`` The size (width, height) in inches of each violin panel. palette: ``List[str]``, optional (default: ``None``) Used for setting colors for categories in ``groupby``. Within the list, each string is the color for one category. left: ``float``, optional, default: ``0.15`` This parameter sets the figure's left margin as a fraction of panel's width (left * panel_size[0]). bottom: ``float``, optional, default: ``0.15`` This parameter sets the figure's bottom margin as a fraction of panel's height (bottom * panel_size[1]). wspace: ``float``, optional, default: ``0.1`` This parameter sets the width between panels and also the figure's right margin as a fraction of panel's width (wspace * panel_size[0]). ylabel: ``str``, optional, default: ``None`` Y-axis label. No label to show if ``None``. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. kwargs Are passed to ``seaborn.violinplot``. Returns ------- ``Figure`` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``show == False`` Examples -------- >>> pg.violin(data, attrs=['CD14', 'TRAC', 'CD34'], groupby='louvain_labels') """ if not is_list_like(attrs): attrs = [attrs] if not isinstance(data, anndata.AnnData): cur_matkey = data.current_matrix() if matkey is not None: assert not isinstance(data, anndata.AnnData) data.select_matrix(matkey) # Filter out attributes not existing in the data attrs = _get_valid_attrs(data, attrs) nrows = len(attrs) fig, axes = _get_subplot_layouts(nrows=nrows, ncols=1, panel_size=panel_size, dpi=dpi, left=left, bottom=bottom, wspace=wspace, hspace=0, squeeze=False, sharey=False) obs_keys = [] genes = [] for key in attrs: if key in data.obs: assert is_numeric_dtype(data.obs[key]) obs_keys.append(key) else: genes.append(key) df_list = [pd.DataFrame({"label": data.obs[groupby].values})] if hue is not None: df_list.append(pd.DataFrame({hue: data.obs[hue].values})) stripplot = False kwargs['hue'] = hue kwargs['split'] = True else: kwargs['hue'] = "label" kwargs['legend'] = False kwargs['split'] = False if len(obs_keys) > 0: df_list.append(data.obs[obs_keys].reset_index(drop=True)) if len(genes) > 0: expr_mat = slicing(data[:, genes].X) df_list.append(pd.DataFrame(data=expr_mat, columns=genes)) df = pd.concat(df_list, axis = 1) for i in range(nrows): ax = axes[i, 0] if stripplot: sns.stripplot(x="label", y=attrs[i], hue = hue, data=df, ax=ax, size=stripsize, color="k", jitter=True) sns.violinplot(x="label", y=attrs[i], data=df, inner=inner, linewidth=1, ax=ax, cut=0, density_norm=scale, palette=palette, **kwargs) ax.grid(False) if hue is not None: if i == 0: ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5)) else: ax.get_legend().set_visible(False) if i < nrows - 1: ax.set_xlabel("") else: ax.set_xlabel(groupby) ax.set_xticks(ax.get_xticks()) # Get rid of the UserWarning: set_ticklabels() should only be used with a fixed number of ticks ax.set_xticklabels(ax.get_xticklabels(), rotation=90) ax.set_ylabel(attrs[i], labelpad=8, rotation=0, horizontalalignment='right', fontsize='medium') ax.tick_params(axis='y', right=True, left=False, labelright=True, labelleft=False, labelsize='small') if ylabel is not None: fig.text(0.02, 0.5, ylabel, rotation="vertical", fontsize="xx-large") # Reset current matrix if needed. if not isinstance(data, anndata.AnnData): if data.current_matrix() != cur_matkey: data.select_matrix(cur_matkey) return fig if return_fig else None
[docs]def heatmap( data: Union[MultimodalData, UnimodalData, anndata.AnnData], attrs: Union[str, List[str]], groupby: Optional[str] = None, matkey: Optional[str] = None, gene_zscore: Optional[bool] = True, on_average: Optional[bool] = True, switch_axes: Optional[bool] = False, attrs_cluster: Optional[bool] = False, attrs_dendrogram: Optional[bool] = True, attrs_method: Optional[bool] = 'ward', attrs_optimal_ordering: Optional[bool] = True, xlabel_size: Optional[float] = 10.0, ylabel_size: Optional[float] = 10.0, legend_fontsize: Optional[float] = 10.0, xlabel_rotation: Optional[float] = 90.0, ylabel_rotation: Optional[float] = 0.0, groupby_cluster: Optional[bool] = True, groupby_dendrogram: Optional[bool] = True, groupby_method: Optional[bool] = 'ward', groupby_optimal_ordering: Optional[bool] = True, groupby_precomputed_linkage: Optional[np.array] = None, show_sample_name: Optional[bool] = None, cbar_labelsize: Optional[float] = 10.0, panel_size: Tuple[float, float] = (10, 10), return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwargs, ) -> Union[plt.Figure, None]: """ Generate a heatmap. Parameters ----------- data: ``AnnData`` or ``MultimodalData`` or ``UnimodalData`` object Single-cell expression data. attrs: ``str`` or ``List[str]`` Cell attributes or features to plot. Cell attributes must exist in ``data.obs`` and must be numeric. Features must exist in ``data.var``. By default, attrs are plotted as columns. groupby: ``str``, optional, default: ``None`` A categorical variable in data.obs that is used to categorize the cells, e.g. Clusters. By default, data.obs['groupby'] is plotted as rows. If ``None``, use data.obs_names instead. matkey: ``str``, optional, default: ``None`` If matkey is set, select matrix with matkey as keyword in the current modality. Only works for MultimodalData or UnimodalData objects. gene_zscore: ``bool``, optional, default: ``True`` If ``True``, compute and then plot z scores for gene expression. on_average: ``bool``, optional, default: ``True`` If ``True``, plot cluster average gene expression or z score (i.e. show a Matrixplot); otherwise, plot a general heatmap. switch_axes: ``bool``, optional, default: ``False`` By default, X axis is for attributes, and Y axis for clusters. If this parameter is ``True``, switch the axes. Moreover, with ``on_average`` being ``False``, if ``switch_axes`` is ``False``, ``row_cluster`` is enforced to be ``False``; if ``switch_axes`` is ``True``, ``col_cluster`` is enforced to be ``False``. attrs_cluster: ``bool``, optional, default: ``False`` Cluster attributes and generate a attribute-wise dendrogram. attrs_dendrogram: ``bool``, optional, default: ``True`` Only matters if attrs_cluster is True. Show the dendrogram if this option is True. attrs_method: ``str``, optional, default: ``ward`` Linkage method for attrs, choosing from ``single``, ``complete``, ``average``, ``weighted``, ``centroid``, ``median`` and ``ward``. attrs_optimal_ordering: ``bool``, optional, default: ``True`` Parameter for scipy.cluster.hierarchy.linkage. If ``True``, the attrs linkage matrix will be reordered so that the distance between successive leaves is minima. xlabel_size: ``float``, optional, default: 10.0 Fontsize for x-axis labels. ylabel_size: ``float``, optional, default: 10.0 Fontsize for y-axis labels. legend_fontsize: ``float``, optional, default: 10.0 Fontsize for legend labels. xlabel_rotation: ``float``, optional, default: 90.0 Rotation of x-axis labels. ylabel_rotation: ``float``, optional, default: 0.0 Rotation of y-axis labels. groupby_cluster: ``bool``, optional, default: ``True`` Cluster data.obs['groupby'] and generate a cluster-wise dendrogram. groupby_dendrogram: ``bool``, optional, default: ``True`` Only matters if groupby_cluster is True. Show the dendrogram if this option is True. groupby_method: ``str``, optional, default: ``ward`` Linkage method for groupby, choosing from ``single``, ``complete``, ``average``, ``weighted``, ``centroid``, ``median`` and ``ward``. groupby_optimal_ordering: ``bool``, optional, default: ``True`` Parameter for scipy.cluster.hierarchy.linkage. If ``True``, the groupby linkage matrix will be reordered so that the distance between successive leaves is minima. groupby_precomputed_linkage: ``np.array``, optional, default: ``None`` Pass a precomputed linkage. show_sample_name: ``bool``, optional, default: ``None`` If show sample names as tick labels. If ``None``, show_sample_name == ``True`` if groupby == ``None`` and otherwise show_sample_name == ``False``. cbar_labelsize: ``float``, optional, default: 10.0 Fontsize of the color bar. panel_size: ``Tuple[float, float]``, optional, default: ``(10, 10)`` Overall size of the heatmap in ``(width, height)`` form. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. kwargs Are passed to ``seaborn.heatmap``. .. _colormap documentation: https://matplotlib.org/3.1.0/tutorials/colors/colormaps.html Returns ------- ``Figure`` object A ``matplotlib.figure.Figure`` object containing the heatmap if ``return_fig == True``; Otherwise, A ``seaborn.matrix.ClusterGrid`` object is returned. Examples -------- >>> pg.heatmap(data, attrs=['CD14', 'TRAC', 'CD34'], groupby='leiden_labels') """ if not isinstance(data, anndata.AnnData): cur_matkey = data.current_matrix() if matkey is not None: assert not isinstance(data, anndata.AnnData) data.select_matrix(matkey) if isinstance(attrs, str): attrs = [attrs] # Filter out attributes not existing in the data attrs = _get_valid_attrs(data, attrs) obs_keys = [] genes = [] for key in attrs: if key in data.obs: assert is_numeric_dtype(data.obs[key]) obs_keys.append(key) else: if key not in data.var_names: logger.warning(f"Cannot find gene {key}. Please make sure all genes are included in data.var_names before running this function!") return None genes.append(key) df_list = [] if len(obs_keys) > 0: df_list.append(data.obs[obs_keys].reset_index(drop=True)) if len(genes) > 0: expr_mat = slicing(data[:, genes].X) if gene_zscore: from scipy.stats import zscore expr_mat = zscore(expr_mat, ddof=1) df_list.append(pd.DataFrame(data=expr_mat, columns=genes)) df = pd.concat(df_list, axis = 1) df.index = data.obs_names attr_names = df.columns.values if show_sample_name is None: show_sample_name = True if groupby is None else False groupby_tick_labels = df.index if show_sample_name else [] cluster_ids = None cell_colors = None if groupby is not None: cluster_ids = data.obs[groupby].values if not is_categorical_dtype(cluster_ids): cluster_ids = pd.Categorical(cluster_ids) else: cluster_ids = cluster_ids.remove_unused_categories() if on_average: if not 'cmap' in kwargs.keys(): kwargs['cmap'] = 'Reds' df['cluster_name'] = cluster_ids df = df.groupby(by='cluster_name', observed=True).mean() cluster_ids = df.index groupby_tick_labels = cluster_ids else: if not groupby_cluster: idx = cluster_ids.argsort(kind = 'mergesort') df = df.iloc[idx, :] # organize df by category order cluster_ids = cluster_ids[idx] cell_colors = np.zeros(df.shape[0], dtype=object) palette = _get_palette(cluster_ids.categories.size) for k, cat in enumerate(cluster_ids.categories): cell_colors[cluster_ids == cat] = palette[k] from scipy.cluster.hierarchy import linkage groupby_linkage = None if groupby_cluster: if groupby_precomputed_linkage is not None: groupby_linkage = groupby_precomputed_linkage else: groupby_linkage = linkage(df, groupby_method, optimal_ordering = groupby_optimal_ordering) attrs_linkage = None if attrs_cluster: attrs_linkage = linkage(df.T, attrs_method, optimal_ordering = attrs_optimal_ordering) if not switch_axes: cg = sns.clustermap( data=df, row_colors=cell_colors, col_colors=None, row_cluster=groupby_cluster, col_cluster=attrs_cluster, row_linkage=groupby_linkage, col_linkage=attrs_linkage, linewidths=0, yticklabels=groupby_tick_labels, xticklabels=attr_names, figsize=panel_size, **kwargs, ) cg.ax_heatmap.set_ylabel("") cg.ax_heatmap.tick_params(axis='x', labelsize=xlabel_size, labelrotation=xlabel_rotation) cg.ax_heatmap.tick_params(axis='y', labelsize=ylabel_size, labelrotation=ylabel_rotation) else: cg = sns.clustermap( data=df.T, row_colors=None, col_colors=cell_colors, row_cluster=attrs_cluster, col_cluster=groupby_cluster, row_linkage=attrs_linkage, col_linkage=groupby_linkage, linewidths=0, yticklabels=attr_names, xticklabels=groupby_tick_labels, figsize=panel_size, **kwargs, ) cg.ax_heatmap.set_xlabel("") cg.ax_heatmap.tick_params(axis='y', labelsize=ylabel_size, labelrotation=ylabel_rotation) cg.ax_heatmap.tick_params(axis='x', labelsize=xlabel_size, labelrotation=xlabel_rotation) show_row_dendrogram = (attrs_cluster and attrs_dendrogram) if switch_axes else (groupby_cluster and groupby_dendrogram) show_col_dendrogram = (groupby_cluster and groupby_dendrogram) if switch_axes else (attrs_cluster and attrs_dendrogram) if show_row_dendrogram: cg.ax_heatmap.yaxis.tick_right() cg.ax_row_dendrogram.set_visible(True) # Avoid overlap of colorbar and row dendrogram. color_box = cg.ax_cbar.get_position() square_plot = cg.ax_heatmap.get_position() if square_plot.y1 > color_box.y0: y_diff = square_plot.y1 - color_box.y0 color_box.y0 = square_plot.y1 color_box.y1 += y_diff cg.ax_cbar.set_position(color_box) else: cg.ax_heatmap.yaxis.tick_left() cg.ax_row_dendrogram.set_visible(False) # Move the colorbar to the right-side. color_box = cg.ax_heatmap.get_position() color_box.x0 = color_box.x1 + 0.04 color_box.x1 = color_box.x0 + 0.02 cg.ax_cbar.set_position(color_box) cg.ax_cbar.yaxis.set_ticks_position("right") cg.ax_heatmap.xaxis.tick_bottom() cg.ax_col_dendrogram.set_visible(show_col_dendrogram) cg.ax_cbar.tick_params(labelsize=cbar_labelsize) cg.fig.dpi = dpi if (groupby is not None) and (not on_average): if groupby_cluster: from matplotlib.patches import Patch legend_elements = [Patch(color = color, label = label) for color, label in zip(palette, cluster_ids.categories)] cg.ax_heatmap.legend(handles=legend_elements, loc='lower left', bbox_to_anchor = (1.02, 1.02), fontsize = legend_fontsize) else: values = cluster_ids.value_counts().values ticks = np.cumsum(values) - values / 2 labels = cluster_ids.categories if not switch_axes: cg.ax_row_colors.yaxis.tick_left() cg.ax_row_colors.set_yticks(ticks) cg.ax_row_colors.set_yticklabels(labels) cg.ax_row_colors.tick_params(axis='y', left = False, length=10) else: cg.ax_col_colors.xaxis.tick_top() cg.ax_col_colors.set_xticks(ticks) cg.ax_col_colors.set_xticklabels(labels, rotation=45) cg.ax_col_colors.tick_params(axis='x', top = False, labelsize = xlabel_size, length=10) if not isinstance(data, anndata.AnnData): if cur_matkey != data.current_matrix(): data.select_matrix(cur_matkey) return cg.fig if return_fig else cg
[docs]def dotplot( data: Union[MultimodalData, UnimodalData, anndata.AnnData], genes: Union[str, List[str]], groupby: str, reduce_function: Union[str, Callable[[np.ndarray], float]] = "mean", show_only_expressed: bool = True, fraction_min: float = 0, fraction_max: float = None, dot_min: int = 0, dot_max: int = 20, switch_axes: bool = False, cmap: Union[str, List[str], Tuple[str]] = 'Reds', sort_function: Union[Callable[[List[str]], List[str]], str] = 'natsorted', grid: bool = True, return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwds, ) -> Union[plt.Figure, None]: """ Generate a dot plot. Parameters ---------- data: ``AnnData`` or ``UnimodalData`` or ``MultimodalData`` object Single cell expression data. genes: ``str`` or ``List[str]`` Features to plot. groupby: ``str`` A categorical variable in data.obs that is used to categorize the cells, e.g. Clusters. reduce_function: ``Union[str, Callable[[np.ndarray], float]]``, optional, default: ``"mean"`` Function to calculate statistic on expression data. Default is mean. show_only_expressed: ``bool``, optional, default: `True` If ``True``, the statistic is calculated over only cells expressing the selected genes; otherwise, it's calculated over all cells. fraction_min: ``float``, optional, default: ``0``. Minimum fraction of expressing cells to consider. fraction_max: ``float``, optional, default: ``None``. Maximum fraction of expressing cells to consider. If ``None``, use the maximum value from data. dot_min: ``int``, optional, default: ``0``. Minimum size in pixels for dots. dot_max: ``int``, optional, default: ``20``. Maximum size in pixels for dots. switch_axes: ``bool``, optional, default: ``False``. If ``True``, switch X and Y axes. cmap: ``str`` or ``List[str]`` or ``Tuple[str]``, optional, default: ``Reds`` Color map. sort_function: ``Union[Callable[List[str], List[str]], str]``, optional, default: ``natsorted`` Function used for sorting groupby labels. If ``natsorted``, apply natsorted function to sort by natural order. If ``None``, don't sort. Otherwise, a callable function will be applied to the labels for sorting. grid: ``bool``, optional, default: ``True`` If ``True``, plot grids. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. **kwds: Are passed to ``matplotlib.pyplot.scatter``. Returns ------- ``Figure`` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples -------- >>> pg.dotplot(data, genes = ['CD14', 'TRAC', 'CD34'], groupby = 'louvain_labels') """ sns.set(font_scale=0.7, style='whitegrid') if not is_list_like(genes): genes = [genes] # Select only genes existing in the data genes = _get_valid_attrs(data, genes) keywords = dict(cmap=cmap) keywords.update(kwds) X = slicing(data[:, genes].X) df = pd.DataFrame(data=X, columns=genes) df[groupby] = data.obs[groupby].values if df[groupby].isna().sum() > 0: logger.warning(f"Detected NaN values in attribute '{groupby}'! Please check if '{groupby}' is set correctly.") return None series = df[groupby].value_counts() idx = series == 0 if idx.sum() > 0: logger.warning(f"The following categories contain no cells and are removed: {','.join(list(series.index[idx]))}.") def non_zero(g): return np.count_nonzero(g.fillna(0.0)) / g.shape[0] if show_only_expressed: df.set_index(groupby, inplace=True) is_expressed = df > 0.0 summarized_df = df.mask(~is_expressed).groupby(level=0, observed=True).aggregate([reduce_function, non_zero]) else: summarized_df = df.groupby(by=groupby, observed=True).aggregate([reduce_function, non_zero]) row_indices = summarized_df.index.tolist() if sort_function == "natsorted": row_indices = natsorted(row_indices) elif callable(sort_function): row_indices = sort_function(row_indices) row_indices.reverse() summarized_df = summarized_df.loc[row_indices] mean_columns = [] frac_columns = [] for j in range(len(summarized_df.columns)): if j % 2 == 0: mean_columns.append(summarized_df.columns[j]) else: frac_columns.append(summarized_df.columns[j]) # Genes on columns, groupby on rows fraction_df = summarized_df[frac_columns] mean_df = summarized_df[mean_columns] y, x = np.indices(mean_df.shape) y = y.flatten() x = x.flatten() fraction = fraction_df.values.flatten() if fraction_max is None: fraction_max = fraction.max() pixels = _get_dot_size(fraction, fraction_min, fraction_max, dot_min, dot_max) summary_values = mean_df.values.flatten() xlabel = [genes[i] for i in range(len(genes))] ylabel = [str(summarized_df.index[i]) for i in range(len(summarized_df.index))] xticks = genes yticks = summarized_df.index.map(str).values if switch_axes: x, y = y[::-1], x[::-1] xlabel, ylabel = ylabel[::-1], xlabel[::-1] xticks, yticks = yticks[::-1], xticks[::-1] dotplot_df = pd.DataFrame(data=dict(x=x, y=y, value=summary_values, pixels=pixels, fraction=fraction, xlabel=np.array(xlabel)[x], ylabel=np.array(ylabel)[y])) import matplotlib.gridspec as gridspec width = int(np.ceil(((dot_max + 1) + 4) * len(xticks) + dotplot_df['ylabel'].str.len().max()) + dot_max + 100) height = int(np.ceil(((dot_max + 1) + 4) * len(yticks) + dotplot_df['xlabel'].str.len().max()) + 50) fig = plt.figure(figsize=(1.1 * width / 100.0, height / 100.0), dpi=dpi) gs = gridspec.GridSpec(3, 11, figure = fig) # Main plot mainplot_col_grid = -2 if len(xlabel) < 10 else -1 ax = fig.add_subplot(gs[:, :mainplot_col_grid]) sc = ax.scatter(x='x', y='y', c='value', s='pixels', data=dotplot_df, linewidth=0.5, edgecolors='black', **keywords) ax.spines["top"].set_color('black') ax.spines["bottom"].set_color('black') ax.spines["left"].set_color('black') ax.spines["right"].set_color('black') if not grid: ax.grid(False) if not switch_axes: ax.set_ylabel(str(groupby)) ax.set_xlabel('') else: ax.set_ylabel('') ax.set_xlabel(str(groupby)) ax.set_xlim(-1, len(xticks)) ax.set_ylim(-1, len(yticks)) ax.set_xticks(range(len(xticks))) ax.set_xticklabels(xticks) ax.set_yticks(range(len(yticks))) ax.set_yticklabels(yticks) plt.xticks(rotation=90) cbar = plt.colorbar(sc) #cbar.set_label("Mean of\nexpressing cells") size_range = fraction_max - fraction_min if 0.3 < size_range <= 0.6: size_legend_step = 0.1 elif size_range <= 0.3: size_legend_step = 0.05 else: size_legend_step = 0.2 size_ticks = np.arange(fraction_min if fraction_min > 0 or fraction_min > 0 else fraction_min + size_legend_step, fraction_max + size_legend_step, size_legend_step) legend_row_grid = 1 if height / 3 > 100 else 3 ax2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[0:legend_row_grid, -1]) size_legend = fig.add_subplot(ax2[0]) size_tick_pixels = _get_dot_size(size_ticks, fraction_min, fraction_max, dot_min, dot_max) size_tick_labels = ["{:.0%}".format(x) for x in size_ticks] size_legend.scatter(x=np.repeat(0, len(size_ticks)), y=np.arange(0, len(size_ticks)), s=size_tick_pixels, c='black', linewidth=0.5) size_legend.title.set_text("Fraction of\nexpressing cells") size_legend.set_xlim(-0.1, 0.1) size_legend.set_xticks([]) ymin, ymax = size_legend.get_ylim() size_legend.set_ylim(ymin, ymax + 0.5) size_legend.set_yticks(np.arange(len(size_ticks))) size_legend.set_yticklabels(size_tick_labels) size_legend.tick_params(axis='y', labelleft=False, labelright=True) size_legend.spines["top"].set_visible(False) size_legend.spines["bottom"].set_visible(False) size_legend.spines["left"].set_visible(False) size_legend.spines["right"].set_visible(False) size_legend.grid(False) # Reset global settings. matplotlib.rc_file_defaults() return fig if return_fig else None
[docs]def dendrogram( data: Union[MultimodalData, UnimodalData, anndata.AnnData], groupby: str, rep: str = 'pca', genes: Optional[List[str]] = None, correlation_method: str = 'pearson', n_clusters: Optional[int] = None, affinity: str = 'euclidean', linkage: str = 'complete', compute_full_tree: Union[str, bool] = 'auto', distance_threshold: Optional[float] = 0, panel_size: Tuple[float, float] = (6, 6), orientation: str = 'top', color_threshold: Optional[float] = None, return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwargs, ) -> Union[plt.Figure, None]: """ Generate a dendrogram on hierarchical clustering result. The metrics used here are consistent with SCANPY's dendrogram_ implementation. *scikit-learn* `Agglomerative Clustering`_ implementation is used for hierarchical clustering. .. _dendrogram: https://scanpy.readthedocs.io/en/stable/api/scanpy.tl.dendrogram.html .. _Agglomerative Clustering: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html Parameters ---------- data: ``MultimodalData``, ``UnimodalData``, or ``AnnData`` object Single cell expression data. genes: ``List[str]``, optional, default: ``None`` List of genes to use. Gene names must exist in ``data.var``. If set, use the counts in ``data.X`` for plotting; if set as ``None``, use the embedding specified in ``rep`` for plotting. rep: ``str``, optional, default: ``pca`` Cell embedding to use. It only works when ``genes``is ``None``, and its key ``"X_"+rep`` must exist in ``data.obsm``. By default, use PCA coordinates. groupby: ``str`` Categorical cell attribute to plot, which must exist in ``data.obs``. correlation_method: ``str``, optional, default: ``pearson`` Method of correlation between categories specified in ``data.obs``. Available options are: ``pearson``, ``kendall``, ``spearman``. See `pandas corr documentation <https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.corr.html>`_ for details. n_clusters: ``int``, optional, default: ``None`` The number of clusters to find, used by hierarchical clustering. It must be ``None`` if ``distance_threshold`` is not ``None``. affinity: ``str``, optional, default: ``correlation`` Metric used to compute the linkage, used by hierarchical clustering. Valid values for metric are: - From scikit-learn: ``cityblock``, ``cosine``, ``euclidean``, ``l1``, ``l2``, ``manhattan``. - From scipy.spatial.distance: ``braycurtis``, ``canberra``, ``chebyshev``, ``correlation``, ``dice``, ``hamming``, ``jaccard``, ``kulsinski``, ``mahalanobis``, ``minkowski``, ``rogerstanimoto``, ``russellrao``, ``seuclidean``, ``sokalmichener``, ``sokalsneath``, ``sqeuclidean``, ``yule``. Default is the correlation distance. See `scikit-learn distance documentation <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise_distances.html>`_ for details. linkage: ``str``, optional, default: ``complete`` Which linkage criterion to use, used by hierarchical clustering. Below are available options: - ``ward`` minimizes the variance of the clusters being merged. - ``average`` uses the average of the distances of each observation of the two sets. - ``complete`` uses the maximum distances between all observations of the two sets. (Default) - ``single`` uses the minimum of the distances between all observations of the two sets. See `scikit-learn documentation <https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html>`_ for details. compute_full_tree: ``str`` or ``bool``, optional, default: ``auto`` Stop early the construction of the tree at ``n_clusters``, used by hierarchical clustering. It must be ``True`` if ``distance_threshold`` is not ``None``. By default, this option is ``auto``, which is ``True`` if and only if ``distance_threshold`` is not ``None``, or ``n_clusters`` is less than ``min(100, 0.02 * n_groups)``, where ``n_groups`` is the number of categories in ``data.obs[groupby]``. distance_threshold: ``float``, optional, default: ``0`` The linkage distance threshold above which, clusters will not be merged. If not ``None``, ``n_clusters`` must be ``None`` and ``compute_full_tree`` must be ``True``. panel_size: ``Tuple[float, float]``, optional, default: ``(6, 6)`` The size (width, height) in inches of figure. orientation: ``str``, optional, default: ``top`` The direction to plot the dendrogram. Available options are: ``top``, ``bottom``, ``left``, ``right``. See `scipy dendrogram documentation`_ for explanation. color_threshold: ``float``, optional, default: ``None`` Threshold for coloring clusters. See `scipy dendrogram documentation <https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.dendrogram.html>`_ for explanation. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. **kwargs: Are passed to ``scipy.cluster.hierarchy.dendrogram``. Returns ------- ``Figure`` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples -------- >>> pg.dendrogram(data, genes=data.var_names, groupby='louvain_labels') >>> pg.dendrogram(data, rep='pca', groupby='louvain_labels') """ if genes is None: embed_df = pd.DataFrame(X_from_rep(data, rep)) embed_df.set_index(data.obs[groupby], inplace=True) else: X = slicing(data[:, genes].X) embed_df = pd.DataFrame(X) embed_df.set_index(data.obs[groupby], inplace=True) mean_df = embed_df.groupby(level=0).mean() mean_df.index = mean_df.index.astype('category') from sklearn.cluster import AgglomerativeClustering from scipy.cluster.hierarchy import dendrogram corr_mat = mean_df.T.corr(method=correlation_method) clusterer = AgglomerativeClustering( n_clusters=n_clusters, affinity=affinity, linkage=linkage, compute_full_tree=compute_full_tree, distance_threshold=distance_threshold ) clusterer.fit(corr_mat) counts = np.zeros(clusterer.children_.shape[0]) n_samples = len(clusterer.labels_) for i, merge in enumerate(clusterer.children_): current_count = 0 for child_idx in merge: if child_idx < n_samples: current_count += 1 # Leaf node else: current_count += counts[child_idx - n_samples] counts[i] = current_count linkage_matrix = np.column_stack([clusterer.children_, clusterer.distances_, counts]).astype(float) fig, ax = _get_subplot_layouts(panel_size=panel_size, dpi=dpi) dendrogram(linkage_matrix, labels=mean_df.index.categories, ax=ax, **kwargs) plt.xticks(rotation=90, fontsize=10) plt.tight_layout() return fig if return_fig else None
[docs]def hvfplot( data: Union[MultimodalData, UnimodalData, anndata.AnnData], top_n: int = 20, panel_size: Optional[Tuple[float, float]] = (6, 4), return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, ) -> Union[plt.Figure, None]: """ Generate highly variable feature plot. Only works for HVGs returned by ``highly_variable_features`` method with ``flavor=='pegasus'``. Parameters ----------- data: ``MultimodalData``, ``UnimodalData``, or ``anndata.AnnData`` object. Single cell expression data. top_n: ``int``, optional, default: ``20`` Number of top highly variable features to show names. panel_size: ``Tuple[float, float]``, optional, default: ``(6, 4)`` The size (width, height) in inches of figure. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. Returns -------- ``Figure`` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples --------- >>> pg.hvfplot(data) >>> pg.hvfplot(data, top_n=10, dpi=150) """ robust_idx = data.var["robust"].values x = data.var.loc[robust_idx, "mean"] y = data.var.loc[robust_idx, "var"] fitted = data.var.loc[robust_idx, "hvf_loess"] hvg_index = data.var.loc[robust_idx, "highly_variable_features"] hvg_rank = data.var.loc[robust_idx, "hvf_rank"] gene_symbols = data.var_names[robust_idx] fig, ax = _get_subplot_layouts(panel_size=panel_size, dpi=dpi) ax.scatter(x[hvg_index], y[hvg_index], s=5, c='b', marker='o', linewidth=0.5, alpha=0.5, label='highly variable features') ax.scatter(x[~hvg_index], y[~hvg_index], s=5, c='k', marker='o', linewidth=0.5, alpha=0.5, label = 'other features') ax.legend(loc = 'best', fontsize = 5) ax.set_xlabel("Mean log expression") ax.set_ylabel("Variance of log expression") order = x.argsort().values ax.plot(x[order], fitted[order], "r-", linewidth=1) ord_rank = hvg_rank.argsort().values texts = [] for i in range(top_n): pos = ord_rank[i] texts.append(ax.text(x[pos], y[pos], gene_symbols[pos], fontsize=5)) from adjustText import adjust_text adjust_text(texts, arrowprops=dict(arrowstyle='-', color='k', lw=0.5)) return fig if return_fig else None
[docs]def qcviolin( data: Union[MultimodalData, UnimodalData, anndata.AnnData], plot_type: str, min_genes_before_filt: Optional[int] = 100, n_violin_per_panel: Optional[int] = 8, panel_size: Optional[Tuple[float, float]] = (6, 4), left: Optional[float] = 0.2, bottom: Optional[float] = 0.15, wspace: Optional[float] = 0.3, hspace: Optional[float] = 0.35, return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, ) -> Union[plt.Figure, None]: """ Plot quality control statistics (before filtration vs. after filtration) as violin plots. Require statistics such as "n_genes", "n_counts" and "percent_mito" precomputed. Parameters ----------- data: ``MultimodalData``, ``UnimodalData``, or ``anndata.AnnData`` object. Single cell expression data. plot_type: ``str`` Choose from ``gene``, ``count`` and ``mito``, which shows number of expressed genes, number of UMIs and percentage of mitochondrial rate. min_genes_before_filt: ``int``, optional, default: 100 If data loaded are raw data (i.e. min(n_genes) == 0), filter out cell barcodes with less than ``min_genes_before_filt`` for better visual effects. n_violin_per_panel: ``int``, optional, default: 8 Number of violin plots (samples) shown in one panel. panel_size: `tuple`, optional (default: `(6, 4)`) The panel size (width, height) in inches. left: `float`, optional (default: `0.2`) This parameter sets the figure's left margin as a fraction of panel's width (left * panel_size[0]). bottom: `float`, optional (default: `0.15`) This parameter sets the figure's bottom margin as a fraction of panel's height (bottom * panel_size[1]). wspace: `float`, optional (default: `0.4`) This parameter sets the width between panels and also the figure's right margin as a fraction of panel's width (wspace * panel_size[0]). hspace: `float`, optional (defualt: `0.15`) This parameter sets the height between panels and also the figure's top margin as a fraction of panel's height (hspace * panel_size[1]). return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. Returns -------- ``Figure`` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples --------- >>> pg.qcviolin(data, "mito", dpi = 500) """ pt2attr = {"gene": "n_genes", "count": "n_counts", "mito": "percent_mito"} pt2ylab = { "gene": "Number of expressed genes", "count": "Number of UMIs", "mito": "Percentage of mitochondrial UMIs", } if "df_qcplot" not in data.uns: if "Channel" not in data.obs: data.obs["Channel"] = pd.Categorical([""] * data.shape[0]) target_cols = np.array(["Channel", "n_genes", "n_counts", "percent_mito"]) target_cols = target_cols[np.isin(target_cols, data.obs.columns)] df = data.obs[data.obs["n_genes"] >= min_genes_before_filt] if data.obs["n_genes"].min() == 0 else data.obs df_plot_before = df[target_cols].copy() df_plot_before.reset_index(drop=True, inplace=True) df_plot_before["status"] = "original" df_plot_after = data.obs.loc[data.obs["passed_qc"], target_cols].copy() df_plot_after.reset_index(drop=True, inplace=True) df_plot_after["status"] = "filtered" df_qcplot = pd.concat((df_plot_before, df_plot_after), axis=0) df_qcplot["status"] = pd.Categorical(df_qcplot["status"].values, categories = ["original", "filtered"]) df_qcplot["Channel"] = pd.Categorical(df_qcplot["Channel"].values, categories = natsorted(df_qcplot["Channel"].astype(str).unique())) data.uns["df_qcplot"] = df_qcplot df_qcplot = data.uns["df_qcplot"] if pt2attr[plot_type] not in df_qcplot: logger.warning(f"Cannot find qc metric {pt2attr[plot_type]}!") return None channels = df_qcplot["Channel"].cat.categories n_channels = channels.size n_pannels = (n_channels - 1) // n_violin_per_panel + 1 nrows = ncols = None nrows, ncols = _get_nrows_and_ncols(n_pannels, nrows, ncols) fig, axes = _get_subplot_layouts(nrows=nrows, ncols=ncols, panel_size=panel_size, dpi=dpi, left=left, bottom=bottom, wspace=wspace, hspace=hspace, sharex = False, sharey = False, squeeze=False) for i in range(nrows): for j in range(ncols): ax = axes[i, j] ax.grid(False) panel_no = i * ncols + j if panel_no < n_pannels: start = panel_no * n_violin_per_panel end = min(start + n_violin_per_panel, n_channels) idx = np.isin(df_qcplot["Channel"], channels[start:end]) if start == 0 and end == n_channels: df_plot = df_qcplot else: df_plot = df_qcplot[idx].copy() df_plot["Channel"] = pd.Categorical(df_plot["Channel"].values, categories = natsorted(channels[start:end])) sns.violinplot( x="Channel", y=pt2attr[plot_type], hue="status", data=df_plot, split=True, linewidth=0.5, cut=0, inner=None, ax = ax, ) ax.set_xlabel("Channel") ax.set_ylabel(pt2ylab[plot_type]) ax.legend(loc="upper right", fontsize=8) if max([len(x) for x in channels[start:end]]) >= 5: ax.set_xticklabels(ax.get_xticklabels(), fontsize=8, rotation=-45) else: ax.set_frame_on(False) ax.set_xticks([]) ax.set_yticks([]) return fig if return_fig else None
[docs]def volcano( data: Union[MultimodalData, UnimodalData, anndata.AnnData], cluster_id: str, de_key: str = "de_res", de_test: str = 'mwu', qval_threshold: float = 0.05, log2fc_threshold: float = 1.0, top_n: int = 20, panel_size: Optional[Tuple[float, float]] = (6, 4), return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, ) -> Union[plt.Figure, None]: """ Generate Volcano plots (-log10 p value vs. log2 fold change) for visualizing DE results. Parameters ----------- data: ``MultimodalData``, ``UnimodalData``, or ``anndata.AnnData`` object. Single cell expression data. cluster_id: ``str`` Cluster ID for the cluster we want to show DE results. There are two cases: * If ``condition`` is ``None`` in ``pg.de_analysis``: Just specify one cluster label in the cluster attribute used in ``pg.de_analysis``. * If ``condition`` is not ``None`` in ``pg.de_analysis``: Specify cluster ID in this format: **"cluster_label:cond_level"**, where **cluster_label** is the cluster label, and **cond_level** is the condition ID. And this shows result of cells within the cluster under the specific condition. de_key: ``str``, optional, default: ``de_res`` The varm keyword for DE results. data.varm[de_key] should store the full DE result table. de_test: ``str``, optional, default: ``mwu`` Which DE test results to show. Use MWU test result by default. qval_threshold: ``float``, optional, default: 0.05. Selected FDR rate. A horizontal line indicating this rate will be shown in the figure. log2fc_threshold: ``float``, optional, default: 1.0 Log2 fold change threshold to highlight biologically interesting genes. Two vertical lines representing negative and positive log2 fold change will be shown. top_n: ``int``, optional, default: ``20`` Number of top DE genes to show names. Genes are ranked by Log2 fold change. panel_size: ``Tuple[float, float]``, optional, default: ``(6, 4)`` The size (width, height) in inches of figure. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. Returns -------- ``Figure`` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples --------- >>> pg.volcano(data, cluster_id = '1', dpi=200) """ if de_key not in data.varm: logger.warning(f"Cannot find DE results '{de_key}'. Please conduct DE analysis first!") return None de_res = data.varm[de_key] fcstr = f"{cluster_id}:log2FC" pstr = f"{cluster_id}:{de_test}_pval" qstr = f"{cluster_id}:{de_test}_qval" columns = de_res.dtype.names if (fcstr not in columns) or (pstr not in columns) or (qstr not in columns): logger.warning(f"Please conduct DE test {de_test} first!") return None log2fc = de_res[fcstr] pvals = de_res[pstr] pvals[pvals == 0.0] = 1e-45 # very small pvalue to avoid log10 0 neglog10p = -np.log10(pvals) yconst = min(neglog10p[de_res[qstr] <= qval_threshold]) fig, ax = _get_subplot_layouts(panel_size=panel_size, dpi=dpi) idxsig = neglog10p >= yconst idxnsig = neglog10p < yconst idxfc = (log2fc <= -log2fc_threshold) | (log2fc >= log2fc_threshold) idxnfc = ~idxfc idx = idxnsig & idxnfc ax.scatter(log2fc[idx], neglog10p[idx], s=5, c='k', marker='o', linewidths=0.5, alpha=0.5, label="NS") idx = idxnsig & idxfc ax.scatter(log2fc[idx], neglog10p[idx], s=5, c='g', marker='o', linewidths=0.5, alpha=0.5, label=r"Log$_2$ FC") idx = idxsig & idxnfc ax.scatter(log2fc[idx], neglog10p[idx], s=5, c='b', marker='o', linewidths=0.5, alpha=0.5, label=r"q-value") idx = idxsig & idxfc ax.scatter(log2fc[idx], neglog10p[idx], s=5, c='r', marker='o', linewidths=0.5, alpha=0.5, label=r"q-value and log$_2$ FC") ax.set_xlabel(r"Log$_2$ fold change") ax.set_ylabel(r"$-$Log$_{10}$ $P$") legend = ax.legend( loc="center", bbox_to_anchor=(0.5, 1.1), frameon=False, fontsize=8, ncol=4, ) for handle in legend.legendHandles: # adjust legend size handle.set_sizes([50.0]) ax.axhline(y = yconst, c = 'k', lw = 0.5, ls = '--') ax.axvline(x = -log2fc_threshold, c = 'k', lw = 0.5, ls = '--') ax.axvline(x = log2fc_threshold, c = 'k', lw = 0.5, ls = '--') texts = [] idx = np.where(idxsig & (log2fc >= log2fc_threshold))[0] posvec = np.argsort(log2fc[idx])[::-1][0:top_n] for pos in posvec: gid = idx[pos] texts.append(ax.text(log2fc[gid], neglog10p[gid], data.var_names[gid], fontsize=5)) idx = np.where(idxsig & (log2fc <= -log2fc_threshold))[0] posvec = np.argsort(log2fc[idx])[0:top_n] for pos in posvec: gid = idx[pos] texts.append(ax.text(log2fc[gid], neglog10p[gid], data.var_names[gid], fontsize=5)) from adjustText import adjust_text adjust_text(texts, arrowprops=dict(arrowstyle='-', color='k', lw=0.5)) return fig if return_fig else None
[docs]def rank_plot( data: Union[MultimodalData, UnimodalData, anndata.AnnData], panel_size: Optional[Tuple[float, float]] = (6, 4), return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwargs, ) -> Union[plt.Figure, None]: """Generate a barcode rank plot, which shows the total UMIs against barcode rank (in descending order with respect to total UMIs) Parameters ---------- data : `AnnData` or `UnimodalData` or `MultimodalData` object The main data object. panel_size: `tuple`, optional (default: `(6, 4)`) The plot size (width, height) in inches. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. Returns ------- `Figure` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples -------- >>> fig = pg.rank_plot(data, dpi = 500) """ fig, ax = _get_subplot_layouts(panel_size=panel_size, dpi=dpi) # default nrows = 1 & ncols = 1 numis = data.X.sum(axis = 1).A1 ords = np.argsort(numis)[::-1] ranks = np.array(range(1, numis.size + 1)) ax.scatter(ranks, numis[ords], c = 'lightgrey', s = 5) ax.set_xscale("log", basex = 10) ax.set_yscale("log", basey = 10) ax.set_xlabel("Barcode rank") ax.set_ylabel("Total UMIs") def _gen_ticklabels(ticks, max_value): label_arr = ['1', '10', '100', '1000', '10K', '100K', '1M', '10M', '100M'] ticklabels = [''] * ticks.size for i in range(ticks.size): exponent = int(round(np.log10(ticks[i]))) if exponent >= 0 and ticks[i] <= max_value: ticklabels[i] = label_arr[exponent] return ticklabels ax.set_xticklabels(_gen_ticklabels(ax.get_xticks(), numis.size)) ax.set_yticklabels(_gen_ticklabels(ax.get_yticks(), numis.max())) return fig if return_fig else None
[docs]def ridgeplot( data: Union[MultimodalData, UnimodalData], features: Union[str, List[str]], matrix_key: Optional[str] = None, donor_attr: Optional[str] = None, qc_attr: Optional[str] = None, overlap: Optional[float] = 0.5, left_adjust: Optional[float] = 0.35, panel_size: Optional[Tuple[float, float]] = (6, 4), return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwargs, ) -> Union[plt.Figure, None]: """Generate ridge plots, up to 8 features can be shown in one figure. Parameters ---------- data : `UnimodalData` or `MultimodalData` object Data matrix. features : `str` or `List[str]` One or more features to display. matrix_key: `str`, optional, default None Which matrix to search features for. If None, use the current matrix. donor_attr: `str`, optional, default None If not None, `features` must contain only one feature, plot this feature by donor indicated as `donor_attr`. qc_attr: `str`, optional, default None If not None, only data.obs[qc_attr] == True are used. overlap: `float`, default 0.5 Overlap between adjacent ridge plots (top and bottom). left_adjust: `float`, default 0.35 Left margin for displaying labels. panel_size: `tuple`, optional (default: `(6, 4)`) The plot size (width, height) in inches. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. Returns ------- `Figure` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples -------- >>> fig = pg.ridgeplot(data, features = ['CD8', 'CD4', 'CD3'], dpi = 500) >>> fig = pg.ridgeplot(data, features = 'CD3', donor_attr = 'assignment', dpi = 500) """ sns.set(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}) idx = data.obs[qc_attr].values if qc_attr is not None else np.ones(data.shape[0], dtype = bool) if isinstance(features, str): features = [features] if len(features) > 8: logger.warning("At most 8 features are allowed to be plotted together!") return None if matrix_key == None: matrix_key = data.current_matrix() df = None if donor_attr is None: exprs = [] feats = [] size = idx.sum() for feature in features: fid = data.var_names.get_loc(feature) exprs.append(slicing(data.get_matrix(matrix_key), idx, fid)) feats.append(np.repeat(feature, size)) df = pd.DataFrame({"expression": np.concatenate(exprs), "feature": np.concatenate(feats)}) else: if len(features) != 1: logger.warning("When donor_attr is set, only one feature can be provided!") return None if donor_attr not in data.obs: logger.warning(f"{donor_attr} is not in data.obs!") return None feature = features[0] if feature not in data.var_names: logger.warning(f"Feature {feature} is not included in data.var_names!") return None fid = data.var_names.get_loc(features[0]) df = pd.DataFrame({"expression": slicing(data.get_matrix("arcsinh.transformed"), idx, fid), "feature": data.obs.loc[idx, donor_attr]}) g = sns.FacetGrid(df, row="feature", hue="feature", aspect=8, height=1.0) try: g.map(sns.kdeplot, "expression", clip_on=False, shade=True, alpha=1, lw=0) g.map(sns.kdeplot, "expression", clip_on=False, color="k", lw=0.5) except RuntimeError as re: if str(re).startswith("Selected KDE bandwidth is 0. Cannot estimate density."): g.map(sns.kdeplot, "expression", clip_on=False, shade=True, alpha=1, lw=0, bw=0.1) g.map(sns.kdeplot, "expression", clip_on=False, color="k", lw=0.5, bw=0.1) else: raise re g.map(plt.axhline, y=0, lw=1, clip_on=False) def _set_label(value, color, label): ax = plt.gca() ax.text(0, 0.2, label, color="k", ha="right", va="center", transform=ax.transAxes) g.map(_set_label, "expression") g.fig.subplots_adjust(hspace=-overlap, left=left_adjust) g.set_titles("") g.set_xlabels("") g.set_ylabels("") g.set(yticks=[]) g.despine(bottom=True, left=True) if donor_attr is not None: g.fig.suptitle(features[0], x = 0.0, y = 0.98, ha = "left") g.fig.set_dpi(dpi) g.fig.set_figwidth(panel_size[0]) g.fig.set_figheight(panel_size[1]) sns.reset_orig() return g.fig if return_fig else None
[docs]def wordcloud( data: Union[MultimodalData, UnimodalData, anndata.AnnData], factor: int, max_words: Optional[int] = 20, random_state: Optional[int] = 0, colormap: Optional[str] = "hsv", width: Optional[int] = 800, height: Optional[int] = 400, panel_size: Optional[Tuple[float, float]] = (6, 4), return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwargs, ) -> Union[plt.Figure, None]: """Generate one word cloud image for factor (starts from 0) in data.uns['W']. Parameters ---------- data : ``AnnData`` or ``UnimodalData`` or ``MultimodalData`` object The main data object. factor: ``int`` Which factor to plot. factor starts from 0. max_words: ``int``, optional, default: 20 Maximum number of genes to show in the image. random_state: ``int``, optional, default: 0 Random seed passing to WordCloud function. colormap: ``str``, optional, default: ``hsv`` Color map for plotting words. width: ``int``, optional, default: 800 Canvas width. height: ``int``, optional, default: 400 Canvas height. panel_size: ``tuple``, optional, default: `(6, 4)` The plot size (width, height) in inches. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. Returns ------- `Figure` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples -------- >>> fig = pg.wordcloud(data, factor=0) """ fig, ax = _get_subplot_layouts(panel_size=panel_size, dpi=dpi) # default nrows = 1 & ncols = 1 assert 'W' in data.uns hvg = data.var_names[data.var['highly_variable_features']] word_dict = {} for i in range(hvg.size): word_dict[hvg[i]] = data.uns['W'][i, factor] from wordcloud import WordCloud wc = WordCloud(background_color="white", max_words=max_words, random_state=random_state, colormap=colormap, width=width, height=height) wc.generate_from_frequencies(word_dict) ax.imshow(wc) ax.axis('off') return fig if return_fig else None
def _make_one_gsea_plot(df, ax, color, size=10, fontsize=5): if df.shape[0] == 0: ax.axis('off') return None df_plot = df.sort_values(['Log Q', 'NES Abs'], ascending=False) df_plot = df_plot.iloc[0:size] ax = sns.barplot(x='Log Q', y='pathway', data=df_plot, color=color, ax=ax) ax.set_xlabel('') ax.set_ylabel('') ax.tick_params(axis='y', labelsize=fontsize)
[docs]def plot_gsea( data: Union[MultimodalData, UnimodalData], gsea_keyword: Optional[str] = "fgsea_out", alpha: Optional[float] = 0.1, top_n: Optional[int] = 20, panel_size: Optional[Tuple[float, float]] = (6, 4), label_fontsize: Optional[int] = 5, return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwargs, ) -> Union[plt.Figure, None]: """Generate GSEA barplots Parameters ---------- data : ``UnimodalData`` or ``MultimodalData`` object The main data object. gsea_keyword: ``str``, optional, default: ``"fgsea_out"`` Keyword in data.uns that stores the fGSEA results in pandas data frame. alpha: ``float``, optional, default: ``0.1`` False discovery rate threshold. top_n: ``int``, optional, default: ``20`` Only show top_n up/down regulated pathways. panel_size: `tuple`, optional (default: `(6, 4)`) The plot size (width, height) in inches. label_fontsize: ``int``, optional, default: ``5`` The font size of labels. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. Returns ------- `Figure` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True``. Each figure contains two panels. The top panel shows up-regulated pathways (color: red) and the bottom panel shows down-regulated pathways (color: green). Examples -------- >>> fig = pg.plot_gsea(data, 'fgsea_out', dpi = 500) """ df = data.uns[gsea_keyword] df = df[df['padj'] <= alpha].copy() df['Log Q'] = -np.log10(df['padj']) df['NES Abs'] = np.abs(df['NES']) df['pathway'] = df['pathway'].map(lambda x: ' '.join(x.split('_'))) fig, axes = _get_subplot_layouts(panel_size=panel_size, nrows=2, dpi=dpi, left=0.6, hspace=0.2, sharey=False) df_up = df.loc[df['NES']>0][0:top_n] _make_one_gsea_plot(df_up, axes[0], color='red', fontsize=label_fontsize) df_dn = df.loc[df['NES']<0][0:top_n] _make_one_gsea_plot(df_dn, axes[1], color='green', fontsize=label_fontsize) axes[1].set_xlabel('-log10(q-value)') return fig if return_fig else None
[docs]def elbowplot( data: Union[MultimodalData, UnimodalData], rep: str = "pca", pval: str = "0.05", panel_size: Optional[Tuple[float, float]] = (6, 4), return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwargs, ) -> Union[plt.Figure, None]: """Generate Elbowplot and suggest n_comps to select based on random matrix theory (see utils.largest_variance_from_random_matrix). Parameters ---------- data : ``UnimodalData`` or ``MultimodalData`` object. The main data object. rep: ``str``, optional, default: ``pca`` Representation to consider, either "pca" or "tsvd". pval: ``str``, optional (default: "0.05"). P value cutoff on the null distribution (random matrix), choosing from "0.01" and "0.05". top_n: ``int``, optional, default: ``20`` Only show top_n up/down regulated pathways. panel_size: `tuple`, optional (default: `(6, 4)`) The plot size (width, height) in inches. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. Returns ------- `Figure` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True``. Update ``data.uns``: * ``{rep}_ncomps``: Recommended components to pick. Examples -------- >>> fig = pg.elbowplot(data, dpi = 500) """ assert rep in data.uns repf = data.uns[f"{rep}_features"] nfeatures = data.var[repf].sum() if repf != None else data.shape[1] thre = largest_variance_from_random_matrix(data.shape[0], nfeatures, pval) ncomps = (data.uns[rep]["variance"] > thre).sum() data.uns[f"{rep}_ncomps"] = ncomps logger.info(f"Selecting {ncomps} is recommended!") fig, ax = _get_subplot_layouts(panel_size=panel_size, dpi=dpi) ax.scatter(range(1, data.uns[rep]["variance"].size + 1), data.uns[rep]["variance"], s=8, c='k') ax.set_yscale('log') ax.set_xlabel(rep.upper()) ax.set_ylabel("Variance") ax.axvline(x = ncomps + 0.5, ls = "--", c = "r", linewidth=1) return fig if return_fig else None