Source code for pegasus.tools.hvf_selection

import time
import numpy as np
import pandas as pd
from pandas.api.types import is_categorical_dtype

from scipy.sparse import issparse
from collections import defaultdict
from joblib import Parallel, delayed
import skmisc.loess as sl
from typing import List
from pegasusio import MultimodalData

import logging
logger = logging.getLogger(__name__)

from pegasusio import timer


def estimate_feature_statistics(data: MultimodalData, consider_batch: bool) -> None:
    """ Estimate feature (gene) statistics per channel, such as mean, var etc.
    """
    assert issparse(data.X)

    if consider_batch:
        start = time.perf_counter()
        # The reason that we test if 'Channel' and 'Channels' exist in addition to the highly_variable_features function is for the case that we do not perform feature selection but do batch correction
        if "Channel" not in data.obs:
            data.obs["Channel"] = ""

        if "Channels" not in data.uns:
            data.uns["Channels"] = data.obs["Channel"].cat.categories.values if is_categorical_dtype(data.obs["Channel"]) else data.obs["Channel"].unique()

        if data.uns["Channels"].size == 1:
            return None

        if "Group" not in data.obs:
            data.obs["Group"] = "one_group"

        if "Groups" not in data.uns:
            data.uns["Groups"] = data.obs["Group"].cat.categories.values if is_categorical_dtype(data.obs["Group"]) else data.obs["Group"].unique()

        channels = data.uns["Channels"]
        groups = data.uns["Groups"]

        ncells = np.zeros(channels.size)
        means = np.zeros((data.shape[1], channels.size))
        partial_sum = np.zeros((data.shape[1], channels.size))

        group_dict = defaultdict(list)
        for i, channel in enumerate(channels):
            idx = np.isin(data.obs["Channel"], channel)
            mat = data.X[idx].astype(np.float64)
            ncells[i] = mat.shape[0]

            if ncells[i] == 0:
                continue

            if ncells[i] == 1:
                means[:, i] = mat.toarray()[0]
            else:
                means[:, i] = mat.mean(axis=0).A1
                m2 = mat.power(2).sum(axis=0).A1
                partial_sum[:, i] = m2 - ncells[i] * (means[:, i] ** 2)

            group = data.obs["Group"][idx.nonzero()[0][0]]
            group_dict[group].append(i)

        partial_sum[partial_sum < 1e-6] = 0.0

        overall_means = np.dot(means, ncells) / data.shape[0]
        batch_adjusted_vars = np.zeros(data.shape[1])

        c2gid = np.zeros(channels.size, dtype=int)
        gncells = np.zeros(groups.size)
        gmeans = np.zeros((data.shape[1], groups.size))
        gstds = np.zeros((data.shape[1], groups.size))

        for i, group in enumerate(groups):
            gchannels = group_dict[group]
            c2gid[gchannels] = i
            gncells[i] = ncells[gchannels].sum()
            gmeans[:, i] = np.dot(means[:, gchannels], ncells[gchannels]) / gncells[i]
            gstds[:, i] = (
                partial_sum[:, gchannels].sum(axis=1) / gncells[i]
            ) ** 0.5  # calculate std
            if groups.size > 1:
                batch_adjusted_vars += gncells[i] * (
                    (gmeans[:, i] - overall_means) ** 2
                )

        data.varm["means"] = means
        data.varm["partial_sum"] = partial_sum
        data.uns["ncells"] = ncells

        data.varm["gmeans"] = gmeans
        data.varm["gstds"] = gstds
        data.uns["gncells"] = gncells
        data.uns["c2gid"] = c2gid

        data.var["mean"] = overall_means
        data.var["var"] = (batch_adjusted_vars + partial_sum.sum(axis=1)) / (
            data.shape[0] - 1.0
        )
        end = time.perf_counter()
        logger.info(
            "Estimation on feature statistics per channel is finished. Time spent = {:.2f}s.".format(
                end - start
            )
        )
    else:
        mean = data.X.mean(axis=0).A1
        m2 = data.X.power(2).sum(axis=0).A1
        var = (m2 - data.X.shape[0] * (mean ** 2)) / (data.X.shape[0] - 1)

        data.var["mean"] = mean
        data.var["var"] = var


def fit_loess(x: List[float], y: List[float], span: float, degree: int) -> object:
    try:
        lobj = sl.loess(x, y, span=span, degree=2)
        lobj.fit()
        return lobj
    except ValueError:
        return None


