Source code for scplot.plot

from typing import Union, List, Tuple, Callable, Set

import colorcet
import holoviews as hv
import hvplot.pandas
import numpy as np
import pandas as pd
import scipy.sparse
from anndata import AnnData
from holoviews import dim
from holoviews.plotting.bokeh.callbacks import LinkCallback
from holoviews.plotting.links import Link
from natsort import natsorted
from pandas.api.types import CategoricalDtype


# def sort_by_values(summarized_df):
#     # sort rows by expression
#     sorted_df = summarized_df.sort_values(axis=0, by=list(summarized_df.columns.values), ascending=False)
#     indices = [summarized_df.index.get_loc(c) for c in sorted_df.index]
#     return indices

# Doesn't work for scatter plots colored by categorical variable
class __BrushLinkRange(Link):
    _requires_target = True


class __BrushLinkCallbackRange(LinkCallback):
    source_model = 'selected'
    source_handles = ['cds', 'glyph']
    on_source_changes = ['indices']
    target_model = 'selected'
    target_handles = ['cds', 'glyph']

    source_code = """
       
        target_selected.indices = source_selected.indices;
    """


__BrushLinkRange.register_callback('bokeh', __BrushLinkCallbackRange)


class __BrushLink(Link):
    _requires_target = True


class __BrushLinkCallback(LinkCallback):
    source_model = 'selected'
    source_handles = ['cds']
    on_source_changes = ['indices']
    target_model = 'selected'

    source_code = """
        target_selected.indices = source_selected.indices;
    """


__BrushLink.register_callback('bokeh', __BrushLinkCallback)


def __get_marker_size(count):
    return min(12, (240000.0 if count > 300000 else 120000.0) / count)


def __sort_category(df, by):
    if not pd.api.types.is_categorical_dtype(df[by]):
        df[by] = df[by].astype('category')
    if not df[by].dtype.ordered:
        df[by] = df[by].astype(CategoricalDtype(natsorted(df[by].dtype.categories), ordered=True))


def __auto_bin(df, nbins, width, height):
    if nbins == -1 and df.shape[0] >= 500000:
        nbins = int(max(200, min(width, height) / 2))
    return nbins


def __create_hover_tool(df, keywords: dict, exclude: List, current: str = None, whitelist: List = None):
    """
   Generate hover tool.

   Args:
       keywords: Keyword dict
       exclude: List of columns in df to exclude.
       current: Key in df that is plotted to show 1st in tooltip
   """

    try:
        import bokeh.models
        import holoviews.core.util
        hover_cols = []
        for column in df.columns:
            if column not in exclude and column != current and column not in hover_cols and (
                    whitelist is None or column in whitelist):
                hover_cols.append(column)
        keywords['hover_cols'] = hover_cols
        tooltips = []
        if current is not None:
            tooltips.append((current, '@{' + holoviews.core.util.dimension_sanitizer(current) + '}'))
        for hover_col in hover_cols:
            tooltips.append((hover_col, '@{' + holoviews.core.util.dimension_sanitizer(hover_col) + '}'))
        tools = keywords.get('tools', [])
        keywords['tools'] = tools + [bokeh.models.HoverTool(tooltips=tooltips)]
    except ModuleNotFoundError:
        pass


scanpy_default_102 = ['#FFFF00', '#1CE6FF', '#FF34FF', '#FF4A46', '#008941', '#006FA6', '#A30059', '#FFDBE5', '#7A4900',
                      '#0000A6', '#63FFAC', '#B79762', '#004D43', '#8FB0FF', '#997D87', '#5A0007', '#809693', '#6A3A4C',
                      '#1B4400', '#4FC601', '#3B5DFF', '#4A3B53', '#FF2F80', '#61615A', '#BA0900', '#6B7900', '#00C2A0',
                      '#FFAA92', '#FF90C9', '#B903AA', '#D16100', '#DDEFFF', '#000035', '#7B4F4B', '#A1C299', '#300018',
                      '#0AA6D8', '#013349', '#00846F', '#372101', '#FFB500', '#C2FFED', '#A079BF', '#CC0744', '#C0B9B2',
                      '#C2FF99', '#001E09', '#00489C', '#6F0062', '#0CBD66', '#EEC3FF', '#456D75', '#B77B68', '#7A87A1',
                      '#788D66', '#885578', '#FAD09F', '#FF8A9A', '#D157A0', '#BEC459', '#456648', '#0086ED', '#886F4C',
                      '#34362D', '#B4A8BD', '#00A6AA', '#452C2C', '#636375', '#A3C8C9', '#FF913F', '#938A81', '#575329',
                      '#00FECF', '#B05B6F', '#8CD0FF', '#3B9700', '#04F757', '#C8A1A1', '#1E6E00', '#7900D7', '#A77500',
                      '#6367A9', '#A05837', '#6B002C', '#772600', '#D790FF', '#9B9700', '#549E79', '#FFF69F', '#201625',
                      '#72418F', '#BC23FF', '#99ADC0', '#3A2465', '#922329', '#5B4534', '#FDE8DC', '#404E55', '#0089A3',
                      '#CB7E98', '#A4E804', '#324E72']
