import pickle
import scanpy
import numpy as np
from scipy.cluster.hierarchy import to_tree
from utils import build_parent_path_mat, newick_to_adjacency_matrix, get_newick


class Simulator:
    """
    Loads and processes the rna seq data in init, simulates phenotype with function call
    """

    def __init__(self, z_genes, tf_file='data/tf_dict_38183', sc_file='data/processed_gtex.h5ad', regress_out=False):
        self.tf_file = tf_file
        # loading the x genes
        with open(tf_file, 'rb') as f:
            tf_dict = pickle.load(f)[0]
        x_genes = set()
        for z in z_genes:
            for x in tf_dict[z]:
                x_genes.add(x)
        for z in z_genes:
            if z in x_genes:
                x_genes.remove(z)

        self.x_genes = x_genes
        self.z_genes = z_genes
        self.gene_list = [g for g in x_genes]
        self.gene_list.extend([z for z in z_genes])
        print(self.gene_list)

        ad = scanpy.read(sc_file)
        for z in z_genes:
            assert z in set(ad.var.gene_name), '{} is missing from single cell file'.format(z)

        # filtering to epithelial cells
        print('filtering to epithelial cells')
        ad = ad[ad.obs['Cell types level 3'].isin(['Epithelial'])]

        # subset to genes of interest
        ad = ad[:, ad.var.gene_name.isin(self.gene_list)]
        sub_ad = ad[:, ad.var.gene_name.isin(z_genes)]
        scanpy.pp.filter_cells(sub_ad, min_genes=len(z_genes))
        ad = ad[sub_ad.obs_names, :]

        if regress_out:
            scanpy.pp.regress_out(ad, keys='Broad cell type')

        ad.obs['Granular cell type'] = ad.obs['Granular cell type'].apply(lambda x: x.replace(' ', ''))
        ad.obs['Granular cell type'] = ad.obs['Granular cell type'].apply(lambda x: x.replace('(', ''))
        ad.obs['Granular cell type'] = ad.obs['Granular cell type'].apply(lambda x: x.replace(',', ''))
        ad.obs['Granular cell type'] = ad.obs['Granular cell type'].apply(lambda x: x.replace(')', ''))
        ad.obs['Granular cell type'] = ad.obs['Granular cell type'].apply(lambda x: x.replace('/', ''))

        # running clustering to get the tree topology
        pp_ordered_nodes, parent_child_mat = self.create_tree(ad, groupby='Granular cell type')

        types = list(ad.obs['Granular cell type'])
        print(len(set(types)))
        print('{} samples'.format(len(types)))
        samples = list(ad.obs_names)

        X = ad[:, ad.var.gene_name.isin(x_genes)].to_df().to_numpy()

        Z = ad[:, ad.var.gene_name.isin(z_genes)].to_df().to_numpy()
        self.dim_z = Z.shape[1]
        assert self.dim_z == len(z_genes), 'incorrect Z dim'

        # dealing with the tree, weights between Z and y
        parent_path = build_parent_path_mat(parent_child_mat)

        path_dict = dict()
        pp_ordered_nodes = [node.name for node in pp_ordered_nodes]
        print(len(set(pp_ordered_nodes)))
        for sample_idx, pop in enumerate(types):
            path_dict[sample_idx] = pp_ordered_nodes.index(pop)

        self.data_dict = {
            'X': X,
            'Z': Z,
            'id_list': samples,
            'path_dict': path_dict,
            'pp': parent_path,
            'cell_types': types,
            'pp_ordered_nodes': pp_ordered_nodes
        }

    def create_tree(self, ad, groupby='Broad cell type', name='dendrogram'):
        scanpy.tl.dendrogram(ad, groupby=groupby)
        scanpy.pl.dendrogram(ad, groupby, save=True)
        dendro = scanpy.tl.dendrogram(ad, groupby=groupby, inplace=False)  # could extract from previous instead
        ordered_labels = dendro['categories_ordered']
        tree = to_tree(dendro['linkage'])
        newick = get_newick(tree, "", tree.dist, ordered_labels)
        return newick_to_adjacency_matrix(newick)

    def simulate(self, k_sparse=1, seed=0, normalize_y=True, test_edges=None, test_k=1):
        if seed is not None:
            np.random.seed(seed)
        # simulating the Z->y map, replicates simulated data generation file
        # the dense root weights
        root_weights = np.random.normal(size=self.dim_z)
        # generating the sparsity pattern of the delta matrix
        num_edges = self.data_dict['pp'].shape[0]
        if k_sparse == 0:
            delta_mat = np.random.binomial(n=1, p=0, size=(num_edges, self.dim_z))  # legacy, could be simplified
        else:
            delta_mat = np.zeros(dtype=float, shape=(num_edges, self.dim_z))
            # picking the k entries in each row to be non-zero
            for row in range(num_edges):
                idx = np.random.choice(range(self.dim_z), size=k_sparse, replace=False)
                for i in idx:
                    delta_mat[row, i] = 1.0
        if test_edges is not None:
            for te in test_edges:
                delta_mat[te, :] = 0
                idx = np.random.choice(range(self.dim_z), size=test_k, replace=False)
                for i in idx:
                    delta_mat[te, i] = 1.0

        # scaling the values (and flipping some negative)
        delta_values = np.random.normal(loc=0., scale=1.0, size=delta_mat.shape)
        delta_mat = delta_mat * delta_values

        """
        now constructing the cell type specific latents as a combination of root weights
        and appropriate delta vectors
        """
        pheno_weights = np.matmul(delta_mat.T, self.data_dict['pp']).T + root_weights

        y = list()
        for idx in range(self.data_dict['Z'].shape[0]):  # could be done in one line
            y.append(np.matmul(self.data_dict['Z'][idx], pheno_weights[self.data_dict['path_dict'][idx]]))
        y = np.asarray(y)
        if normalize_y:
            y = (y - y.mean()) / y.std()

        return y
