import logging

import anndata
import numpy as np
import pandas as pd
import scipy
from numba import jit

logging.basicConfig(level=logging.INFO)
import multiprocessing as mp

from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm

from .kh import *


def delve_fs(
    adata=None,
    k: int = 10,
    num_subsamples: int = 1000,
    n_clusters: int = 5,
    null_iterations: int = 1000,
    random_state: int = 0,
    n_random_state: int = 10,
    n_pcs=None,
    n_jobs: int = -1,
):
    """Performs DELVE feature selection
        - step 1: identifies dynamic seed features to construct a between-cell affinity graph according to dynamic cell state progression
        - step 2: ranks features according to their total variation in signal along the approximate trajectory graph using the Laplacian score

    Parameters
    adata: anndata.AnnData
        annotated data object where adata.X is the attribute for preprocessed data (dimensions = cells x features)
    k: int (default = 10)
        number of nearest neighbors for between cell affinity kNN graph construction
    num_subsamples: int (default = 1000)
        number of neighborhoods to subsample when estimating feature dynamics
    n_clusters: int (default = 5)
        number of feature modules
    null_iterations: int (default = 1000)
        number of iterations for gene-wise permutation testing
    random_state: int (default = 0)
        random seed parameter
    n_random_state: int (default = 10)
        number of kmeans clustering initializations
    n_pcs: int (default = None)
        number of principal components to compute pairwise Euclidean distances for between-cell affinity graph construction. If None, uses adata.X
    n_jobs = int (default = -1)
        number of tasks
    ----------
    Returns
    delta_mean: pd.DataFrame
        dataframe containing average pairwise change in expression of all features across subsampled neighborhoods (dimensions = num_subsamples x features)
    modules: pd.DataFrame
        dataframe containing feature-cluster assignments and permutation p-values (dimensions = features x 2)
    selected_features: pd.DataFrame
        dataframe containing ranked features and Laplacian scores following feature selection (dimensions = features x 1)
    ----------
    """
    if n_jobs == -1:
        n_jobs = mp.cpu_count()
    elif n_jobs < -1:
        n_jobs = mp.cpu_count() + 1 + n_jobs

    X, feature_names, obs_names = parse_input(adata)  # parse anndata

    try:
        logging.info("Step 1: identifying dynamic feature modules")
        sub_idx, _, delta_mean, modules = seed_select(
            X=X,
            feature_names=feature_names,
            obs_names=obs_names,
            k=k,
            num_subsamples=num_subsamples,
            n_clusters=n_clusters,
            null_iterations=null_iterations,
            random_state=random_state,
            n_random_state=n_random_state,
            n_pcs=n_pcs,
            n_jobs=n_jobs,
        )

        logging.info("Step 2: performing feature selection")
        if modules is None:
            logging.warning("No modules found, returning None")
            return None, None, None

        dyn_feats = np.asarray(modules.index[modules["cluster_id"] != "static"])
        selected_features = feature_select(
            X=X[sub_idx, :],
            feature_names=feature_names,
            dyn_feats=dyn_feats,
            k=k,
            n_pcs=n_pcs,
            n_jobs=n_jobs,
        )
        return delta_mean, modules, selected_features

    except (TypeError, AttributeError):  # no dynamic seed features were identified
        return None, None, None