scanpy_default_20 = ['#1f77b4', '#ff7f0e', '#279e68', '#d62728', '#aa40fc', '#8c564b', '#e377c2', '#b5bd61',
                     '#17becf', '#aec7e8', '#ffbb78', '#98df8a', '#ff9896', '#c5b0d5', '#c49c94', '#f7b6d2',
                     '#dbdb8d', '#9edae5', '#ad494a', '#8c6d31']
scanpy_default_28 = ['#023fa5', '#7d87b9', '#bec1d4', '#d6bcc0', '#bb7784', '#8e063b', '#4a6fe3', '#8595e1',
                     '#b5bbe3', '#e6afb9', '#e07b91', '#d33f6a', '#11c638', '#8dd593', '#c6dec7', '#ead3c6',
                     '#f0b98d', '#ef9708', '#0fcfc0', '#9cded6', '#d5eae7', '#f3e1eb', '#f6c4e1', '#f79cd4',
                     '#7f7f7f', '#c7c7c7', '#1CE6FF', '#336600']


def __get_scanpy_colors(series):
    from matplotlib import rcParams
    categories = series.cat.categories
    length = len(categories)

    # check if default matplotlib palette has enough colors
    if len(rcParams['axes.prop_cycle'].by_key()['color']) >= length:
        cc = rcParams['axes.prop_cycle']()
        palette = [next(cc)['color'] for _ in range(length)]

    else:
        if length <= 20:
            palette = scanpy_default_20
        elif length <= 28:
            palette = scanpy_default_28
        elif length <= len(scanpy_default_102):  # 103 colors
            palette = scanpy_default_102
        else:
            palette = ['grey' for _ in range(length)]
    return palette


def __fix_cmap(df, key, palette=None):
    # check for missing values in palette
    palette = palette.copy()
    colors = colorcet.b_glasbey_category10
    ncolors = len(colors)
    i = 0
    series = df[key]
    if not pd.api.types.is_categorical_dtype(series):
        # hvplot does not handle boolean type for colors
        series = series.astype('category') if not pd.api.types.is_bool_dtype(series) else series.astype(str).astype(
            'category')
        df[key] = series
    for c in series.cat.categories:
        if c not in palette:
            palette[c] = colors[i % ncolors]
            i += 1
    return palette


def __get_category_cmap(adata, df, key):
    color_key = f"{key}_colors"
    colors = None
    series = df[key]
    if not pd.api.types.is_categorical_dtype(series):
        # hvplot does not handle boolean type for colors
        series = series.astype('category') if not pd.api.types.is_bool_dtype(series) else series.astype(str).astype(
            'category')
        df[key] = series

    if color_key in adata.uns:
        adata_categories = series.cat.categories
        if len(adata.uns[color_key]) == len(adata_categories):
            colors = adata.uns[color_key]
    if colors is None:
        colors = __get_scanpy_colors(series)
    ncolors = len(colors)
    color_map = {}
    adata_categories = series.cat.categories
    for i in range(len(adata_categories)):
        color_map[adata_categories[i]] = colors[i % ncolors]
    return color_map


def __create_bounds_stream(source):
    stream = hv.streams.BoundsXY(source=source)
    return stream


def get_bounds(plot):
    if isinstance(plot, hv.Layout):
        if plot.shape == (1, 1):
            plot = plot[0, 0]
        else:
            raise ValueError('Please select the plot in the layout')
    if hasattr(plot, 'bounds_stream'):
        return plot.bounds_stream.bounds


def __to_list(vals):
    if isinstance(vals, np.ndarray):
        vals = vals.tolist()
    elif isinstance(vals, tuple):
        vals = list(vals)
    elif not isinstance(vals, list):
        vals = [vals]
    return vals


def __size_legend(size_min, size_max, dot_min, dot_max, size_tick_labels_format, size_ticks):
    size_ticks_pixels = np.interp(size_ticks, (size_min, size_max), (dot_min, dot_max))
    size_tick_labels = [size_tick_labels_format.format(x) for x in size_ticks]
    points = hv.Points(
        {'x': np.repeat(0.15, len(size_ticks)), 'y': np.arange(len(size_ticks), 0, -1),
         'size': size_ticks_pixels},
        vdims='size').opts(xaxis=None, color='black', yaxis=None, size=dim('size'))
    labels = hv.Labels(
        {'x': np.repeat(0.3, len(size_ticks)), 'y': np.arange(len(size_ticks), 0, -1),
         'text': size_tick_labels},
        ['x', 'y'], 'text').opts(text_align='left', text_font_size='9pt')
    overlay = (points * labels)
    overlay.opts(width=dot_max + 100, height=int(len(size_ticks) * (dot_max + 12)), xlim=(0, 1),
        ylim=(0, len(size_ticks) + 1),
        invert_yaxis=True, shared_axes=False, show_frame=False)
    return overlay


def __get_raw(adata, use_raw):
    adata_raw = adata
    if use_raw or (use_raw is None and adata.raw is not None):
        if adata.raw is None:
            raise ValueError('Raw data not found')
        adata_raw = adata.raw
    return adata_raw