def select_hvf_pegasus(
    data: MultimodalData, consider_batch: bool, n_top: int = 2000, span: float = 0.02
) -> None:
    """ Select highly variable features using the pegasus method
    """
    if "robust" not in data.var:
        raise ValueError("Please run `qc_metrics` to identify robust genes")

    estimate_feature_statistics(data, consider_batch)

    robust_idx = data.var["robust"].values
    hvf_index = np.zeros(robust_idx.sum(), dtype=bool)

    mean = data.var.loc[robust_idx, "mean"]
    var = data.var.loc[robust_idx, "var"]

    span_value = span
    while True:
        lobj = fit_loess(mean, var, span = span_value, degree = 2)
        if lobj is not None:
            break
        span_value += 0.01
    if span_value > span:
        logger.warning("Leoss span is adjusted from {:.2f} to {:.2f} to avoid fitting errors.".format(span, span_value))

    rank1 = np.zeros(hvf_index.size, dtype=int)
    rank2 = np.zeros(hvf_index.size, dtype=int)

    delta = var - lobj.outputs.fitted_values
    fc = var / lobj.outputs.fitted_values

    rank1[np.argsort(delta)[::-1]] = range(hvf_index.size)
    rank2[np.argsort(fc)[::-1]] = range(hvf_index.size)
    hvf_rank = rank1 + rank2

    hvf_index[np.argsort(hvf_rank)[:n_top]] = True

    data.var["hvf_loess"] = 0.0
    data.var.loc[robust_idx, "hvf_loess"] = lobj.outputs.fitted_values

    data.var["hvf_rank"] = -1
    data.var.loc[robust_idx, "hvf_rank"] = hvf_rank
    data.var["highly_variable_features"] = False
    data.var.loc[robust_idx, "highly_variable_features"] = hvf_index


def select_hvf_seurat_single(
    X: "csr_matrix",
    n_top: int,
    min_disp: float,
    max_disp: float,
    min_mean: float,
    max_mean: float,
) -> List[int]:
    """ HVF selection for one channel using Seurat method
    """
    X = X.copy().expm1()
    mean = X.mean(axis=0).A1
    m2 = X.power(2).sum(axis=0).A1
    var = (m2 - X.shape[0] * (mean ** 2)) / (X.shape[0] - 1)

    dispersion = np.full(X.shape[1], np.nan)
    idx_valid = (mean > 0.0) & (var > 0.0)
    dispersion[idx_valid] = var[idx_valid] / mean[idx_valid]

    mean = np.log1p(mean)
    dispersion = np.log(dispersion)

    df = pd.DataFrame({"log_dispersion": dispersion, "bin": pd.cut(mean, bins=20)})
    log_disp_groups = df.groupby("bin")["log_dispersion"]
    log_disp_mean = log_disp_groups.mean()
    log_disp_std = log_disp_groups.std(ddof=1)
    log_disp_zscore = (
        df["log_dispersion"].values - log_disp_mean.loc[df["bin"]].values
    ) / log_disp_std.loc[df["bin"]].values
    log_disp_zscore[np.isnan(log_disp_zscore)] = 0.0

    hvf_rank = np.full(X.shape[1], -1, dtype=int)
    ords = np.argsort(log_disp_zscore)[::-1]

    if n_top is None:
        hvf_rank[ords] = range(X.shape[1])
        idx = np.logical_and.reduce(
            (
                mean > min_mean,
                mean < max_mean,
                log_disp_zscore > min_disp,
                log_disp_zscore < max_disp,
            )
        )
        hvf_rank[~idx] = -1
    else:
        hvf_rank[ords[:n_top]] = range(n_top)

    return hvf_rank


def select_hvf_seurat_multi(
    X: "csr_matrix",
    channels: List[str],
    cell2channel: List[str],
    n_top: int,
    n_jobs: int,
    min_disp: float,
    max_disp: float,
    min_mean: float,
    max_mean: float,
) -> List[int]:
    Xs = []
    for channel in channels:
        Xs.append(X[np.isin(cell2channel, channel)])

    from joblib import effective_n_jobs

    n_jobs = effective_n_jobs(n_jobs)

    res_arr = np.array(
        Parallel(n_jobs=n_jobs)(
            delayed(select_hvf_seurat_single)(
                Xs[i], n_top, min_disp, max_disp, min_mean, max_mean
            )
            for i in range(channels.size)
        )
    )
    selected = res_arr >= 0
    shared = selected.sum(axis=0)
    cands = (shared > 0).nonzero()[0]
    import numpy.ma as ma

    median_rank = ma.median(ma.masked_array(res_arr, mask=~selected), axis=0).data
    cands = sorted(cands, key=lambda x: median_rank[x])
    cands = sorted(cands, key=lambda x: shared[x], reverse=True)

    hvf_rank = np.full(X.shape[1], -1, dtype=int)
    hvf_rank[cands[:n_top]] = range(n_top)

    return hvf_rank