def seed_select(
    X=None,
    feature_names=None,
    obs_names=None,
    k: int = 10,
    num_subsamples: int = 1000,
    n_clusters: int = 5,
    null_iterations: int = 1000,
    random_state: int = 0,
    n_random_state: int = 10,
    n_pcs=None,
    n_jobs: int = -1,
):
    """Identifies dynamic seed clusters
    Parameters
    X: np.ndarray (default = None)
        array containing normalized and preprocessed data (dimensions = cells x features)
    feature_names: np.ndarray (default = None)
        array containing feature names
    obs_names: np.ndarray (default = None)
        array containing cell names
    k: int (default = 10)
        number of nearest neighbors for between cell affinity kNN graph construction
    num_subsamples: int (default = 1000)
        number of neighborhoods to subsample when estimating feature dynamics
    n_clusters: int (default = 5)
        number of feature modules
    null_iterations: int (default = 1000)
        number of iterations for gene-wise permutation testing
    random_state: int (default = 0)
        random seed parameter
    n_random_state: int (default = 10)
        number of kmeans clustering initializations
    n_pcs: int (default = None)
        number of principal components to compute pairwise Euclidean distances for between-cell affinity graph construction. If None, uses adata.X
    n_jobs = int (default = -1)
        number of tasks
    ----------
    Returns
    sub_idx: np.andarray
        array containing indices of subsampled neighborhoods
    adata_sub: anndata.AnnData
        annotated data object containing subsampled means (dimensions = num_subsamples x features)
    delta_mean: pd.DataFrame
        dataframe containing average pairwise change in expression of all features across subsampled neighborhoods (dimensions = num_subsamples x features)
    modules: pd.DataFrame
        dataframe containing feature-cluster assignments and permutation p-values (dimensions = features x 2)
    ----------
    """
    if n_jobs == -1:
        n_jobs = mp.cpu_count()
    else:
        n_jobs == mp.cpu_count() + 1 + n_jobs

    # multiprocessing crashes a lot in this function
    # p = mp.Pool(n_jobs)

    np.random.seed(random_state)
    random_state_arr = np.random.randint(0, 1000000, n_random_state)

    logging.info("estimating feature dynamics")
    sub_idx, adata_sub, delta_mean = delta_exp(
        X=X,
        feature_names=feature_names,
        obs_names=obs_names,
        k=k,
        num_subsamples=num_subsamples,
        random_state=random_state,
        n_pcs=n_pcs,
        n_jobs=n_jobs,
    )

    # identify modules
    mapping_df = pd.DataFrame(index=feature_names)
    pval_df = pd.DataFrame(index=feature_names)
    dyn_feats = []
    random_state_idx = []
    results_collected = []

    logging.info(f"Running {n_random_state} clustering iterations sequentially")
    for state in tqdm(random_state_arr, desc="clustering and permutation testing"):
        result = _run_cluster(
            delta_mean, feature_names, n_clusters, null_iterations, state
        )
        if result is not None:
            mapping_df = pd.concat([mapping_df, result[0]], axis=1)
            pval_df = pd.concat([pval_df, result[1]], axis=1)
            dyn_feats.append(result[2])
            random_state_idx.append(result[3])

    logging.info(f"Completed {len(dyn_feats)} successful clustering runs")
    # try:
    #     for result in tqdm(
    #         p.imap(
    #             partial(
    #                 _run_cluster, delta_mean, feature_names, n_clusters, null_iterations
    #             ),
    #             random_state_arr,
    #         ),
    #         total=n_random_state,
    #         desc="clustering features and performing feature-wise permutation testing",
    #     ):
    #         results_collected.append(result)
    #         if result is not None:
    #             mapping_df = pd.concat([mapping_df, result[0]], axis=1)
    #             pval_df = pd.concat([pval_df, result[1]], axis=1)
    #             dyn_feats.append(result[2])
    #             random_state_idx.append(result[3])
    # finally:
    #     p.close()
    #     p.join()

    if len(dyn_feats) == 0:
        logging.warning(
            "No feature clusters have a dynamic variance greater than null. Consider changing the number of clusters or the subsampling size."
        )
        logging.warning(
            f"All {n_random_state} random state runs failed to find dynamic clusters"
        )
        return None, None, None, None
    else:
        dyn_feats = list(np.unique(list(set.intersection(*map(set, dyn_feats)))))
        if len(dyn_feats) == 0:
            logging.warning(
                "No features were considered dynamically-expressed across runs."
            )
            logging.warning(
                f"Found dynamic features in individual runs but none were consistent across all {n_random_state} runs"
            )
            return None, None, None, None
        else:
            modules = _annotate_clusters(
                mapping_df=mapping_df,
                dyn_feats=dyn_feats,
                pval_df=pval_df,
                random_state_idx=random_state_idx[-1],
            )
            n_dynamic_clusters = len(
                np.unique(modules["cluster_id"][modules["cluster_id"] != "static"])
            )
            logging.info(f"identified {n_dynamic_clusters} dynamic cluster(s)")
            return sub_idx, adata_sub, delta_mean, modules


