Source code for pegasus.plotting.plot_library

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from scipy.sparse import issparse
from pandas.api.types import is_numeric_dtype, is_categorical_dtype, is_list_like
from scipy.stats import zscore
from sklearn.metrics import adjusted_mutual_info_score
from natsort import natsorted


import anndata
from pegasusio import UnimodalData, MultimodalData

from typing import List, Tuple, Union, Optional, Callable

import logging
logger = logging.getLogger(__name__)

from pegasus.tools import X_from_rep
from .plot_utils import _transform_basis, _get_nrows_and_ncols, _get_marker_size, _get_dot_size, _get_subplot_layouts, _get_legend_ncol, _get_palette, _plot_cluster_labels_in_heatmap, RestrictionParser, DictWithDefault, _generate_categories, _plot_corners


[docs]def scatter( data: Union[MultimodalData, UnimodalData, anndata.AnnData], attrs: Union[str, List[str]], basis: Optional[str] = "umap", matkey: Optional[str] = None, restrictions: Optional[Union[str, List[str]]] = None, show_background: Optional[bool] = False, alpha: Optional[Union[float, List[float]]] = 1.0, legend_loc: Optional[Union[str, List[str]]] = "right margin", legend_ncol: Optional[str] = None, palettes: Optional[Union[str, List[str]]] = None, cmaps: Optional[Union[str, List[str]]] = "YlOrRd", vmin: Optional[float] = None, vmax: Optional[float] = None, nrows: Optional[int] = None, ncols: Optional[int] = None, panel_size: Optional[Tuple[float, float]] = (4, 4), left: Optional[float] = 0.2, bottom: Optional[float] = 0.15, wspace: Optional[float] = 0.4, hspace: Optional[float] = 0.15, return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwargs, ) -> Union[plt.Figure, None]: """Generate scatter plots for different attributes Parameters ---------- data: ``pegasusio.MultimodalData`` Use current selected modality in data. attrs: ``str`` or ``List[str]`` Color scatter plots by attrs. Each attribute in attrs should be one key in data.obs or data.var_names (e.g. one gene). If one attribute is categorical, a palette will be used to color each category separately. Otherwise, a color map will be used. basis: ``str``, optional, default: ``umap`` Basis to be used to generate scatter plots. Can be either 'umap', 'tsne', 'fitsne', 'fle', 'net_tsne', 'net_fitsne', 'net_umap' or 'net_fle'. matkey: ``str``, optional, default: None If matkey is set, select matrix with matkey as keyword in the current modality. Only works for MultimodalData or UnimodalData objects. restrictions: ``str`` or ``List[str]``, optional, default: None A list of restrictions to subset data for plotting. There are two types of restrictions: global restriction and attribute-specific restriction. Global restriction appiles to all attributes in ``attrs`` and takes the format of 'key:value,value...', or 'key:~value,value...'. This restriction selects cells with the ``data.obs[key]`` values belong to 'value,value...' (or not belong to if '~' shows). Attribute-specific restriction takes the format of 'attr:key:value,value...', or 'attr:key:~value,value...'. It only applies to one attribute 'attr'. If 'attr' and 'key' are the same, one can use '.' to replace 'key' (e.g. ``cluster_labels:.:value1,value2``). show_background: ``bool``, optional, default: False Only applicable if `restrictions` is set. By default, only data points selected are shown. If show_background is True, data points that are not selected will also be shown. alpha: ``float`` or ``List[float]``, optional, default: ``1.0`` Alpha value for blending, from 0.0 (transparent) to 1.0 (opaque). If this is a list, the length must match attrs, which means we set a separate alpha value for each attribute. legend_loc: ``str`` or ``List[str]``, optional, default: ``right margin`` Legend location. Can be either "right margin" or "on data". If a list is provided, set 'legend_loc' for each attribute in 'attrs' separately. legend_ncol: ``str``, optional, default: None Only applicable if legend_loc == "right margin". Set number of columns used to show legends. palettes: ``str`` or ``List[str]``, optional, default: None Used for setting colors for every categories in categorical attributes. Each string in ``palettes`` takes the format of 'attr:color1,color2,...,colorn'. 'attr' is the categorical attribute and 'color1' - 'colorn' are the colors for each category in 'attr' (e.g. 'cluster_labels:black,blue,red,...,yellow'). If there is only one categorical attribute in 'attrs', ``palletes`` can be set as a single string and the 'attr' keyword can be omitted (e.g. "blue,yellow,red"). cmaps: ``str`` or ``List[str]``, optional, default: ``YlOrRd`` Used for setting colormap for numeric attributes. Each string in ``cmaps`` takes the format of 'colormap' or 'attr:colormap'. 'colormap' sets the default colormap for all numeric attributes. 'attr:colormap' overwrites attribute 'attr's colormap as 'colormap'. vmin: ``float``, optional, default: None Minimum value to show on a numeric scatter plot (feature plot). vmax: ``float``, optional, default: None Maximum value to show on a numeric scatter plot (feature plot). nrows: ``int``, optional, default: None Number of rows in the figure. If not set, pegasus will figure it out automatically. ncols: ``int``, optional, default: None Number of columns in the figure. If not set, pegasus will figure it out automatically. panel_size: `tuple`, optional (default: `(6, 4)`) The panel size (width, height) in inches. left: `float`, optional (default: `0.2`) This parameter sets the figure's left margin as a fraction of panel's width (left * panel_size[0]). bottom: `float`, optional (default: `0.15`) This parameter sets the figure's bottom margin as a fraction of panel's height (bottom * panel_size[1]). wspace: `float`, optional (default: `0.4`) This parameter sets the width between panels and also the figure's right margin as a fraction of panel's width (wspace * panel_size[0]). hspace: `float`, optional (defualt: `0.15`) This parameter sets the height between panels and also the figure's top margin as a fraction of panel's height (hspace * panel_size[1]). return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: 300.0 The resolution of the figure in dots-per-inch. Returns ------- `Figure` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples -------- >>> pg.scatter(data, attrs=['louvain_labels', 'Channel'], basis='fitsne') >>> pg.scatter(data, attrs=['CD14', 'TRAC'], basis='umap') """ if not is_list_like(attrs): attrs = [attrs] nattrs = len(attrs) if isinstance(data, MultimodalData) or isinstance(data, UnimodalData): cur_matkey = data.current_matrix() if matkey is not None: assert isinstance(data, MultimodalData) or isinstance(data, UnimodalData) data.select_matrix(matkey) x = data.obsm[f"X_{basis}"][:, 0] y = data.obsm[f"X_{basis}"][:, 1] # four corners of the plot corners = np.array(np.meshgrid([x.min(), x.max()], [y.min(), y.max()])).T.reshape(-1, 2) basis = _transform_basis(basis) marker_size = _get_marker_size(x.size) nrows, ncols = _get_nrows_and_ncols(nattrs, nrows, ncols) fig, axes = _get_subplot_layouts(nrows=nrows, ncols=ncols, panel_size=panel_size, dpi=dpi, left=left, bottom=bottom, wspace=wspace, hspace=hspace, squeeze=False) if not is_list_like(alpha): alpha = [alpha] * nattrs if not is_list_like(legend_loc): legend_loc = [legend_loc] * nattrs legend_fontsize = [5 if x == 'on data' else 10 for x in legend_loc] palettes = DictWithDefault(palettes) cmaps = DictWithDefault(cmaps) restr_obj = RestrictionParser(restrictions) restr_obj.calc_default(data) for i in range(nrows): for j in range(ncols): ax = axes[i, j] ax.grid(False) ax.set_xticks([]) ax.set_yticks([]) if i * ncols + j < nattrs: pos = i * ncols + j attr = attrs[pos] if attr in data.obs: values = data.obs[attr].values else: try: loc = data.var_names.get_loc(attr) except KeyError: raise KeyError(f"{attr} is neither in data.obs nor data.var_names!") values = data.X[:, loc].toarray().ravel() if issparse(data.X) else data.X[:, loc] selected = restr_obj.get_satisfied(data, attr) if is_numeric_dtype(values): cmap = cmaps.get(attr, squeeze = True) if cmap is None: raise KeyError(f"Please set colormap for attribute {attr} or set a default colormap!") _plot_corners(ax, corners, marker_size) img = ax.scatter( x[selected], y[selected], c=values[selected], s=marker_size, marker=".", alpha=alpha[pos], edgecolors="none", cmap=cmap, vmin=vmin, vmax=vmax, rasterized=True, ) left, bottom, width, height = ax.get_position().bounds rect = [left + width * (1.0 + 0.05), bottom, width * 0.1, height] ax_colorbar = fig.add_axes(rect) fig.colorbar(img, cax=ax_colorbar) else: labels, with_background = _generate_categories(values, restr_obj.get_satisfied(data, attr)) label_size = labels.categories.size palette = palettes.get(attr) if palette is None: palette = _get_palette(label_size, with_background=with_background, show_background=show_background) elif with_background: palette = ["gainsboro" if show_background else "white"] + list(palette) text_list = [] for k, cat in enumerate(labels.categories): idx = labels == cat if idx.sum() > 0: scatter_kwargs = {"marker": ".", "alpha": alpha[pos], "edgecolors": "none", "rasterized": True} if cat != "": if legend_loc[pos] != "on data": scatter_kwargs["label"] = cat else: text_list.append((np.median(x[idx]), np.median(y[idx]), cat)) if cat != "" or (cat == "" and show_background): ax.scatter( x[idx], y[idx], c=palette[k], s=marker_size, **scatter_kwargs, ) else: _plot_corners(ax, corners, marker_size) if legend_loc[pos] == "right margin": legend = ax.legend( loc="center left", bbox_to_anchor=(1, 0.5), frameon=False, fontsize=legend_fontsize[pos], ncol=_get_legend_ncol(label_size, legend_ncol), ) for handle in legend.legendHandles: handle.set_sizes([300.0]) elif legend_loc[pos] == "on data": texts = [] for px, py, txt in text_list: texts.append(ax.text(px, py, txt, fontsize=legend_fontsize[pos], fontweight = "bold", ha = "center", va = "center")) # from adjustText import adjust_text # adjust_text(texts, arrowprops=dict(arrowstyle='-', color='k', lw=0.5)) ax.set_title(attr) else: ax.set_frame_on(False) if i == nrows - 1: ax.set_xlabel(f"{basis}1") if j == 0: ax.set_ylabel(f"{basis}2") # Reset current matrix if needed. if (not isinstance(data, anndata.AnnData)): if cur_matkey != data.current_matrix(): data.select_matrix(cur_matkey) return fig if return_fig else None
[docs]def scatter_groups( data: Union[MultimodalData, UnimodalData, anndata.AnnData], attr: str, groupby: str, basis: Optional[str] = "umap", matkey: Optional[str] = None, restrictions: Optional[Union[str, List[str]]] = None, show_full: Optional[bool] = True, categories: Optional[List[str]] = None, alpha: Optional[float] = 1.0, legend_loc: Optional[str] = "right margin", legend_ncol: Optional[str] = None, palette: Optional[str] = None, cmap: Optional[str] = "YlOrRd", vmin: Optional[float] = None, vmax: Optional[float] = None, nrows: Optional[int] = None, ncols: Optional[int] = None, panel_size: Optional[Tuple[float, float]] = (4, 4), left: Optional[float] = 0.2, bottom: Optional[float] = 0.15, wspace: Optional[float] = 0.4, hspace: Optional[float] = 0.15, return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwargs, ) -> Union[plt.Figure, None]: """ Generate scatter plots of attribute 'attr' for each category in attribute 'group'. Optionally show scatter plot containing data points from all categories in 'group'. Parameters ---------- data: ``pegasusio.MultimodalData`` Use current selected modality in data. attr: ``str`` Color scatter plots by attribute 'attr'. This attribute should be one key in data.obs or data.var_names (e.g. one gene). If it is categorical, a palette will be used to color each category separately. Otherwise, a color map will be used. groupby: ``str`` Generate separate scatter plots of 'attr' for data points in each category in 'groupby', which should be a key in data.obs representing one categorical variable. basis: ``str``, optional, default: ``umap`` Basis to be used to generate scatter plots. Can be either 'umap', 'tsne', 'fitsne', 'fle', 'net_tsne', 'net_fitsne', 'net_umap' or 'net_fle'. matkey: ``str``, optional, default: None If matkey is set, select matrix with matkey as keyword in the current modality. Only works for MultimodalData or UnimodalData objects. restrictions: ``str`` or ``List[str]``, optional, default: None A list of restrictions to subset data for plotting. Each restriction takes the format of 'key:value,value...', or 'key:~value,value...'. This restriction selects cells with the ``data.obs[key]`` values belong to 'value,value...' (or not belong to if '~' shows). show_full: ``bool``, optional, default: True Show the scatter plot with all categories in 'groupby' as the first plot. categories: ``List[str]``, optional, default: None Redefine group structure based on attribute 'groupby'. If 'categories' is not None, each string in the list takes the format of 'category_name:value,value', or 'category_name:~value,value...", where 'category_name' refers to new category name, 'value' refers to one of the category in 'groupby' and '~' refers to exclude values. alpha: ``float``, optional, default: ``1.0`` Alpha value for blending, from 0.0 (transparent) to 1.0 (opaque). legend_loc: ``str``, optional, default: ``right margin`` Legend location. Can be either "right margin" or "on data". legend_ncol: ``str``, optional, default: None Only applicable if legend_loc == "right margin". Set number of columns used to show legends. palette: ``str``, optional, default: None Used for setting colors for one categorical attribute (e.g. "black,blue,red,...,yellow"). cmap: ``str``, optional, default: ``YlOrRd`` Used for setting colormap for one numeric attribute. vmin: ``float``, optional, default: None Minimum value to show on a numeric scatter plot (feature plot). vmax: ``float``, optional, default: None Maximum value to show on a numeric scatter plot (feature plot). nrows: ``int``, optional, default: None Number of rows in the figure. If not set, pegasus will figure it out automatically. ncols: ``int``, optional, default: None Number of columns in the figure. If not set, pegasus will figure it out automatically. panel_size: `tuple`, optional (default: `(6, 4)`) The panel size (width, height) in inches. left: `float`, optional (default: `0.2`) This parameter sets the figure's left margin as a fraction of panel's width (left * panel_size[0]). bottom: `float`, optional (default: `0.15`) This parameter sets the figure's bottom margin as a fraction of panel's height (bottom * panel_size[1]). wspace: `float`, optional (default: `0.4`) This parameter sets the width between panels and also the figure's right margin as a fraction of panel's width (wspace * panel_size[0]). hspace: `float`, optional (defualt: `0.15`) This parameter sets the height between panels and also the figure's top margin as a fraction of panel's height (hspace * panel_size[1]). return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: 300.0 The resolution of the figure in dots-per-inch. Returns ------- `Figure` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples -------- >>> pg.scatter_groups(data, attr='louvain_labels', groupby='Individual', basis='tsne', nrows = 2, ncols = 4, alpha = 0.5) >>> pg.scatter_groups(data, attr='anno', groupby='Channel', basis='umap', categories=['new_cat1:channel1,channel2', 'new_cat2:channel3']) """ if isinstance(data, MultimodalData) or isinstance(data, UnimodalData): cur_matkey = data.current_matrix() if matkey is not None: assert isinstance(data, MultimodalData) or isinstance(data, UnimodalData) data.select_matrix(matkey) x = data.obsm[f"X_{basis}"][:, 0] y = data.obsm[f"X_{basis}"][:, 1] # four corners of the plot corners = np.array(np.meshgrid([x.min(), x.max()], [y.min(), y.max()])).T.reshape(-1, 2) basis = _transform_basis(basis) marker_size = _get_marker_size(x.size) if attr in data.obs: values = data.obs[attr].values else: try: loc = data.var_names.get_loc(attr) except KeyError: raise KeyError(f"{attr} is neither in data.obs nor data.var_names!") values = data.X[:, loc].toarray().ravel() if issparse(data.X) else data.X[:, loc] is_cat = is_categorical_dtype(values) if (not is_cat) and (not is_numeric_dtype(values)): values = pd.Categorical(values, categories=natsorted(np.unique(values))) is_cat = True assert groupby in data.obs groups = data.obs[groupby].values if not is_categorical_dtype(groups): groups = pd.Categorical(groups, categories=natsorted(np.unique(groups))) restr_obj = RestrictionParser(restrictions) restr_obj.calc_default(data) selected = restr_obj.get_satisfied(data) nsel = selected.sum() if nsel < data.shape[0]: x = x[selected] y = y[selected] values = values[selected] groups = groups[selected] df_g = pd.DataFrame() if show_full: df_g["All"] = np.ones(nsel, dtype=bool) if categories is None: for cat in groups.categories: df_g[cat] = groups == cat else: cat_obj = RestrictionParser(categories) for cat, idx in cat_obj.next_category(groups): df_g[cat] = idx nrows, ncols = _get_nrows_and_ncols(df_g.shape[1], nrows, ncols) fig, axes = _get_subplot_layouts(nrows=nrows, ncols=ncols, panel_size=panel_size, dpi=dpi, left=left, bottom=bottom, wspace=wspace, hspace=hspace, squeeze=False) legend_fontsize = 5 if legend_loc == 'on data' else 10 if is_cat: labels = values label_size = labels.categories.size palette = _get_palette(label_size) if palette is None else np.array(palette.split(",")) legend_ncol = _get_legend_ncol(label_size, legend_ncol) for i in range(nrows): for j in range(ncols): ax = axes[i, j] ax.grid(False) ax.set_xticks([]) ax.set_yticks([]) gid = i * ncols + j if gid < df_g.shape[1]: if is_cat: text_list = [] for k, cat in enumerate(labels.categories): idx = np.logical_and(df_g.iloc[:, gid].values, labels == cat) _plot_corners(ax, corners, marker_size) if idx.sum() > 0: scatter_kwargs = {"marker": ".", "alpha": alpha, "edgecolors": "none", "rasterized": True} if legend_loc != "on data": scatter_kwargs["label"] = str(cat) else: text_list.append((np.median(x[idx]), np.median(y[idx]), str(cat))) ax.scatter( x[idx], y[idx], c=palette[k], s=marker_size, **scatter_kwargs, ) if legend_loc == "right margin": legend = ax.legend( loc="center left", bbox_to_anchor=(1, 0.5), frameon=False, fontsize=legend_fontsize, ncol=legend_ncol, ) for handle in legend.legendHandles: handle.set_sizes([300.0]) elif legend_loc == "on data": texts = [] for px, py, txt in text_list: texts.append(ax.text(px, py, txt, fontsize=legend_fontsize, fontweight = "bold", ha = "center", va = "center")) else: _plot_corners(ax, corners, marker_size) idx_g = df_g.iloc[:, gid].values img = ax.scatter( x[idx_g], y[idx_g], s=marker_size, c=values[idx_g], marker=".", alpha=alpha, edgecolors="none", cmap=cmap, vmin=vmin, vmax=vmax, rasterized=True, ) left, bottom, width, height = ax.get_position().bounds rect = [left + width * (1.0 + 0.05), bottom, width * 0.1, height] ax_colorbar = fig.add_axes(rect) fig.colorbar(img, cax=ax_colorbar) ax.set_title(str(df_g.columns[gid])) else: ax.set_frame_on(False) if i == nrows - 1: ax.set_xlabel(basis + "1") if j == 0: ax.set_ylabel(basis + "2") if not isinstance(data, anndata.AnnData): if cur_matkey != data.current_matrix(): data.select_matrix(cur_matkey) return fig if return_fig else None
[docs]def compo_plot( data: Union[MultimodalData, UnimodalData, anndata.AnnData], groupby: str, condition: str, style: Optional[str] = "frequency", restrictions: Optional[Union[str, List[str]]] = None, xlabel: Optional[str] = None, panel_size: Optional[Tuple[float, float]] = (6, 4), left: Optional[float] = 0.15, bottom: Optional[float] = 0.15, wspace: Optional[float] = 0.3, hspace: Optional[float] = 0.15, return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwargs, ) -> Union[plt.Figure, None]: """Generate a composition plot, which shows the percentage of cells from each condition for every cluster. This function is used to generate composition plots, which are bar plots showing the cell compositions (from different conditions) for each cluster. This type of plots is useful to fast assess library quality and batch effects. Parameters ---------- data : `AnnData` or `UnimodalData` or `MultimodalData` object Single cell expression data. groupby : `str` A categorical variable in data.obs that is used to categorize the cells, e.g. cell type. condition: `str` A categorical variable in data.obs that is used to calculate frequency within each category defined by 'groupby', e.g. donor. style: `str`, optional (default: `frequency`) Composition plot style. Can be either `frequency`, or 'normalized'. Within each cluster, the `frequency` style show the percentage of cells from each 'condition' within each category in 'groupby' (stacked), the `normalized` style shows for each category in 'groupby' the percentage of cells that are also in each 'condition' over all cells that are in the same 'condition' (not stacked). restrictions: ``str`` or ``List[str]``, optional, default: None A list of restrictions to subset data for plotting. Each restriction takes the format of 'key:value,value...', or 'key:~value,value...'. This restriction selects cells with the ``data.obs[key]`` values belong to 'value,value...' (or not belong to if '~' shows). xlabel: `str`, optional (default None) Label for the horizontal axis. If None, use 'groupby'. panel_size: `tuple`, optional (default: `(6, 4)`) The plot size (width, height) in inches. left: `float`, optional (default: `0.15`) This parameter sets the figure's left margin as a fraction of panel's width (left * panel_size[0]). bottom: `float`, optional (default: `0.15`) This parameter sets the figure's bottom margin as a fraction of panel's height (bottom * panel_size[1]). wspace: `float`, optional (default: `0.3`) This parameter sets the width between panels and also the figure's right margin as a fraction of panel's width (wspace * panel_size[0]). hspace: `float`, optional (defualt: `0.15`) This parameter sets the height between panels and also the figure's top margin as a fraction of panel's height (hspace * panel_size[1]). return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. Returns ------- `Figure` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples -------- >>> fig = pg.compo_plot(data, 'louvain_labels', 'Donor', style = 'normalized') """ if xlabel is None: xlabel = groupby fig, ax = _get_subplot_layouts(panel_size=panel_size, dpi=dpi, left=left, bottom=bottom, wspace=wspace, hspace=hspace) # default nrows = 1 & ncols = 1 restr_obj = RestrictionParser(restrictions) restr_obj.calc_default(data) selected = restr_obj.get_satisfied(data) df = pd.crosstab(data.obs.loc[selected, groupby], data.obs.loc[selected, condition]) df = df.reindex( index=natsorted(df.index.values), columns=natsorted(df.columns.values) ) if style == "frequency": df = df.div(df.sum(axis=1), axis=0) * 100.0 else: assert style == "normalized" df = df.div(df.sum(axis=0), axis=1) * 100.0 df.plot( kind = "bar", stacked = style == "frequency", legend = False, color = _get_palette(df.shape[1]), ax = ax, ) ax.grid(False) ax.set_xlabel(xlabel) ax.set_ylabel("Percentage") ax.legend(loc="center left", bbox_to_anchor=(1.05, 0.5)) if len(max(df.index.astype(str), key=len)) >= 5: ax.set_xticklabels(ax.get_xticklabels(), rotation=-45, ha='left') return fig if return_fig else None
[docs]def violin( data: Union[MultimodalData, UnimodalData, anndata.AnnData], attrs: Union[str, List[str]], groupby: str, hue: Optional[str] = None, matkey: Optional[str] = None, stripplot: bool = False, scale: str = 'width', panel_size: Optional[Tuple[float, float]] = (8, 0.5), left: Optional[float] = 0.15, bottom: Optional[float] = 0.15, wspace: Optional[float] = 0.1, ylabel: Optional[str] = None, return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwargs, ) -> Union[plt.Figure, None]: """ Generate a stacked violin plot. Parameters ---------- data: ``AnnData`` or ``MultimodalData`` or ``UnimodalData`` object Single-cell expression data. attrs: ``str`` or ``List[str]`` Cell attributes or features to plot. Cell attributes must exist in ``data.obs`` and must be numeric. Features must exist in ``data.var``. groupby: ``str`` A categorical variable in data.obs that is used to categorize the cells, e.g. Clusters. hue: ``str``, optional, default: None 'hue' should be a categorical variable in data.obs that has only two levels. Set 'hue' will show us split violin plots. matkey: ``str``, optional, default: ``None`` If matkey is set, select matrix with matkey as keyword in the current modality. Only works for MultimodalData or UnimodalData objects. stripplot: ``bool``, optional, default: ``False`` Attach a stripplot to the violinplot or not. This option will be automatically turn off if 'hue' is set. scale: ``str``, optional, default: ``width`` The method used to scale the width of each violin: - If ``width``, each violin will have the same width. - If ``area``, each violin will have the same area. - If ``count``, the width of the violins will be scaled by the number of observations in that bin. jitter: ``float`` or ``bool``, optional, default: ``False`` Amount of jitter (only along the categorical axis) to apply to stripplot. This is used only when ``stripplot`` is set to ``True``. This can be useful when you have many points and they overlap, so that it is easier to see the distribution. You can specify the amount of jitter (half the width of the uniform random variable support), or just use ``True`` for a good default. panel_size: ``Tuple[float, float]``, optional, default: ``(8, 0.5)`` The size (width, height) in inches of each violin panel. left: ``float``, optional, default: ``0.15`` This parameter sets the figure's left margin as a fraction of panel's width (left * panel_size[0]). bottom: ``float``, optional, default: ``0.15`` This parameter sets the figure's bottom margin as a fraction of panel's height (bottom * panel_size[1]). wspace: ``float``, optional, default: ``0.1`` This parameter sets the width between panels and also the figure's right margin as a fraction of panel's width (wspace * panel_size[0]). ylabel: ``str``, optional, default: ``None`` Y-axis label. No label to show if ``None``. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. kwargs Are passed to ``seaborn.violinplot``. Returns ------- ``Figure`` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``show == False`` Examples -------- >>> pg.violin(data, attrs=['CD14', 'TRAC', 'CD34'], groupby='louvain_labels') """ if not is_list_like(attrs): attrs = [attrs] if isinstance(data, MultimodalData) or isinstance(data, UnimodalData): cur_matkey = data.current_matrix() if matkey is not None: assert isinstance(data, MultimodalData) or isinstance(data, UnimodalData) data.select_matrix(matkey) nrows = len(attrs) fig, axes = _get_subplot_layouts(nrows=nrows, ncols=1, panel_size=panel_size, dpi=dpi, left=left, bottom=bottom, wspace=wspace, hspace=0, squeeze=False, sharey=False) obs_keys = [] genes = [] for key in attrs: if key in data.obs: assert is_numeric_dtype(data.obs[key]) obs_keys.append(key) else: if key not in data.var_names: logger.warning("Cannot find gene {key}. Please make sure all genes are included in data.var_names before running this function!") return None genes.append(key) df_list = [pd.DataFrame({"label": data.obs[groupby].values})] if hue is not None: df_list.append(pd.DataFrame({hue: data.obs[hue].values})) stripplot = False if len(obs_keys) > 0: df_list.append(data.obs[obs_keys].reset_index(drop=True)) if len(genes) > 0: expr_mat = data[:, genes].X.toarray() df_list.append(pd.DataFrame(data=expr_mat, columns=genes)) df = pd.concat(df_list, axis = 1) for i in range(nrows): ax = axes[i, 0] if stripplot: sns.stripplot(x="label", y=attrs[i], hue = hue, data=df, ax=ax, size=1, color="k", jitter=True) sns.violinplot(x="label", y=attrs[i], hue = hue, data=df, inner=None, linewidth=1, ax=ax, cut=0, scale=scale, split=True, **kwargs) ax.grid(False) if hue is not None: if i == 0: ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5)) else: ax.get_legend().set_visible(False) if i < nrows - 1: ax.set_xlabel("") else: ax.set_xlabel(groupby) ax.set_xticklabels(ax.get_xticklabels(), rotation=90) ax.set_ylabel(attrs[i], labelpad=8, rotation=0, horizontalalignment='right', fontsize='medium') ax.tick_params(axis='y', right=True, left=False, labelright=True, labelleft=False, labelsize='small') if ylabel is not None: fig.text(0.02, 0.5, ylabel, rotation="vertical", fontsize="xx-large") # Reset current matrix if needed. if not isinstance(data, anndata.AnnData): if data.current_matrix() != cur_matkey: data.select_matrix(cur_matkey) return fig if return_fig else None
[docs]def heatmap( data: Union[MultimodalData, UnimodalData, anndata.AnnData], genes: Union[str, List[str]], groupby: str, matkey: Optional[str] = None, on_average: bool = True, switch_axes: bool = False, row_cluster: Optional[bool] = None, col_cluster: Optional[bool] = None, panel_size: Tuple[float, float] = (10, 10), return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwargs, ) -> Union[plt.Figure, None]: """ Generate a heatmap. Parameters ----------- data: ``AnnData`` or ``MultimodalData`` or ``UnimodalData`` object Single-cell expression data. genes: ``str`` or ``List[str]`` Features to plot. groupby: ``str`` A categorical variable in data.obs that is used to categorize the cells, e.g. Clusters. matkey: ``str``, optional, default: ``None`` If matkey is set, select matrix with matkey as keyword in the current modality. Only works for MultimodalData or UnimodalData objects. on_average: ``bool``, optional, default: ``True`` If ``True``, plot cluster average gene expression (i.e. show a Matrixplot); otherwise, plot a general heatmap. switch_axes: ``bool``, optional, default: ``False`` By default, X axis is for genes, and Y axis for clusters. If this parameter is ``True``, switch the axes. Moreover, with ``on_average`` being ``False``, if ``switch_axes`` is ``False``, ``row_cluster`` is enforced to be ``False``; if ``switch_axes`` is ``True``, ``col_cluster`` is enforced to be ``False``. row_cluster: ``bool``, optional, default: ``False`` Cluster rows and generate a row-wise dendrogram. col_cluster: ``bool``, optional, default: ``True`` Cluster columns and generate a column-wise dendrogram. panel_size: ``Tuple[float, float]``, optional, default: ``(10, 10)`` Overall size of the heatmap in ``(width, height)`` form. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. kwargs Are passed to ``seaborn.heatmap``. .. _colormap documentation: https://matplotlib.org/3.1.0/tutorials/colors/colormaps.html Returns ------- ``Figure`` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples -------- >>> pg.heatmap(data, genes=['CD14', 'TRAC', 'CD34'], groupby='louvain_labels') """ if isinstance(data, MultimodalData) or isinstance(data, UnimodalData): cur_matkey = data.current_matrix() if matkey is not None: assert isinstance(data, MultimodalData) or isinstance(data, UnimodalData) data.select_matrix(matkey) if row_cluster is None: row_cluster = True if switch_axes else False if col_cluster is None: col_cluster = True if not switch_axes else False df = pd.DataFrame(data[:, genes].X.toarray(), index=data.obs.index, columns=genes) df['cluster_name'] = data.obs[groupby] if on_average: if not 'cmap' in kwargs.keys(): kwargs['cmap'] = 'Reds' df = df.groupby('cluster_name').mean() cluster_ids = df.index else: row_cluster = False if not switch_axes else row_cluster col_cluster = False if switch_axes else col_cluster cluster_ids = pd.Categorical(data.obs[groupby]) idx = cluster_ids.argsort() df = df.iloc[idx, :] # organize df by category order df.drop(columns=['cluster_name'], inplace=True) cell_colors = np.zeros(df.shape[0], dtype=object) palette = _get_palette(cluster_ids.categories.size) cluster_ids = cluster_ids[idx] for k, cat in enumerate(cluster_ids.categories): cell_colors[np.isin(cluster_ids, cat)] = palette[k] if not switch_axes: cg = sns.clustermap( data=df, row_colors=cell_colors if not on_average else None, col_colors=None, row_cluster=row_cluster, col_cluster=col_cluster, linewidths=0, yticklabels=cluster_ids if on_average else [], xticklabels=genes, figsize=panel_size, **kwargs, ) cg.ax_heatmap.set_ylabel("") else: cg = sns.clustermap( data=df.T, row_colors=None, col_colors=cell_colors if not on_average else None, row_cluster=row_cluster, col_cluster=col_cluster, linewidths=0, yticklabels=genes, xticklabels=cluster_ids if on_average else [], figsize=panel_size, **kwargs, ) cg.ax_heatmap.set_xlabel("") if row_cluster: cg.ax_heatmap.yaxis.tick_right() else: cg.ax_heatmap.yaxis.tick_left() cg.ax_row_dendrogram.set_visible(row_cluster) cg.cax.tick_params(labelsize=10) cg.fig.dpi = dpi if not row_cluster: # Move the colorbar to the right-side. color_box = cg.ax_heatmap.get_position() color_box.x0 = color_box.x1 + 0.04 color_box.x1 = color_box.x0 + 0.02 cg.cax.set_position(color_box) cg.cax.yaxis.set_ticks_position("right") else: # Avoid overlap of colorbar and row dendrogram. color_box = cg.cax.get_position() square_plot = cg.ax_heatmap.get_position() if square_plot.y1 > color_box.y0: y_diff = square_plot.y1 - color_box.y0 color_box.y0 = square_plot.y1 color_box.y1 += y_diff cg.cax.set_position(color_box) if not on_average: orientation = 'left' if not switch_axes else 'top' _plot_cluster_labels_in_heatmap(cg.ax_heatmap, cluster_ids, orientation) if not isinstance(data, anndata.AnnData): if cur_matkey != data.current_matrix(): data.select_matrix(cur_matkey) return cg.fig if return_fig else None
[docs]def dotplot( data: Union[MultimodalData, UnimodalData, anndata.AnnData], genes: Union[str, List[str]], groupby: str, reduce_function: Callable[[np.ndarray], float] = np.mean, fraction_min: float = 0, fraction_max: float = None, dot_min: int = 0, dot_max: int = 20, switch_axes: bool = False, cmap: Union[str, List[str], Tuple[str]] = 'Reds', sort_function: Callable[[pd.DataFrame], List[str]] = None, grid: bool = True, return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwds, ) -> Union[plt.Figure, None]: """ Generate a dot plot. Parameters ---------- data: ``AnnData`` or ``UnimodalData`` or ``MultimodalData`` object Single cell expression data. genes: ``str`` or ``List[str]`` Features to plot. groupby: ``str`` A categorical variable in data.obs that is used to categorize the cells, e.g. Clusters. reduce_function: ``Callable[[np.ndarray], float]``, optional, default: ``np.mean`` Function to calculate statistic on expression data. Default is mean. fraction_min: ``float``, optional, default: ``0``. Minimum fraction of expressing cells to consider. fraction_max: ``float``, optional, default: ``None``. Maximum fraction of expressing cells to consider. If ``None``, use the maximum value from data. dot_min: ``int``, optional, default: ``0``. Minimum size in pixels for dots. dot_max: ``int``, optional, default: ``20``. Maximum size in pixels for dots. switch_axes: ``bool``, optional, default: ``False``. If ``True``, switch X and Y axes. cmap: ``str`` or ``List[str]`` or ``Tuple[str]``, optional, default: ``Reds`` Color map. sort_function: ``Callable[[pd.DataFrame], List[str]]``, optional, default: ``None`` Function used for sorting labels. If ``None``, don't sort. grid: ``bool``, optional, default: ``True`` If ``True``, plot grids. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. **kwds: Are passed to ``matplotlib.pyplot.scatter``. Returns ------- ``Figure`` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples -------- >>> pg.dotplot(data, genes = ['CD14', 'TRAC', 'CD34'], groupby = 'louvain_labels') """ sns.set(font_scale=0.7, style='whitegrid') if not is_list_like(genes): geness = [genes] keywords = dict(cmap=cmap) keywords.update(kwds) from scipy.sparse import issparse X = data[:, genes].X if issparse(X): X = X.toarray() df = pd.DataFrame(data=X, columns=genes) df[groupby] = data.obs[groupby].values def non_zero(g): return np.count_nonzero(g) / g.shape[0] summarized_df = df.groupby(groupby).aggregate([reduce_function, non_zero]) if sort_function is not None: row_indices = sort_function(summarized_df) summarized_df = summarized_df.iloc[row_indices] else: summarized_df = summarized_df.loc[natsorted(summarized_df.index, reverse=True)] mean_columns = [] frac_columns = [] for j in range(len(summarized_df.columns)): if j % 2 == 0: mean_columns.append(summarized_df.columns[j]) else: frac_columns.append(summarized_df.columns[j]) # Genes on columns, groupby on rows fraction_df = summarized_df[frac_columns] mean_df = summarized_df[mean_columns] y, x = np.indices(mean_df.shape) y = y.flatten() x = x.flatten() fraction = fraction_df.values.flatten() if fraction_max is None: fraction_max = fraction.max() pixels = _get_dot_size(fraction, fraction_min, fraction_max, dot_min, dot_max) summary_values = mean_df.values.flatten() xlabel = [genes[i] for i in range(len(genes))] ylabel = [str(summarized_df.index[i]) for i in range(len(summarized_df.index))] xticks = genes yticks = summarized_df.index.map(str).values if switch_axes: x, y = y, x xlabel, ylabel = ylabel, xlabel xticks, yticks = yticks, xticks dotplot_df = pd.DataFrame(data=dict(x=x, y=y, value=summary_values, pixels=pixels, fraction=fraction, xlabel=np.array(xlabel)[x], ylabel=np.array(ylabel)[y])) import matplotlib.gridspec as gridspec width = int(np.ceil(((dot_max + 1) + 4) * len(xticks) + dotplot_df['ylabel'].str.len().max()) + dot_max + 100) height = int(np.ceil(((dot_max + 1) + 4) * len(yticks) + dotplot_df['xlabel'].str.len().max()) + 50) fig = plt.figure(figsize=(1.1 * width / 100.0, height / 100.0), dpi=dpi) gs = gridspec.GridSpec(3, 11, figure = fig) # Main plot mainplot_col_grid = -2 if len(xlabel) < 10 else -1 ax = fig.add_subplot(gs[:, :mainplot_col_grid]) sc = ax.scatter(x='x', y='y', c='value', s='pixels', data=dotplot_df, linewidth=0.5, edgecolors='black', **keywords) ax.spines["top"].set_color('black') ax.spines["bottom"].set_color('black') ax.spines["left"].set_color('black') ax.spines["right"].set_color('black') if not grid: ax.grid(False) if not switch_axes: ax.set_ylabel(str(groupby)) ax.set_xlabel('') else: ax.set_ylabel('') ax.set_xlabel(str(groupby)) ax.set_xlim(-1, len(xticks)) ax.set_ylim(-1, len(yticks)) ax.set_xticks(range(len(xticks))) ax.set_xticklabels(xticks) ax.set_yticks(range(len(yticks))) ax.set_yticklabels(yticks) plt.xticks(rotation=90) cbar = plt.colorbar(sc) #cbar.set_label("Mean of\nexpressing cells") size_range = fraction_max - fraction_min if 0.3 < size_range <= 0.6: size_legend_step = 0.1 elif size_range <= 0.3: size_legend_step = 0.05 else: size_legend_step = 0.2 size_ticks = np.arange(fraction_min if fraction_min > 0 or fraction_min > 0 else fraction_min + size_legend_step, fraction_max + size_legend_step, size_legend_step) legend_row_grid = 1 if height / 3 > 100 else 3 ax2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[0:legend_row_grid, -1]) size_legend = fig.add_subplot(ax2[0]) size_tick_pixels = _get_dot_size(size_ticks, fraction_min, fraction_max, dot_min, dot_max) size_tick_labels = ["{:.0%}".format(x) for x in size_ticks] size_legend.scatter(x=np.repeat(0, len(size_ticks)), y=np.arange(0, len(size_ticks)), s=size_tick_pixels, c='black', linewidth=0.5) size_legend.title.set_text("Fraction of\nexpressing cells") size_legend.set_xlim(-0.1, 0.1) size_legend.set_xticks([]) ymin, ymax = size_legend.get_ylim() size_legend.set_ylim(ymin, ymax + 0.5) size_legend.set_yticks(np.arange(len(size_ticks))) size_legend.set_yticklabels(size_tick_labels) size_legend.tick_params(axis='y', labelleft=False, labelright=True) size_legend.spines["top"].set_visible(False) size_legend.spines["bottom"].set_visible(False) size_legend.spines["left"].set_visible(False) size_legend.spines["right"].set_visible(False) size_legend.grid(False) # Reset global settings. sns.reset_orig() return fig if return_fig else None
[docs]def dendrogram( data: Union[MultimodalData, UnimodalData, anndata.AnnData], groupby: str, rep: str = 'pca', genes: Optional[List[str]] = None, correlation_method: str = 'pearson', n_clusters: Optional[int] = None, affinity: str = 'euclidean', linkage: str = 'complete', compute_full_tree: Union[str, bool] = 'auto', distance_threshold: Optional[float] = 0, panel_size: Tuple[float, float] = (6, 6), orientation: str = 'top', color_threshold: Optional[float] = None, return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwargs, ) -> Union[plt.Figure, None]: """ Generate a dendrogram on hierarchical clustering result. The metrics used here are consistent with SCANPY's dendrogram_ implementation. *scikit-learn* `Agglomerative Clustering`_ implementation is used for hierarchical clustering. .. _dendrogram: https://scanpy.readthedocs.io/en/stable/api/scanpy.tl.dendrogram.html .. _Agglomerative Clustering: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html Parameters ---------- data: ``MultimodalData``, ``UnimodalData``, or ``AnnData`` object Single cell expression data. genes: ``List[str]``, optional, default: ``None`` List of genes to use. Gene names must exist in ``data.var``. If set, use the counts in ``data.X`` for plotting; if set as ``None``, use the embedding specified in ``rep`` for plotting. rep: ``str``, optional, default: ``pca`` Cell embedding to use. It only works when ``genes``is ``None``, and its key ``"X_"+rep`` must exist in ``data.obsm``. By default, use PCA coordinates. groupby: ``str`` Categorical cell attribute to plot, which must exist in ``data.obs``. correlation_method: ``str``, optional, default: ``pearson`` Method of correlation between categories specified in ``data.obs``. Available options are: ``pearson``, ``kendall``, ``spearman``. See `pandas corr documentation`_ for details. n_clusters: ``int``, optional, default: ``None`` The number of clusters to find, used by hierarchical clustering. It must be ``None`` if ``distance_threshold`` is not ``None``. affinity: ``str``, optional, default: ``correlation`` Metric used to compute the linkage, used by hierarchical clustering. Valid values for metric are: - From scikit-learn: ``cityblock``, ``cosine``, ``euclidean``, ``l1``, ``l2``, ``manhattan``. - From scipy.spatial.distance: ``braycurtis``, ``canberra``, ``chebyshev``, ``correlation``, ``dice``, ``hamming``, ``jaccard``, ``kulsinski``, ``mahalanobis``, ``minkowski``, ``rogerstanimoto``, ``russellrao``, ``seuclidean``, ``sokalmichener``, ``sokalsneath``, ``sqeuclidean``, ``yule``. Default is the correlation distance. See `scikit-learn distance documentation <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise_distances.html>`_ for details. linkage: ``str``, optional, default: ``complete`` Which linkage criterion to use, used by hierarchical clustering. Below are available options: - ``ward`` minimizes the variance of the clusters being merged. - ``avarage`` uses the average of the distances of each observation of the two sets. - ``complete`` uses the maximum distances between all observations of the two sets. (Default) - ``single`` uses the minimum of the distances between all observations of the two sets. See `scikit-learn documentation`_ for details. compute_full_tree: ``str`` or ``bool``, optional, default: ``auto`` Stop early the construction of the tree at ``n_clusters``, used by hierarchical clustering. It must be ``True`` if ``distance_threshold`` is not ``None``. By default, this option is ``auto``, which is ``True`` if and only if ``distance_threshold`` is not ``None``, or ``n_clusters`` is less than ``min(100, 0.02 * n_groups)``, where ``n_groups`` is the number of categories in ``data.obs[groupby]``. distance_threshold: ``float``, optional, default: ``0`` The linkage distance threshold above which, clusters will not be merged. If not ``None``, ``n_clusters`` must be ``None`` and ``compute_full_tree`` must be ``True``. panel_size: ``Tuple[float, float]``, optional, default: ``(6, 6)`` The size (width, height) in inches of figure. orientation: ``str``, optional, default: ``top`` The direction to plot the dendrogram. Available options are: ``top``, ``bottom``, ``left``, ``right``. See `scipy dendrogram documentation`_ for explanation. color_threshold: ``float``, optional, default: ``None`` Threshold for coloring clusters. See `scipy dendrogram documentation`_ for explanation. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. **kwargs: Are passed to ``scipy.cluster.hierarchy.dendrogram``. .. _scikit-learn documentation: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html .. _scipy dendrogram documentation: https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.dendrogram.html .. _pandas corr documentation: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.corr.html Returns ------- ``Figure`` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples -------- >>> pg.dendrogram(data, genes=data.var_names, groupby='louvain_labels') >>> pg.dendrogram(data, rep='pca', groupby='louvain_labels') """ if genes is None: embed_df = pd.DataFrame(X_from_rep(data, rep)) embed_df.set_index(data.obs[groupby], inplace=True) else: sub_data = data[:, genes] X = sub_data.X.toarray() if issparse(sub_data.X) else sub_data.X embed_df = pd.DataFrame(X) embed_df.set_index(data.obs[groupby], inplace=True) mean_df = embed_df.groupby(level=0).mean() mean_df.index = mean_df.index.astype('category') from sklearn.cluster import AgglomerativeClustering from scipy.cluster.hierarchy import dendrogram corr_mat = mean_df.T.corr(method=correlation_method) clusterer = AgglomerativeClustering( n_clusters=n_clusters, affinity=affinity, linkage=linkage, compute_full_tree=compute_full_tree, distance_threshold=distance_threshold ) clusterer.fit(corr_mat) counts = np.zeros(clusterer.children_.shape[0]) n_samples = len(clusterer.labels_) for i, merge in enumerate(clusterer.children_): current_count = 0 for child_idx in merge: if child_idx < n_samples: current_count += 1 # Leaf node else: current_count += counts[child_idx - n_samples] counts[i] = current_count linkage_matrix = np.column_stack([clusterer.children_, clusterer.distances_, counts]).astype(float) fig, ax = _get_subplot_layouts(panel_size=panel_size, dpi=dpi) dendrogram(linkage_matrix, labels=mean_df.index.categories, ax=ax, **kwargs) plt.xticks(rotation=90, fontsize=10) plt.tight_layout() return fig if return_fig else None
[docs]def hvfplot( data: Union[MultimodalData, UnimodalData, anndata.AnnData], top_n: int = 20, panel_size: Optional[Tuple[float, float]] = (6, 4), return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, ) -> Union[plt.Figure, None]: """ Generate highly variable feature plot. Only works for HVGs returned by ``highly_variable_features`` method with ``flavor=='pegasus'``. Parameters ----------- data: ``MultimodalData``, ``UnimodalData``, or ``anndata.AnnData`` object. Single cell expression data. top_n: ``int``, optional, default: ``20`` Number of top highly variable features to show names. panel_size: ``Tuple[float, float]``, optional, default: ``(6, 4)`` The size (width, height) in inches of figure. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. Returns -------- ``Figure`` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples --------- >>> pg.hvfplot(data) >>> pg.hvfplot(data, top_n=10, dpi=150) """ robust_idx = data.var["robust"].values x = data.var.loc[robust_idx, "mean"] y = data.var.loc[robust_idx, "var"] fitted = data.var.loc[robust_idx, "hvf_loess"] hvg_index = data.var.loc[robust_idx, "highly_variable_features"] hvg_rank = data.var.loc[robust_idx, "hvf_rank"] gene_symbols = data.var_names[robust_idx] fig, ax = _get_subplot_layouts(panel_size=panel_size, dpi=dpi) ax.scatter(x[hvg_index], y[hvg_index], s=5, c='b', marker='o', linewidth=0.5, alpha=0.5, label='highly variable features') ax.scatter(x[~hvg_index], y[~hvg_index], s=5, c='k', marker='o', linewidth=0.5, alpha=0.5, label = 'other features') ax.legend(loc = 'upper right', fontsize = 5) ax.set_xlabel("Mean log expression") ax.set_ylabel("Variance of log expression") order = x.argsort().values ax.plot(x[order], fitted[order], "r-", linewidth=1) ord_rank = hvg_rank.argsort().values texts = [] for i in range(top_n): pos = ord_rank[i] texts.append(ax.text(x[pos], y[pos], gene_symbols[pos], fontsize=5)) from adjustText import adjust_text adjust_text(texts, arrowprops=dict(arrowstyle='-', color='k', lw=0.5)) return fig if return_fig else None
[docs]def qcviolin( data: Union[MultimodalData, UnimodalData, anndata.AnnData], plot_type: str, min_genes_before_filt: Optional[int] = 100, n_violin_per_panel: Optional[int] = 8, panel_size: Optional[Tuple[float, float]] = (6, 4), left: Optional[float] = 0.2, bottom: Optional[float] = 0.15, wspace: Optional[float] = 0.3, hspace: Optional[float] = 0.35, return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, ) -> Union[plt.Figure, None]: """ Plot quality control statistics (before filtration vs. after filtration) as violin plots. Require statistics such as "n_genes", "n_counts" and "percent_mito" precomputed. Parameters ----------- data: ``MultimodalData``, ``UnimodalData``, or ``anndata.AnnData`` object. Single cell expression data. plot_type: ``str`` Choose from ``gene``, ``count`` and ``mito``, which shows number of expressed genes, number of UMIs and percentage of mitochondrial rate. min_genes_before_filt: ``int``, optional, default: 100 If data loaded are raw data (i.e. min(n_genes) == 0), filter out cell barcodes with less than ``min_genes_before_filt`` for better visual effects. n_violin_per_panel: ``int``, optional, default: 8 Number of violin plots (samples) shown in one panel. panel_size: `tuple`, optional (default: `(6, 4)`) The panel size (width, height) in inches. left: `float`, optional (default: `0.2`) This parameter sets the figure's left margin as a fraction of panel's width (left * panel_size[0]). bottom: `float`, optional (default: `0.15`) This parameter sets the figure's bottom margin as a fraction of panel's height (bottom * panel_size[1]). wspace: `float`, optional (default: `0.4`) This parameter sets the width between panels and also the figure's right margin as a fraction of panel's width (wspace * panel_size[0]). hspace: `float`, optional (defualt: `0.15`) This parameter sets the height between panels and also the figure's top margin as a fraction of panel's height (hspace * panel_size[1]). return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. Returns -------- ``Figure`` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples --------- >>> pg.qcviolin(data, "mito", dpi = 500) """ pt2attr = {"gene": "n_genes", "count": "n_counts", "mito": "percent_mito"} pt2ylab = { "gene": "Number of expressed genes", "count": "Number of UMIs", "mito": "Percentage of mitochondrial UMIs", } if "df_qcplot" not in data.uns: if "Channel" not in data.obs: data.obs["Channel"] = pd.Categorical([""] * data.shape[0]) target_cols = np.array(["Channel", "n_genes", "n_counts", "percent_mito"]) target_cols = target_cols[np.isin(target_cols, data.obs.columns)] df = data.obs[data.obs["n_genes"] >= min_genes_before_filt] if data.obs["n_genes"].min() == 0 else data.obs df_plot_before = df[target_cols].copy() df_plot_before.reset_index(drop=True, inplace=True) df_plot_before["status"] = "original" df_plot_after = data.obs.loc[data.obs["passed_qc"], target_cols].copy() df_plot_after.reset_index(drop=True, inplace=True) df_plot_after["status"] = "filtered" df_qcplot = pd.concat((df_plot_before, df_plot_after), axis=0) df_qcplot["status"] = pd.Categorical(df_qcplot["status"].values, categories = ["original", "filtered"]) df_qcplot["Channel"] = pd.Categorical(df_qcplot["Channel"].values, categories = natsorted(df_qcplot["Channel"].astype(str).unique())) data.uns["df_qcplot"] = df_qcplot df_qcplot = data.uns["df_qcplot"] if pt2attr[plot_type] not in df_qcplot: logger.warning(f"Cannot find qc metric {pt2attr[plot_type]}!") return None channels = df_qcplot["Channel"].cat.categories n_channels = channels.size n_pannels = (n_channels - 1) // n_violin_per_panel + 1 nrows = ncols = None nrows, ncols = _get_nrows_and_ncols(n_pannels, nrows, ncols) fig, axes = _get_subplot_layouts(nrows=nrows, ncols=ncols, panel_size=panel_size, dpi=dpi, left=left, bottom=bottom, wspace=wspace, hspace=hspace, sharex = False, sharey = False, squeeze=False) for i in range(nrows): for j in range(ncols): ax = axes[i, j] ax.grid(False) panel_no = i * ncols + j if panel_no < n_pannels: start = panel_no * n_violin_per_panel end = min(start + n_violin_per_panel, n_channels) idx = np.isin(df_qcplot["Channel"], channels[start:end]) if start == 0 and end == n_channels: df_plot = df_qcplot else: df_plot = df_qcplot[idx].copy() df_plot["Channel"] = pd.Categorical(df_plot["Channel"].values, categories = natsorted(channels[start:end])) sns.violinplot( x="Channel", y=pt2attr[plot_type], hue="status", data=df_plot, split=True, linewidth=0.5, cut=0, inner=None, ax = ax, ) ax.set_xlabel("Channel") ax.set_ylabel(pt2ylab[plot_type]) ax.legend(loc="upper right", fontsize=8) if max([len(x) for x in channels[start:end]]) >= 5: ax.set_xticklabels(ax.get_xticklabels(), fontsize=8, rotation=-45) else: ax.set_frame_on(False) ax.set_xticks([]) ax.set_yticks([]) return fig if return_fig else None
[docs]def volcano( data: Union[MultimodalData, UnimodalData, anndata.AnnData], cluster_id: str, de_key: str = "de_res", de_test: str = 'mwu', qval_threshold: float = 0.05, log2fc_threshold: float = 1.0, top_n: int = 20, panel_size: Optional[Tuple[float, float]] = (6, 4), return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, ) -> Union[plt.Figure, None]: """ Generate Volcano plots (-log10 p value vs. log2 fold change) for visualizing DE results. Parameters ----------- data: ``MultimodalData``, ``UnimodalData``, or ``anndata.AnnData`` object. Single cell expression data. cluster_id: ``str`` Cluster ID for the cluster we want to show DE results. There are two cases: * If ``condition`` is ``None`` in ``pg.de_analysis``: Just specify one cluster label in the cluster attribute used in ``pg.de_analysis``. * If ``condition`` is not ``None`` in ``pg.de_analysis``: Specify cluster ID in this format: **"cluster_label:cond_level"**, where **cluster_label** is the cluster label, and **cond_level** is the condition ID. And this shows result of cells within the cluster under the specific condition. de_key: ``str``, optional, default: ``de_res`` The varm keyword for DE results. data.varm[de_key] should store the full DE result table. de_test: ``str``, optional, default: ``mwu`` Which DE test results to show. Use MWU test result by default. qval_threshold: ``float``, optional, default: 0.05. Selected FDR rate. A horizontal line indicating this rate will be shown in the figure. log2fc_threshold: ``float``, optional, default: 1.0 Log2 fold change threshold to highlight biologically interesting genes. Two vertical lines representing negative and positive log2 fold change will be shown. top_n: ``int``, optional, default: ``20`` Number of top DE genes to show names. Genes are ranked by Log2 fold change. panel_size: ``Tuple[float, float]``, optional, default: ``(6, 4)`` The size (width, height) in inches of figure. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. Returns -------- ``Figure`` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples --------- >>> pg.volcano(data, cluster_id = '1', dpi=200) """ if de_key not in data.varm: logger.warning(f"Cannot find DE results '{de_key}'. Please conduct DE analysis first!") return None de_res = data.varm[de_key] fcstr = f"{cluster_id}:log2FC" pstr = f"{cluster_id}:{de_test}_pval" qstr = f"{cluster_id}:{de_test}_qval" columns = de_res.dtype.names if (fcstr not in columns) or (pstr not in columns) or (qstr not in columns): logger.warning(f"Please conduct DE test {de_test} first!") return None log2fc = de_res[fcstr] pvals = de_res[pstr] pvals[pvals == 0.0] = 1e-45 # very small pvalue to avoid log10 0 neglog10p = -np.log10(pvals) yconst = min(neglog10p[de_res[qstr] <= qval_threshold]) fig, ax = _get_subplot_layouts(panel_size=panel_size, dpi=dpi) idxsig = neglog10p >= yconst idxnsig = neglog10p < yconst idxfc = (log2fc <= -log2fc_threshold) | (log2fc >= log2fc_threshold) idxnfc = ~idxfc idx = idxnsig & idxnfc ax.scatter(log2fc[idx], neglog10p[idx], s=5, c='k', marker='o', linewidths=0.5, alpha=0.5, label="NS") idx = idxnsig & idxfc ax.scatter(log2fc[idx], neglog10p[idx], s=5, c='g', marker='o', linewidths=0.5, alpha=0.5, label=r"Log$_2$ FC") idx = idxsig & idxnfc ax.scatter(log2fc[idx], neglog10p[idx], s=5, c='b', marker='o', linewidths=0.5, alpha=0.5, label=r"q-value") idx = idxsig & idxfc ax.scatter(log2fc[idx], neglog10p[idx], s=5, c='r', marker='o', linewidths=0.5, alpha=0.5, label=r"q-value and log$_2$ FC") ax.set_xlabel(r"Log$_2$ fold change") ax.set_ylabel(r"$-$Log$_{10}$ $P$") legend = ax.legend( loc="center", bbox_to_anchor=(0.5, 1.1), frameon=False, fontsize=8, ncol=4, ) for handle in legend.legendHandles: # adjust legend size handle.set_sizes([50.0]) ax.axhline(y = yconst, c = 'k', lw = 0.5, ls = '--') ax.axvline(x = -log2fc_threshold, c = 'k', lw = 0.5, ls = '--') ax.axvline(x = log2fc_threshold, c = 'k', lw = 0.5, ls = '--') texts = [] idx = np.where(idxsig & (log2fc >= log2fc_threshold))[0] posvec = np.argsort(log2fc[idx])[::-1][0:top_n] for pos in posvec: gid = idx[pos] texts.append(ax.text(log2fc[gid], neglog10p[gid], data.var_names[gid], fontsize=5)) idx = np.where(idxsig & (log2fc <= -log2fc_threshold))[0] posvec = np.argsort(log2fc[idx])[0:top_n] for pos in posvec: gid = idx[pos] texts.append(ax.text(log2fc[gid], neglog10p[gid], data.var_names[gid], fontsize=5)) from adjustText import adjust_text adjust_text(texts, arrowprops=dict(arrowstyle='-', color='k', lw=0.5)) return fig if return_fig else None
def rank_plot( data: Union[MultimodalData, UnimodalData, anndata.AnnData], panel_size: Optional[Tuple[float, float]] = (6, 4), return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwargs, ) -> Union[plt.Figure, None]: """Generate a barcode rank plot, which shows the total UMIs against barcode rank (in descending order with respect to total UMIs) Parameters ---------- data : `AnnData` or `UnimodalData` or `MultimodalData` object The main data object. panel_size: `tuple`, optional (default: `(6, 4)`) The plot size (width, height) in inches. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. Returns ------- `Figure` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples -------- >>> fig = pg.rank_plot(data, show = False, dpi = 500) """ fig, ax = _get_subplot_layouts(panel_size=panel_size, dpi=dpi) # default nrows = 1 & ncols = 1 numis = data.X.sum(axis = 1).A1 ords = np.argsort(numis)[::-1] ranks = np.array(range(1, numis.size + 1)) ax.scatter(ranks, numis[ords], c = 'lightgrey', s = 5) ax.set_xscale("log", basex = 10) ax.set_yscale("log", basey = 10) ax.set_xlabel("Barcode rank") ax.set_ylabel("Total UMIs") def _gen_ticklabels(ticks, max_value): label_arr = ['1', '10', '100', '1000', '10K', '100K', '1M', '10M', '100M'] ticklabels = [''] * ticks.size for i in range(ticks.size): exponent = int(round(np.log10(ticks[i]))) if exponent >= 0 and ticks[i] <= max_value: ticklabels[i] = label_arr[exponent] return ticklabels ax.set_xticklabels(_gen_ticklabels(ax.get_xticks(), numis.size)) ax.set_yticklabels(_gen_ticklabels(ax.get_yticks(), numis.max())) return fig if return_fig else None def ridgeplot( data: Union[MultimodalData, UnimodalData], features: Union[str, List[str]], qc_attr: Optional[str] = None, overlap: Optional[float] = 0.5, panel_size: Optional[Tuple[float, float]] = (6, 4), return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwargs, ) -> Union[plt.Figure, None]: """Generate ridge plots, up to 8 features can be shown in one figure. Parameters ---------- data : `UnimodalData` or `MultimodalData` object CITE-Seq or Cyto data. features : `str` or `List[str]` One or more features to display. qc_attr: `str`, optional, default None If not None, only data.obs[qc_attr] == True are used. overlap: `float`, default 0.5 Overlap between adjacent ridge plots (top and bottom). panel_size: `tuple`, optional (default: `(6, 4)`) The plot size (width, height) in inches. return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. Returns ------- `Figure` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples -------- >>> fig = pg.ridgeplot(data, features = ['CD8', 'CD4', 'CD3'], show = False, dpi = 500) """ sns.set(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}) idx = data.obs[qc_attr].values if qc_attr is not None else np.ones(data.shape[0], dtype = bool) if isinstance(features, str): features = [features] if len(features) > 8: logger.warning("With with_control == False, only up to 8 features are allowed!") return None exprs = [] feats = [] size = idx.sum() for feature in features: fid = data.var_names.get_loc(feature) exprs.append(data.get_matrix("arcsinh.transformed")[idx, fid].toarray()[:, 0]) feats.append(np.repeat(feature, size)) df = pd.DataFrame({"expression": np.concatenate(exprs), "feature": np.concatenate(feats)}) g = sns.FacetGrid(df, row="feature", hue="feature", aspect=8, height=1.0) g.map(sns.kdeplot, "expression", clip_on=False, shade=True, alpha=1, lw=1.5) g.map(sns.kdeplot, "expression", clip_on=False, color="k", lw=1) g.map(plt.axhline, y=0, lw=1, clip_on=False) def _set_label(value, color, label): ax = plt.gca() ax.text(0, 0.2, label, color="k", ha="right", va="center", transform=ax.transAxes) g.map(_set_label, "expression") g.fig.subplots_adjust(hspace=-overlap) g.set_titles("") g.set_xlabels("") g.set(yticks=[]) g.despine(bottom=True, left=True) g.fig.set_dpi(dpi) g.fig.set_figwidth(panel_size[0]) g.fig.set_figheight(panel_size[1]) sns.reset_orig() return g.fig if return_fig else None def doublet_plot( scores: List[float], codes: List[int], code_names: Optional[List[str]] = ["singlet", "singlet2", "doublet"], panel_size: Optional[Tuple[float, float]] = (4, 3), left: Optional[float] = 0.2, bottom: Optional[float] = 0.2, wspace: Optional[float] = 0.2, hspace: Optional[float] = 0.2, return_fig: Optional[bool] = False, dpi: Optional[float] = 300.0, **kwargs, ) -> Union[plt.Figure, None]: """Generate KDE for doublet predictions Parameters ---------- scores: `List[float]`, Doublet scores. codes: `List[int]` Doublet type codes. code_names: `List[str]`, optional (default: `["singlet", "singlet2", "doublet"]`) Doublet type names. panel_size: `tuple`, optional (default: `(4, 3)`) The plot size (width, height) in inches. left: `float`, optional (default: `0.15`) This parameter sets the figure's left margin as a fraction of panel's width (left * panel_size[0]). bottom: `float`, optional (default: `0.15`) This parameter sets the figure's bottom margin as a fraction of panel's height (bottom * panel_size[1]). wspace: `float`, optional (default: `0.3`) This parameter sets the width between panels and also the figure's right margin as a fraction of panel's width (wspace * panel_size[0]). hspace: `float`, optional (defualt: `0.15`) This parameter sets the height between panels and also the figure's top margin as a fraction of panel's height (hspace * panel_size[1]). return_fig: ``bool``, optional, default: ``False`` Return a ``Figure`` object if ``True``; return ``None`` otherwise. dpi: ``float``, optional, default: ``300.0`` The resolution in dots per inch. Returns ------- `Figure` object A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True`` Examples -------- """ from scipy.stats import gaussian_kde # constants alpha = 0.7 eps = 1e-6 fig, axes = _get_subplot_layouts(nrows = 2, ncols = 2, panel_size=panel_size, dpi=dpi, left=left, bottom=bottom, wspace=wspace, hspace=hspace, sharex=False, sharey=False) # Panel 1, KDE of scores ax = axes[0, 0] x = np.linspace(scores.min() - 0.05, scores.max() + 0.05, 200) # generate x coordinates kde = gaussian_kde(scores) y = kde(x) ax.plot(x, y, '-', c='k') ax.set_ylim(bottom = 0.0) ax.set_xlabel('doublet score') ax.set_ylabel('density') # Panel 2, KDE of log-transformed scores ax = axes[0, 1] scores = np.log(scores) x = np.linspace(scores.min() - 0.5, scores.max() + 0.5, 200) # generate x coordinates kde = gaussian_kde(scores) y = kde(x) ax.plot(x, y, '-', c='k') ax.set_ylim(bottom = 0.0) ax.set_xlabel('log doublet score') ax.set_ylabel('density') # Panel 3, KDE of each cluster ax = axes[1, 0] # Partition by clusters idx1 = codes == 0 idx2 = codes == 1 idx3 = codes == 2 # Get scale factor for each cluster scale1 = idx1.sum() * 1.0 / codes.size scale2 = idx2.sum() * 1.0 / codes.size scale3 = idx3.sum() * 1.0 / codes.size # kde for singlets1 kde1 = gaussian_kde(scores[idx1]) y1 = kde1(x) * scale1 idxs1 = y1 > eps ax.plot(x[idxs1], y1[idxs1], '-', c = 'orange', label = code_names[0]) # kde for singlets2 kde2 = gaussian_kde(scores[idx2]) y2 = kde2(x) * scale2 idxs2 = y2 > eps ax.plot(x[idxs2], y2[idxs2], '-', c = 'green', label = code_names[1]) # kde for doublets if idx3.sum() > 0: scores3 = scores[idx3] if np.std(scores3) < eps: # if too small, add jitters scores3 += np.random.normal(scale = 0.01, size = scores3.size) kde3 = gaussian_kde(scores3) y3 = kde3(x) * scale3 idxs3 = y3 > eps ax.plot(x[idxs3], y3[idxs3], '-', c = 'red', label = code_names[2]) # Set ylim bottom to 0 and show legend ax.set_ylim(bottom = 0) ax.set_xlabel('log doublet score') ax.set_ylabel('density') ax.legend(loc = 'best') # Panel 4, overlay the overall KDE with cluster-specific KDEs ax = axes[1, 1] ax.plot(x, y, '-', c='k', label = 'overall') ax.plot(x[idxs1], y1[idxs1], '-', c = 'orange', alpha = alpha, label = 'singlet1') ax.plot(x[idxs2], y2[idxs2], '-', c = 'green', alpha = alpha, label = 'singlet2') if idx3.sum() > 0: ax.plot(x[idxs3], y3[idxs3], '-', c = 'red', alpha = alpha, label = 'doublet') ax.set_ylim(bottom = 0) ax.set_xlabel('log doublet score') ax.set_ylabel('density') ax.legend(loc = 'best') return fig if return_fig else None