def __get_df(adata, adata_raw, keys, df=None, is_obs=None):
    if df is not None and is_obs is None:
        raise ValueError('Please provide is_obs when df is provided.')
    for i in range(len(keys)):
        key = keys[i]
        if df is None:
            if isinstance(key, np.ndarray):
                is_obs = len(key) == adata.shape[0]
            else:
                is_obs = key not in adata.var
            df = pd.DataFrame(data=dict(id=(adata.obs.index.values if is_obs else adata.var.index.values)))
        if isinstance(key, np.ndarray):
            values = key
            key = str(i)
            keys[i] = key
        elif key in adata_raw.var_names and is_obs:
            X = adata_raw.obs_vector(key)
            #  X = adata_raw[:, key].X
            if scipy.sparse.issparse(X):
                X = X.toarray()
            values = X
        elif key in adata.obs and is_obs:
            values = adata.obs[key].values
        elif key in adata.var and not is_obs:
            values = adata.var[key].values
        else:
            raise ValueError('{} not found'.format(key))
        df[key] = values
    return df


def mode_and_purity(x):
    value_counts = x.value_counts(sort=False)
    largest = value_counts.nlargest(1)
    purity = largest[0] / value_counts.sum()
    return largest.index[0], purity


def __bin(df, nbins, coordinate_columns, reduce_function, coordinate_column_to_range=None):
    # replace coordinates with bin
    for view_column_name in coordinate_columns:  # add view column _bin
        values = df[view_column_name].values
        view_column_range = coordinate_column_to_range.get(view_column_name,
            None) if coordinate_column_to_range is not None else None
        column_min = values.min() if view_column_range is None else view_column_range[0]
        column_max = values.max() if view_column_range is None else view_column_range[1]
        df[view_column_name] = np.floor(
            np.interp(values, [column_min, column_max], [0, nbins - 1])).astype(int)

    agg_func = {}
    for column in df:
        if column not in coordinate_columns:
            if column == 'count':
                agg_func[column] = 'sum'
            elif pd.api.types.is_numeric_dtype(df[column]):
                agg_func[column] = reduce_function
            else:  # pd.api.types.is_categorical_dtype(df[column]):
                agg_func[column] = mode_and_purity
    return df.groupby(coordinate_columns, as_index=False).agg(agg_func), df[coordinate_columns]