def feature_select(
    X=None,
    feature_names=None,
    dyn_feats=None,
    k: int = 10,
    n_pcs=None,
    n_jobs: int = -1,
):
    """Ranks features along dynamic seed graph using the Laplacian score: https://papers.nips.cc/paper/2005/file/b5b03f06271f8917685d14cea7c6c50a-Paper.pdf
    Parameters
    X: np.ndarray (default = None)
        array containing normalized and preprocessed data (dimensions = cells x features)
    feature_names: np.ndarray (default = None)
        array containing feature names
    dyn_feats: np.ndarray (default = None)
        array containing features that are dynamically expressed. Can consider replacing this with a set of known regulators.
    k: int (default = 10)
        number of nearest neighbors for between cell affinity kNN graph construction
    n_pcs: int (default = None)
        number of principal components to compute pairwise Euclidean distances for between-cell affinity graph construction. If None, uses adata.X
    n_jobs = int (default = -1)
        number of tasks
    ----------
    Returns
    selected_features: pd.DataFrame
        dataframe containing ranked features and Laplacian scores for feature selection (dimensions = features x 1)
    ----------
    """
    f_idx = np.where(np.isin(feature_names, dyn_feats) == True)[
        0
    ]  # index of feature names to construct seed graph
    W = construct_affinity(
        X=X[:, f_idx], k=k, n_pcs=n_pcs, n_jobs=n_jobs
    )  # constructs graph using dynamic seed features
    scores = laplacian_score(X=X, W=W)
    selected_features = pd.DataFrame(scores, index=feature_names, columns=["DELVE"])
    selected_features = selected_features.sort_values(by="DELVE", ascending=True)

    return selected_features


def delta_exp(
    X=None,
    feature_names=None,
    obs_names=None,
    k: int = 10,
    num_subsamples: int = 1000,
    random_state: int = 0,
    n_pcs=None,
    n_jobs: int = -1,
):
    """Estimates change in expression of features across representative cellular neighborhoods
    Parameters
    X: np.ndarray (default = None)
        array containing normalized and preprocessed data (dimensions = cells x features)
    feature_names: np.ndarray (default = None)
        array containing feature names
    obs_names: np.ndarray (default = None)
        array containing cell names
    k: int (default = 10)
        number of nearest neighbors for between cell affinity kNN graph construction
    num_subsamples: int (default = 1000)
        number of neighborhoods to subsample when estimating feature dynamics
    random_state: int (default = 0)
        random seed parameter
    n_pcs: int (default = None)
        number of principal components for between-cell affinity graph computation. if None, uses adata.X to find pairwise Euclidean distances
    n_jobs = int (default = -1)
        number of tasks
    ----------
    Returns
    sub_idx: np.ndarray
        array containing indices of subsampled neighborhoods
    adata_sub: anndata.AnnData
        annotated data object containing subsampled means (dimensions = num_subsamples x features)
    delta_mean: pd.DataFrame (dimensions = num_subsamples x features)
        array containing average pairwise change in expression of all features across subsampled neighborhoods (dimensions = num_subsamples x features)
    ----------
    """
    # construct between cell affinity kNN graph according to all profiled features
    W = construct_affinity(X=X, k=k, n_pcs=n_pcs, n_jobs=-1)

    # compute neighborhood means
    n_bool = W.astype(bool)
    n_mean = (X.transpose() @ n_bool) / np.asarray(n_bool.sum(1)).reshape(1, -1)
    n_mean = pd.DataFrame(n_mean.transpose(), index=obs_names, columns=feature_names)

    # perform subsampling of means to get representative neighborhoods using kernel herding sketching: https://dl.acm.org/doi/abs/10.1145/3535508.3545539, https://github.com/CompCy-lab/SketchKH
    sub_idx, adata_sub = sketch(
        anndata.AnnData(n_mean),
        num_subsamples=num_subsamples,
        frequency_seed=random_state,
        n_jobs=n_jobs,
    )

    # NOTE: orignal code is cubic in memory, not scalable
    # compute the average pairwise change in the expression across all neighborhoods for all features
    # subsampled_means = np.asarray(adata_sub.X, dtype=np.float32)
    # delta_mean = subsampled_means.reshape(
    #     -1, 1, subsampled_means.shape[1]
    # ) - subsampled_means.reshape(1, -1, subsampled_means.shape[1])
    # delta_mean = delta_mean.sum(axis=1) * (1 / (subsampled_means.shape[0] - 1))

    subsampled_means = np.asarray(adata_sub.X, dtype=np.float32)
    delta_mean = pairwise_neighborhoods_difference(subsampled_means)

    delta_mean = pd.DataFrame(
        delta_mean[np.argsort(adata_sub.obs.index)],
        index=adata_sub.obs.index[np.argsort(adata_sub.obs.index)],
        columns=adata_sub.var_names,
    )  # resort according to subsampled indices

    return sub_idx[0], adata_sub, delta_mean