def select_hvf_seurat(
    data: MultimodalData,
    consider_batch: bool,
    n_top: int,
    min_disp: float,
    max_disp: float,
    min_mean: float,
    max_mean: float,
    n_jobs: int,
) -> None:
    """ Select highly variable features using Seurat method.
    """

    robust_idx = data.var["robust"].values
    X = data.X[:, robust_idx]

    hvf_rank = (
        select_hvf_seurat_multi(
            X,
            data.uns["Channels"],
            data.obs["Channel"],
            n_top,
            n_jobs=n_jobs,
            min_disp=min_disp,
            max_disp=max_disp,
            min_mean=min_mean,
            max_mean=max_mean,
        )
        if consider_batch
        else select_hvf_seurat_single(
            X,
            n_top=n_top,
            min_disp=min_disp,
            max_disp=max_disp,
            min_mean=min_mean,
            max_mean=max_mean,
        )
    )

    hvf_index = hvf_rank >= 0

    data.var["hvf_rank"] = -1
    data.var.loc[robust_idx, "hvf_rank"] = hvf_rank
    data.var["highly_variable_features"] = False
    data.var.loc[robust_idx, "highly_variable_features"] = hvf_index


[docs]@timer(logger=logger) def highly_variable_features( data: MultimodalData, consider_batch: bool, flavor: str = "pegasus", n_top: int = 2000, span: float = 0.02, min_disp: float = 0.5, max_disp: float = np.inf, min_mean: float = 0.0125, max_mean: float = 7, n_jobs: int = -1, ) -> None: """ Highly variable features (HVF) selection. The input data should be logarithmized. Parameters ---------- data: ``pegasusio.MultimodalData`` Annotated data matrix with rows for cells and columns for genes. consider_batch: ``bool``. Whether consider batch effects or not. flavor: ``str``, optional, default: ``"pegasus"`` The HVF selection method to use. Available choices are ``"pegasus"`` or ``"Seurat"``. n_top: ``int``, optional, default: ``2000`` Number of genes to be selected as HVF. if ``None``, no gene will be selected. span: ``float``, optional, default: ``0.02`` Only applicable when ``flavor`` is ``"pegasus"``. The smoothing factor used by *scikit-learn loess* model in pegasus HVF selection method. min_disp: ``float``, optional, default: ``0.5`` Minimum normalized dispersion. max_disp: ``float``, optional, default: ``np.inf`` Maximum normalized dispersion. Set it to ``np.inf`` for infinity bound. min_mean: ``float``, optional, default: ``0.0125`` Minimum mean. max_mean: ``float``, optional, default: ``7`` Maximum mean. n_jobs: ``int``, optional, default: ``-1`` Number of threads to be used during calculation. If ``-1``, all available threads will be used. Returns ------- ``None`` Update ``data.var``: * ``highly_variable_features``: replace with Boolean type array indicating the selected highly variable features. Examples -------- >>> pg.highly_variable_features(data, consider_batch = False) """ if "Channels" not in data.uns: if "Channel" not in data.obs: data.obs["Channel"] = "" data.uns["Channels"] = data.obs["Channel"].cat.categories.values if is_categorical_dtype(data.obs["Channel"]) else data.obs["Channel"].unique() if data.uns["Channels"].size == 1 and consider_batch: consider_batch = False logger.warning( "Warning: only contains one channel, no need to consider batch for selecting highly variable features." ) if flavor == "pegasus": select_hvf_pegasus(data, consider_batch, n_top=n_top, span=span) else: assert flavor == "Seurat" select_hvf_seurat( data, consider_batch, n_top=n_top, min_disp=min_disp, max_disp=max_disp, min_mean=min_mean, max_mean=max_mean, n_jobs=n_jobs, ) logger.info( "{} highly variable features have been selected.".format( data.var["highly_variable_features"].sum() ) )