[docs]def violin(adata: AnnData, keys: Union[str, List[str], Tuple[str]], by: str = None, width: int = 300, cmap: Union[str, List[str], Tuple[str]] = None, cols: int = None, use_raw: bool = None, **kwds) -> hv.core.element.Element: """ Generate a violin plot. Args: adata: Annotated data matrix. keys: Keys for accessing variables of adata.var_names, field of adata.var, or field of adata.obs by: Group plot by specified observation. width: Plot width. cmap: Color map name (hv.plotting.list_cmaps()) or a list of hex colors. See http://holoviews.org/user_guide/Styling_Plots.html for more information. cols: Number of columns for laying out multiple plots use_raw: Use `raw` attribute of `adata` if present. """ if cols is None: cols = 3 adata_raw = __get_raw(adata, use_raw) keys = __to_list(keys) df = __get_df(adata, adata_raw, keys + ([] if by is None else [by])) if by is not None: cmap = __get_category_cmap(adata, df, by) if cmap is None else __fix_cmap(df, by, cmap) plots = [] keywords = dict(padding=0.02, cmap=cmap, rot=90) keywords.update(kwds) if by is not None: __sort_category(df, by) for key in keys: p = df.hvplot.violin(key, width=width, by=by, violin_color=by, **keywords) plots.append(p) layout = hv.Layout(plots).cols(cols) layout.df = df return layout
[docs]def heatmap(adata: AnnData, keys: Union[str, List[str], Tuple[str]], by: str, reduce_function: Callable[[np.ndarray], float] = np.mean, use_raw: bool = None, cmap: Union[str, List[str], Tuple[str]] = 'Reds', **kwds) -> hv.core.element.Element: """ Generate a heatmap. Args: adata: Annotated data matrix. keys: Keys for accessing variables of adata.var_names by: Group plot by specified observation. cmap: Color map name (hv.plotting.list_cmaps()) or a list of hex colors. See http://holoviews.org/user_guide/Styling_Plots.html for more information. reduce_function: Function to summarize an element in the heatmap use_raw: Use `raw` attribute of `adata` if present. """ adata_raw = __get_raw(adata, use_raw) keys = __to_list(keys) df = None keywords = dict(colorbar=True, xlabel='', cmap=cmap, ylabel=str(by), rot=90) keywords.update(kwds) for key in keys: X = adata_raw.obs_vector(key) # X = adata_raw[:, key].X if scipy.sparse.issparse(X): X = X.toarray() _df = pd.DataFrame(X, columns=['value']) _df['feature'] = key _df[by] = adata.obs[by].values df = _df if df is None else pd.concat((df, _df)) __sort_category(df, by) df['feature'] = df['feature'].astype(CategoricalDtype(keys, ordered=True)) return df.hvplot.heatmap(x='feature', y=by, C='value', reduce_function=reduce_function, **keywords)
[docs]def scatter(adata: AnnData, x: str, y: str, color: str = None, size: Union[int, str] = None, dot_min=2, dot_max=14, use_raw: bool = None, sort: bool = True, width: int = 400, height: int = 400, nbins: int = -1, reduce_function: Callable[[np.array], float] = np.max, cmap: Union[str, List[str], Tuple[str]] = None, palette: Union[str, List[str], Tuple[str]] = None, **kwds) -> hv.core.element.Element: """ Generate a scatter plot. Args: adata: Annotated data matrix. x: Key for accessing variables of adata.var_names, field of adata.var, or field of adata.obs y: Key for accessing variables of adata.var_names, field of adata.var, or field of adata.obs cmap: Color map for continuous variables. palette: Color map for categorical variables. color: Field in .var_names, adata.var, or adata.obs to color the points by. sort: Plot higher color by values on top of lower values. width: Chart width. height: Chart height. size: Field in .var_names, adata.var, or adata.obs to size the points by or a pixel size. dot_min: Minimum dot size when sizing points by a field. dot_max: Maximum dot size when sizing points by a field. use_raw: Use `raw` attribute of `adata` if present. nbins: Number of bins used to summarize plot on a grid. Useful for large datasets. Negative one means automatically bin the plot. reduce_function: Function used to summarize overlapping cells if nbins is specified """ return __scatter(adata=adata, x=x, y=y, color=color, size=size, dot_min=dot_min, dot_max=dot_max, use_raw=use_raw, sort=sort, width=width, height=height, nbins=nbins, reduce_function=reduce_function, cmap=cmap, palette=palette, is_scatter=True, **kwds)
def line(adata: AnnData, x: str, y: str, use_raw: bool = None, width: int = 400, height: int = 400, nbins: int = None, reduce_function: Callable[[np.array], float] = np.max, **kwds) -> hv.core.element.Element: """ Generate a scatter plot. Args: adata: Annotated data matrix. x: Key for accessing variables of adata.var_names, field of adata.var, or field of adata.obs y: Key for accessing variables of adata.var_names, field of adata.var, or field of adata.obs use_raw: Use `raw` attribute of `adata` if present. width: Chart width. height: Chart height. nbins: Number of bins used to summarize plot on a grid. Useful for large datasets. reduce_function: Function used to summarize overlapping cells if nbins is specified """ return __scatter(adata=adata, x=x, y=y, use_raw=use_raw, sort=False, width=width, height=height, nbins=nbins, reduce_function=reduce_function, is_scatter=False, **kwds) def __scatter(adata: AnnData, x: str, y: str, color=None, size: Union[int, str] = None, dot_min=2, dot_max=14, use_raw: bool = None, sort: bool = True, width: int = 400, height: int = 400, nbins: int = None, reduce_function: Callable[[np.array], float] = np.max, cmap: Union[str, List[str], Tuple[str]] = None, palette: Union[str, List[str], Tuple[str]] = None, is_scatter=True, **kwds) -> hv.core.element.Element: """ Generate a scatter plot. Args: adata: Annotated data matrix. x: Key for accessing variables of adata.var_names, field of adata.var, or field of adata.obs y: Key for accessing variables of adata.var_names, field of adata.var, or field of adata.obs cmap: Color map for continous variables. palette: Color map for categorical variables. color: Field in .var_names, adata.var, or adata.obs to color the points by. sort: Plot higher color by values on top of lower values. size: Field in .var_names, adata.var, or adata.obs to size the points by or a pixel size. dot_min: Minimum dot size when sizing points by a field. dot_max: Maximum dot size when sizing points by a field. use_raw: Use `raw` attribute of `adata` if present. nbins: Number of bins used to summarize plot on a grid. Useful for large datasets. reduce_function: Function used to summarize overlapping cells if nbins is specified """ adata_raw = __get_raw(adata, use_raw) is_size_by = size is not None and is_scatter is_color_by = color is not None and is_scatter keywords = dict(fontsize=dict(title=9), nonselection_alpha=0.1, padding=0.02, xaxis=True, yaxis=True, width=width, height=height, alpha=1, tools=['box_select']) keywords.update(kwds) keys = [x, y] if is_color_by: keys.append(color) if is_size_by: keys.append(size) df = __get_df(adata, adata_raw, keys) # keys might have been modified by __get_df if key was an array instead of a string x = keys[0] y = keys[1] if is_color_by: color = keys[2] if is_size_by: size = keys[3 if is_color_by else 2] nbins = __auto_bin(df, nbins, width, height) bin_data = nbins is not None and nbins > 0 df_with_coords = df hover_cols = keywords.get('hover_cols', []) if bin_data: df['count'] = 1.0 hover_cols.append('count') df, df_with_coords = __bin(df, nbins=nbins, coordinate_columns=[x, y], reduce_function=reduce_function) else: hover_cols.append('id') keywords['hover_cols'] = hover_cols if is_color_by: is_categorical = bin_data and pd.api.types.is_object_dtype(df[color]) if is_categorical: df = pd.DataFrame(df[color].tolist(), columns=[color, str(color) + '_purity']).join(df, rsuffix='orig_') is_color_by_numeric = not pd.api.types.is_bool_dtype(df[color]) and pd.api.types.is_numeric_dtype(df[color]) __fix_scatter_colors(adata, df, color, is_color_by_numeric, cmap, palette, keywords) if is_color_by_numeric: keywords['c'] = color keywords.update(dict(colorbar=True)) if sort: df = df.sort_values(by=color) else: keywords['by'] = color if is_size_by: size_min = df[size].min() size_max = df[size].max() size_pixels = np.interp(df[size], (size_min, size_max), (dot_min, dot_max)) df['pixels'] = size_pixels keywords['s'] = 'pixels' hover_cols = keywords.get('hover_cols', []) hover_cols.append(size) keywords['hover_cols'] = hover_cols if is_scatter: p = df.hvplot.scatter(x=x, y=y, **keywords) else: # line plot df = df.sort_values(by=x) p = df.hvplot.line(x=x, y=y, **keywords) if is_size_by: return_value = (p + __size_legend(size_min=size_min, size_max=size_max, dot_min=dot_min, dot_max=dot_max, size_tick_labels_format='{0:.1f}', size_ticks=np.array([size_min, (size_min + size_max) / 2, size_max]))) else: return_value = p return_value.df = df_with_coords return return_value
[docs]def dotplot(adata: AnnData, keys: Union[str, List[str], Tuple[str]], by: str, reduce_function: Callable[[np.ndarray], float] = np.mean, fraction_min: float = 0, fraction_max: float = None, dot_min: int = 1, dot_max: int = 26, use_raw: bool = None, cmap: Union[str, List[str], Tuple[str]] = 'Reds', sort_function: Callable[[pd.DataFrame], List[str]] = None, **kwds) -> hv.core.element.Element: """ Generate a dot plot. Args: adata: Annotated data matrix. keys: Keys for accessing variables of adata.var_names by: Group plot by specified observation. cmap: Color map name (hv.plotting.list_cmaps()) or a list of hex colors. See http://holoviews.org/user_guide/Styling_Plots.html for more information. reduce_function: Function to summarize an element in the heatmap fraction_min: Minimum fraction expressed value. fraction_max: Maximum fraction expressed value. dot_min: Minimum pixel dot size. dot_max: Maximum pixel dot size. use_raw: Use `raw` attribute of `adata` if present. sort_function: Optional function that accepts summarized data frame and returns a list of row indices in the order to render in the heatmap. """ adata_raw = __get_raw(adata, use_raw) keys = __to_list(keys) keywords = dict(colorbar=True, ylabel=str(by), xlabel='', padding=0, rot=90, cmap=cmap) keywords.update(kwds) X = adata_raw[:, keys].X if scipy.sparse.issparse(X): X = X.toarray() df = pd.DataFrame(data=X, columns=keys) df[by] = adata.obs[by].values def non_zero(g): return np.count_nonzero(g) / g.shape[0] summarized_df = df.groupby(by).aggregate([reduce_function, non_zero]) if sort_function is not None: # sort categories row_indices = sort_function(summarized_df) summarized_df = summarized_df.iloc[row_indices] else: summarized_df = summarized_df.loc[natsorted(summarized_df.index, reverse=True)] mean_columns = [] frac_columns = [] for i in range(len(summarized_df.columns)): if i % 2 == 0: mean_columns.append(summarized_df.columns[i]) else: frac_columns.append(summarized_df.columns[i]) # features on columns, by on rows fraction_df = summarized_df[frac_columns] mean_df = summarized_df[mean_columns] y, x = np.indices(mean_df.shape) y = y.flatten() x = x.flatten() fraction = fraction_df.values.flatten() if fraction_max is None: fraction_max = fraction.max() pixels = np.interp(fraction, (fraction_min, fraction_max), (dot_min, dot_max)) pixels = pixels * pixels # hvplot takes the sqrt of size summary_values = mean_df.values.flatten() xlabel = [keys[i] for i in range(len(keys))] ylabel = [str(summarized_df.index[i]) for i in range(len(summarized_df.index))] dotplot_df = pd.DataFrame( data=dict(x=x, y=y, value=summary_values, pixels=pixels, fraction=fraction, xlabel=np.array(xlabel)[x], ylabel=np.array(ylabel)[y])) xticks = [(i, keys[i]) for i in range(len(keys))] yticks = [(i, str(summarized_df.index[i])) for i in range(len(summarized_df.index))] # note we take the max label string length as an approximation of width of labels in pixels keywords['width'] = int( np.ceil(((dot_max + 1) + 4) * len(xticks) + dotplot_df['ylabel'].str.len().max()) + dot_max + 100) keywords['height'] = int(np.ceil(((dot_max + 1) + 4) * len(yticks) + dotplot_df['xlabel'].str.len().max()) + 50) try: import bokeh.models keywords['hover_cols'] = ['fraction', 'xlabel', 'ylabel'] keywords['tools'] = [bokeh.models.HoverTool(tooltips=[ ('fraction', '@fraction'), ('value', '@value'), ('x', '@xlabel'), ('y', '@ylabel') ])] except ModuleNotFoundError: pass p = dotplot_df.hvplot.scatter(x='x', y='y', xlim=(-1, len(xticks)), ylim=(-1, len(yticks)), c='value', s='pixels', xticks=xticks, yticks=yticks, **keywords) size_range = fraction_max - fraction_min if 0.3 < size_range <= 0.6: size_legend_step = 0.1 elif size_range <= 0.3: size_legend_step = 0.05 else: size_legend_step = 0.2 size_ticks = np.arange(fraction_min if fraction_min > 0 or fraction_min > 0 else fraction_min + size_legend_step, fraction_max + size_legend_step, size_legend_step) result = p + __size_legend(size_min=fraction_min, size_max=fraction_max, dot_min=dot_min, dot_max=dot_max, size_tick_labels_format='{:.0%}', size_ticks=size_ticks) result.df = dotplot_df return result
[docs]def scatter_matrix(adata: AnnData, keys: Union[str, List[str], Tuple[str]], color=None, use_raw: bool = None, **kwds) -> hv.core.element.Element: """ Generate a scatter plot matrix. Args: adata: Annotated data matrix. keys: Key for accessing variables of adata.var_names or a field of adata.obs color: Key in adata.obs to color points by. use_raw: Use `raw` attribute of `adata` if present. """ adata_raw = __get_raw(adata, use_raw) keys = __to_list(keys) if color is not None: keys.append(color) df = __get_df(adata, adata_raw, keys) if color is not None and pd.api.types.is_bool_dtype(df[color]): df[color] = df[color].astype(str).astype('category') p = hvplot.scatter_matrix(df, c=color, **kwds) p.df = df return p
def __fix_scatter_colors(adata, df_to_plot, key, is_color_by_numeric, cmap, palette, keywords): color_keyword_keep = 'cmap' # dict of colors to use with 'c' color_keyword_delete = 'color' # key in df containing colors to use with 'by' # 'c' does not show clickable legend, but respects dict color map, 'by' shows legend but does not respect color map, use array of colors if is_color_by_numeric: color_map = 'viridis' if cmap is None else cmap else: __sort_category(df_to_plot, key) # for legend color_map = __get_category_cmap(adata, df_to_plot, key) if palette is None else __fix_cmap(df_to_plot, key, palette) df_to_plot['__color'] = df_to_plot[key].apply(lambda x: color_map[x]) color_map = '__color' color_keyword_keep = 'color' color_keyword_delete = 'cmap' keywords[color_keyword_keep] = color_map if color_keyword_delete in keywords: del keywords[color_keyword_delete]
[docs]def embedding(adata: AnnData, basis: Union[str, List[str], Tuple[str]], keys: Union[None, str, List[str], Tuple[str]] = None, cmap: Union[str, List[str], Tuple[str]] = None, palette: Union[str, List[str], Tuple[str]] = None, alpha: float = 1, size: float = None, width: int = 400, height: int = 400, sort: bool = True, cols: int = None, use_raw: bool = None, nbins: int = -1, reduce_function: Callable[[np.array], float] = np.max, brush_categorical: bool = False, legend: str = 'right', tooltips: Union[str, List[str], Tuple[str]] = None, legend_font_size: Union[int, str] = None, opacity_min: float = 0, opacity_max: float = 1, **kwds) -> hv.core.element.Element: """ Generate an embedding plot. Args: adata: Annotated data matrix. keys: Key for accessing variables of adata.var_names or a field of adata.obs used to color the plot. Can also use `count` to plot cell count when binning. basis: String in adata.obsm containing coordinates. alpha: Points alpha value. size: Point pixel size. sort: Plot higher values on top of lower values. Disable for linked brushing. brush_categorical: Enable linked brushing on categorical variables (disables categorical legend). cmap: Color map for continous variables. palette: Color map for categorical variables. nbins: Number of bins used to summarize plot on a grid. Useful for large datasets. Negative one means automatically bin the plot. reduce_function: Function used to summarize overlapping cells if nbins is specified. cols: Number of columns for laying out multiple plots width: Plot width. height: Plot height. tooltips: List of additional fields to show on hover. legend: `top', 'bottom', 'left', 'right', or 'data' to draw labels for categorical features on the plot. legend_font_size: Font size for `labels_on_data` use_raw: Use `raw` attribute of `adata` if present. opacity_min: Minimum value for encoding categorical data purity when binning using opacity. opacity_max: Maximum value for encoding categorical data purity when binning using opacity. """ if keys is None: keys = [] basis = __to_list(basis) adata_raw = __get_raw(adata, use_raw) keys = __to_list(keys) if tooltips is None: tooltips = [] tooltips = __to_list(tooltips) if legend_font_size is None: legend_font_size = '12pt' labels_on_data = legend == 'data' keywords = dict(fontsize=dict(title=9, legend=legend_font_size), padding=0.02 if not labels_on_data else 0.05, xaxis=False, yaxis=False, nonselection_alpha=0.1, tools=['box_select'], legend=not legend == 'data') plots = [] charts_to_brush = [] keywords.update(kwds) data_df = __get_df(adata, adata_raw, keys + tooltips, is_obs=True) density = len(keys) == 0 if density: keys = ['count'] for b in basis: df = data_df.copy() coordinate_columns = ['X_' + b + c for c in ['1', '2']] df = pd.concat((df, pd.DataFrame(adata.obsm['X_' + b][:, 0:2], columns=coordinate_columns)), axis=1) nbins = __auto_bin(df, nbins, width, height) df_with_coords = df bin_data = nbins is not None and nbins > 0 if bin_data or density: df['count'] = 1.0 if bin_data: df, df_with_coords = __bin(df, nbins=nbins, coordinate_columns=coordinate_columns, reduce_function=reduce_function) if size is None: size = __get_marker_size(df.shape[0]) for key in keys: is_color_by_numeric = False is_categorical_binned = bin_data and pd.api.types.is_object_dtype(df[key]) if is_categorical_binned: df_to_plot = pd.DataFrame(df[key].tolist(), columns=[key, str(key) + '_purity']).join(df, rsuffix='orig_') else: is_color_by_numeric = not pd.api.types.is_bool_dtype(df[key]) and pd.api.types.is_numeric_dtype(df[key]) # if not is_color_by_numeric: # __fix_color_by_data_type(df, key) df_to_plot = df if sort and is_color_by_numeric: df_to_plot = df.sort_values(by=key) __create_hover_tool(df_to_plot, keywords, exclude=coordinate_columns, current=key) use_c = is_color_by_numeric or brush_categorical __fix_scatter_colors(adata, df_to_plot, key, is_color_by_numeric, cmap, palette, keywords) if is_categorical_binned: point_opacity = np.interp(df_to_plot[str(key) + '_purity'], (df_to_plot[str(key) + '_purity'].min(), df_to_plot[str(key) + '_purity'].max()), (opacity_min, opacity_max)) df_to_plot['__point_opacity'] = point_opacity p = df_to_plot.hvplot.scatter( x=coordinate_columns[0], y=coordinate_columns[1], title=str(key), c=key if use_c else None, by=key if not use_c else None, size=size, alpha='__point_opacity' if is_categorical_binned else alpha, colorbar=is_color_by_numeric, width=width, height=height, **keywords) bounds_stream = __create_bounds_stream(p) if not sort and not bin_data: charts_to_brush.append(p) if not is_color_by_numeric and labels_on_data: labels_df = df_to_plot[[coordinate_columns[0], coordinate_columns[1], key]].groupby(key).aggregate( np.median) labels = hv.Labels({('x', 'y'): labels_df, 'text': labels_df.index.values}, ['x', 'y'], 'text').opts( text_font_size=legend_font_size) p = p * labels p.bounds_stream = bounds_stream plots.append(p) for i in range(len(charts_to_brush)): for j in range(i): __BrushLinkRange(charts_to_brush[i], charts_to_brush[j]) __BrushLinkRange(charts_to_brush[j], charts_to_brush[i]) if cols is None: cols = 1 if width > 500 else 2 layout = hv.Layout(plots).cols(cols) layout.df = df_with_coords return layout
[docs]def variable_feature_plot(adata: AnnData, **kwds) -> hv.core.element.Element: """ Generate a variable feature plot. Args: adata: Annotated data matrix. """ if 'hvf_loess' in adata.var: keywords = dict(x='mean', y='var', y_fit='hvf_loess', color='highly_variable_features', xlabel='Mean log expression', ylabel='Variance of log expression') else: keywords = dict(x='means', y='dispersions_norm', y_fit=None, color='highly_variable', xlabel='Mean log expression', ylabel='Normalized dispersion') keywords.update(kwds) if 'cmap' not in keywords: keywords['cmap'] = {'True': '#e41a1c', 'False': '#bdbdbd'} x = keywords.pop('x') y = keywords.pop('y') color = keywords.pop('color') xlabel = keywords.pop('xlabel') ylabel = keywords.pop('ylabel') y_fit = keywords.pop('y_fit') line_color = keywords.pop('line_color', 'black') if y_fit is not None and y_fit in adata.var: return scatter(adata, x=x, y=y, xlabel=xlabel, color=color, ylabel=ylabel, **keywords) * line(adata, x=x, y=y_fit, line_color=line_color) else: return scatter(adata, x=x, y=y, color=color, xlabel=xlabel, ylabel=ylabel)
def volcano(adata: AnnData, basis: str = 'de_res', x: str = 'log_fold_change', y: str = 't_qval', x_cutoff: float = 1, y_cutoff: float = 0.05, cluster_ids: Union[List, Tuple, Set] = None, **kwds) -> hv.core.element.Element: """ Generate a volcano plot. Args: adata: Annotated data matrix. basis: String in adata.varm containing statistics to plot. x: Field in basis to plot on x-axis. Field is assumed to end with :cluster_id (e.g. log_fold_change:1). y: Field in basis to plot on y-axis. Field is assumed to end with :cluster_id (e.g. t_qval:1).. x_cutoff: Highlight items >= x_cutoff or <=-x_cutoff y_cutoff: Highlight items >= y_cutoff cluster_ids: Optional list of cluster ids to include. If unspecified, plots are shown for all clusters. """ de_results = adata.varm[basis] names = de_results.dtype.names # stat:cluster e.g. 'mwu_pval:13' cluster_to_xy = {} keywords = dict(fontsize=dict(title=9), nonselection_line_color=None, line_color='black', selection_line_color='black', line_width=0.3, nonselection_alpha=0.05, padding=0.02, xaxis=True, yaxis=True, alpha=0.9, tools=['box_select'], hover_cols=['id'], cmap={'Up': '#e41a1c', 'Down': '#377eb8', 'Not significant': '#bdbdbd'}) keywords.update(kwds) for name in names: xy_index = -1 if name.startswith(x): xy_index = 0 elif name.startswith(y): xy_index = 1 if xy_index != -1: cluster_id = name[name.rindex(':') + 1:] if cluster_ids is None or (cluster_ids is not None and cluster_id in cluster_ids): xy = cluster_to_xy.get(cluster_id, None) if xy is None: xy = [None, None] cluster_to_xy[cluster_id] = xy xy[xy_index] = name plots = [] cluster_ids = cluster_to_xy.keys() df = pd.DataFrame(dict(id=adata.var.index.values)) filtered_cluster_ids = [] for cluster_id in cluster_ids: xy = cluster_to_xy[cluster_id] if xy[0] is not None and xy[1] is not None: filtered_cluster_ids.append(cluster_id) x_column = '{}_{}'.format(x, cluster_id) y_column = '{}_{}'.format(y, cluster_id) y_log_column = '{}_{}_log'.format(y, cluster_id) status_column = '{}_status'.format(cluster_id) df[x_column] = de_results[xy[0]] df[y_column] = de_results[xy[1]] df[status_column] = 'Not significant' df.loc[(df[y_column] <= y_cutoff) & (df[x_column] >= x_cutoff), status_column] = 'Up' df.loc[(df[y_column] <= y_cutoff) & (df[x_column] < -x_cutoff), status_column] = 'Down' df[y_log_column] = -np.log10(df[y_column] + 1e-12) for cluster_id in filtered_cluster_ids: x_column = '{}_{}'.format(x, cluster_id) y_column = '{}_{}'.format(y, cluster_id) y_log_column = '{}_{}_log'.format(y, cluster_id) status_column = '{}_status'.format(cluster_id) __create_hover_tool(df, keywords, exclude=[], whitelist=['id', x_column, y_column]) p = df.hvplot.scatter(x=x_column, y=y_log_column, title=str( cluster_id), c=status_column, xlabel=str(x), ylabel='-log10 ' + str(y), **keywords) plots.append(p) # shared_datasource for linked brushing colors points incorrectly for i in range(len(plots)): for j in range(i): __BrushLink(plots[i], plots[j]) __BrushLink(plots[j], plots[i]) result = hv.Layout(plots).cols(1) result.df = df return result
[docs]def composition_plot(adata: AnnData, by: str, condition: str, stacked: bool = True, normalize: bool = True, condition_sort_by: str = None, cmap: Union[str, List[str], Tuple[str]] = None, **kwds) -> hv.core.element.Element: """ Generate a composition plot, which shows the percentage of observations from every condition within each cluster (by). Args: adata: Annotated data matrix. by: Key for accessing variables of adata.var_names or a field of adata.obs used to group the data. condition: Key for accessing variables of adata.var_names or a field of adata.obs used to compute counts within a group. stacked: Whether bars are stacked. normalize: Normalize counts within each group to sum to one. condition_sort_by: Sort condition within each group by max, mean, natsorted, or None. cmap: Color map name (hv.plotting.list_cmaps()) or a list of hex colors. See http://holoviews.org/user_guide/Styling_Plots.html for more information. """ adata_raw = __get_raw(adata, False) keys = [by, condition] adata_df = __get_df(adata, adata_raw, keys) for column in adata_df: if not pd.api.types.is_categorical_dtype(adata_df[column]): adata_df[column] = adata_df[column].astype(str).astype('category') cmap = __get_category_cmap(adata_raw, adata_df, condition) if cmap is None else __fix_cmap(adata_df, condition, cmap) keywords = dict(stacked=stacked, group_label=condition) keywords.update(kwds) invert = keywords.get('invert', False) if not invert and 'rot' not in keywords: keywords['rot'] = 90 dummy_df = pd.get_dummies(adata_df[condition]) df = pd.concat([adata_df, dummy_df], axis=1) df = df.groupby(by).agg(np.sum) if normalize: df = df.T.div(df.sum(axis=1)).T if not (pd.api.types.is_categorical_dtype(df.index) and df.index.dtype.ordered): df = df.loc[natsorted(df.index)] secondary = dummy_df.columns.values if condition_sort_by == 'max' or condition_sort_by == 'mean': secondary_sort = df.values.max(axis=0) if condition_sort_by == 'max' else df.values.mean(axis=0) index = np.flip(np.argsort(secondary_sort)) secondary = secondary[index] elif condition_sort_by == 'natsorted': secondary = natsorted(secondary) secondary = list(secondary) p = df.hvplot.bar(by, secondary, cmap=cmap, **keywords) p.df = df return p