@jit
def pairwise_neighborhoods_difference(subsampled_means):
    # compute mean pairwise difference one gene at a time
    n_neighborhoods = subsampled_means.shape[0]
    n_genes = subsampled_means.shape[1]
    delta_mean = np.empty((n_neighborhoods, n_genes), dtype=np.float32)

    for i in range(n_genes):
        gene_vec = subsampled_means[:, i]
        delta_mean_one_gene = gene_vec[:, np.newaxis] - gene_vec
        delta_mean[:, i] = delta_mean_one_gene.sum(axis=1) / (n_neighborhoods - 1)

    return delta_mean


def laplacian_score(X=None, W=None):
    """Computes the Laplacian score: https://papers.nips.cc/paper/2005/file/b5b03f06271f8917685d14cea7c6c50a-Paper.pdf
    Parameters
    X: np.ndarray (default = None)
        array containing normalized and preprocessed data (dimensions = cells x features)
    W: np.ndarray (default = None)
        adjacency matrix containing between-cell affinity weights
    ----------
    Returns
    l_score: np.ndarray
        array containing laplacian score for all features (dimensions = features)
    ----------
    """
    n_samples, n_features = X.shape

    # compute degree matrix
    D = np.array(W.sum(axis=1))
    D = scipy.sparse.diags(np.transpose(D), [0])

    # compute graph laplacian
    L = D - W.toarray()

    # ones vector: 1 = [1,···,1]'
    ones = np.ones((n_samples, n_features))

    # feature vector: fr = [fr1,...,frm]'
    fr = X.copy()

    # construct fr_t = fr - (fr' D 1/ 1' D 1) 1
    numerator = np.matmul(np.matmul(np.transpose(fr), D.toarray()), ones)
    denomerator = np.matmul(np.matmul(np.transpose(ones), D.toarray()), ones)
    ratio = numerator / denomerator
    ratio = ratio[:, 0]
    ratio = np.tile(ratio, (n_samples, 1))
    fr_t = fr - ratio

    # compute laplacian score Lr = fr_t' L fr_t / fr_t' D fr_t
    l_score = np.matmul(np.matmul(np.transpose(fr_t), L), fr_t) / np.matmul(
        np.dot(np.transpose(fr_t), D.toarray()), fr_t
    )
    l_score = np.diag(l_score)

    return l_score


def construct_affinity(
    X=None, k: int = 10, radius: int = 3, n_pcs=None, n_jobs: int = -1
):
    """Computes between cell affinity knn graph using heat kernel
    Parameters
    X: np.ndarray (default = None)
        Data (dimensions = cells x features)
    k: int (default = None)
        Number of nearest neighbors
    radius: int (default = 3)
        Neighbor to compute per cell distance for heat kernel bandwidth parameter
    n_pcs: int (default = None)
        number of principal components to compute pairwise Euclidean distances for between-cell affinity graph construction. If None, uses adata.X
    n_jobs: int (default = -1)
        Number of tasks
    ----------
    Returns
    W: np.ndarray
        sparse symmetric matrix containing between cell similarity (dimensions = cells x cells)
    ----------
    """
    if n_pcs is not None:
        n_comp = min(n_pcs, X.shape[1])
        pca_op = PCA(n_components=n_comp, random_state=0)
        X_ = pca_op.fit_transform(X)
    else:
        X_ = X.copy()

    # find kNN
    knn_tree = NearestNeighbors(
        n_neighbors=k, algorithm="ball_tree", metric="euclidean", n_jobs=n_jobs
    ).fit(X_)
    dist, nn = knn_tree.kneighbors()  # dist = cells x knn (no self interactions)

    # transform distances using heat kernel
    s = heat_kernel(dist, radius=radius)  # -||x_i - x_j||^2 / 2*sigma_i**2
    rows = np.repeat(np.arange(X.shape[0]), k)
    cols = nn.reshape(-1)
    W = scipy.sparse.csr_matrix(
        (s.reshape(-1), (rows, cols)), shape=(X.shape[0], X.shape[0])
    )

    # make symmetric
    bigger = W.transpose() > W
    W = W - W.multiply(bigger) + W.transpose().multiply(bigger)

    return W


def heat_kernel(dist=None, radius: int = 3):
    """Transforms distances into weights using heat kernel
    Parameters
    dist: np.ndarray (default = None)
        distance matrix (dimensions = cells x k)
    radius: np.int (default = 3)
        defines the per-cell bandwidth parameter (distance to the radius nn)
    ----------
    Returns
    s: np.ndarray
        array containing between cell similarity (dimensions = cells x k)
    ----------
    """
    sigma = dist[
        :, [radius]
    ]  # per cell bandwidth parameter (distance to the radius nn)
    s = np.exp(-1 * (dist**2) / (2.0 * sigma**2))  # -||x_i - x_j||^2 / 2*sigma_i**2
    return s


def parse_input(adata: anndata.AnnData):
    """Accesses and parses data from adata object
    Parameters
    adata: anndata.AnnData
        annotated data object where adata.X is the attribute for preprocessed data
    ----------
    Returns
    X: np.ndarray
        array of data (dimensions = cells x features)
    feature_names: np.ndarray
        array of feature names
    obs_names: np.ndarray
        array of cell names
    ----------
    """
    try:
        if isinstance(adata, anndata.AnnData):
            X = adata.X.copy()
        if isinstance(X, scipy.sparse.csr_matrix):
            X = np.asarray(X.todense())

        feature_names = np.asarray(adata.var_names)
        obs_names = np.asarray(adata.obs_names)
        return X, feature_names, obs_names
    except NameError:
        return None


def _run_cluster(delta_mean, feature_names, n_clusters, null_iterations, state):
    """Multiprocessing function for identifying feature modules and performing gene-wise permutation testing
    Parameters
    delta_mean: pd.DataFrame
        dataframe containing average pairwise change in expression of all features across subsampled neighborhoods (dimensions = num_subsamples x features)
    feature_names: np.ndarray (default = None)
        array containing feature names
    n_clusters: int (default = 5)
        number of feature modules
    null_iterations: int (default = 1000)
        number of iterations for gene-wise permutation testing
    state: int (default = 0)
        random seed parameter
    ----------
    Returns
    mapping_df: pd.DataFrame
        dataframe containing feature to cluster assignments
    pval_df: pd.DataFrame
        dataframe containing the permutation p-values
    dyn_feats: np.ndarray
        array containing features identified as dynamically-expressed following permutation testing
    state: int
        random seed parameter
    ----------
    """
    # perform clustering
    clusters = KMeans(
        n_clusters=n_clusters, random_state=state, init="k-means++", n_init=10
    ).fit_predict(delta_mean.transpose().values)

    feats = {i: feature_names[np.where(clusters == i)[0]] for i in np.unique(clusters)}

    # record feature-cluster assignment to find intersection across runs
    mapping = np.full((len(feature_names), 1), "NaN")
    for id, feature in feats.items():
        mapping[np.isin(feature_names, feature)] = str(id)
    mapping_df = pd.DataFrame(mapping, index=feature_names, columns=[state])

    # compute variance-based permutation test
    seed_var = np.array(
        [
            np.var(
                delta_mean.iloc[:, np.isin(feature_names, feats[i])], axis=1, ddof=1
            ).mean()
            for i in range(n_clusters)
        ]
    )

    null_var = []
    pval_df = pd.DataFrame(index=feature_names, columns=[state])

    for f in range(0, len(feats)):
        null_var_ = np.array(
            [
                np.var(
                    delta_mean.iloc[
                        :,
                        np.isin(
                            feature_names,
                            np.random.choice(
                                feature_names, len(feats[f]), replace=False
                            ),
                        ),
                    ],
                    axis=1,
                    ddof=1,
                ).mean()
                for i in range(null_iterations)
            ]
        )
        permutation_pval = 1 - (len(np.where(seed_var[f] > null_var_)[0]) + 1) / (
            null_iterations + 1
        )
        pval_df.loc[feats[f]] = permutation_pval
        null_var.append(np.mean(null_var_))

    dynamic_id = np.where(seed_var > np.array(null_var))[
        0
    ]  # select dynamic clusters over null variance threshold

    if len(dynamic_id) != 0:
        dyn_feats = np.concatenate(
            [
                v
                for k, v in feats.items()
                if k in np.array(list(feats.keys()))[dynamic_id]
            ]
        )
        return mapping_df, pval_df, dyn_feats, state
    else:
        return None


def _annotate_clusters(
    mapping_df=None, dyn_feats=None, pval_df=None, random_state_idx: int = None
):
    """Annotates clusters as dynamic or static according to feature-wise permutation testing within clusters
    Parameters
    mapping_df: pd.DataFrame
        dataframe containing feature-cluster ids from KMeans clustering across random trials (dimensions = features x n_random_state)
    dyn_feats:  np.ndarray
        array containing features considered to be dynamically expressed across runs
    random_state_idx:  int (default = None)
        id of random state column id in mapping DataFrame to obtain cluster ids
    ----------
    Returns
    modules: pd.DataFrame
        dataframe containing annotated feature-cluster assignment and permutation p-values (dimensions = features x 2)
    ----------
    """
    cluster_id = np.unique(mapping_df.values)
    dynamic_id = np.unique(mapping_df.loc[dyn_feats].loc[:, random_state_idx])
    static_id = cluster_id[~np.isin(cluster_id, dynamic_id)]

    cats = {id_: "static" for id_ in static_id}
    cats.update(
        {
            id_: f"dynamic {i}" if len(dynamic_id) > 1 else "dynamic"
            for i, id_ in enumerate(dynamic_id)
        }
    )

    modules = pd.Categorical(
        pd.Series(mapping_df.loc[:, random_state_idx].astype("str")).map(cats)
    )
    modules = pd.DataFrame(modules, index=mapping_df.index, columns=["cluster_id"])
    modules[~np.isin(modules.index, dyn_feats)] = "static"
    modules["cluster_permutation_pval"] = pval_df.median(
        1
    )  # median across all random trials
    return modules
