import io
import numpy as np
from Bio import Phylo


"""
Components:

-A generator for latent cell state vectors Z, which allows for setting levels of correlation between components
-An invertible process to convert latent cell state vectors to gene expression values X
-A tree defining the relationship between cell types
-A minimum correlation between cell type specific weights
-A process to convert the tree and minimum correlation into a covariance matrix
-A sampler for the cell-specific weights given the covariance matrix (W)
-Mapping between Z  and phenotype (Y) using cell-specific weights
"""

def generate_binary_tree(depth):
    """
    :param depth: the number of levels in the binary tree to be generated
    :return: Newick string for the binary tree
    """
    num_leaves = np.power(2, depth)
    leaves = ['leaf_{}'.format(i) for i in range(num_leaves)]
    tree = recursive_pairing(leaves)

    return tree, leaves


def recursive_pairing(node_list):
    if len(node_list) == 1:
        return '({}:1.00);'.format(node_list[0])
    else:
        assert len(node_list) % 2 == 0, 'odd number of nodes'
        new_list = list()
        for i in range(int(len(node_list) / 2)):
            idx = i * 2
            new_list.append('({}:1.00,{}:1.00)'.format(node_list[idx], node_list[idx+1]))
        return recursive_pairing(new_list)



class DataHandler:
    def __init__(self, feat_dim, latent_dim, pop_size, latent_correlation, parent_path, cell_tree, pp_ordered_names,
                 pheno_noise=0.1, tanh=False, k_sparse=0, seed=None, mask_dim=0):
        """
        :param feat_dim: vector dimension of gene expression values X
        :param latent_dim: vector dimension of Z
        :param pop_size: number of cells to sample per cell type
        :param latent_correlation: how correlated the components of Z should be
        :param pheno_noise: std of noise to be added to the phenotypes
        """
        if seed is not None:
            np.random.seed(seed)

        self.feat_dim = feat_dim
        self.latent_dim = latent_dim
        self.ordered_names = pp_ordered_names
        self.num_cell_types = len(pp_ordered_names)
        # self.num_cell_types = len(cell_names)
        self.pop_size = pop_size
        self.latent_correlation = latent_correlation
        self.pheno_noise = pheno_noise
        self.parent_path = parent_path
        self.tanh = tanh
        self.k_sparse = k_sparse
        self.latents_by_cell, self.phenotypes_by_cell = list(), list()
        self.scale = 0.2

        # generating the gene expression data (feature array X)
        self.expression_array = self.gene_expression_sampler(mask_dim=mask_dim)

        # generating the covariance matrix for cell-type weights (mapping latents to phenotype)
        self.tree = Phylo.read(io.StringIO(cell_tree), "newick")

        # generating the weights that map gene expression to latents, and latents to phenotype
        if self.latent_correlation > 0:
            latent_cov = np.random.normal(loc=self.latent_correlation, scale=self.scale, size=(self.latent_dim, self.latent_dim))
            np.fill_diagonal(np.clip(latent_cov, 0.0, 1.0), 1.0)
            self.latent_map_weights = np.random.multivariate_normal(mean=np.zeros(self.latent_dim), cov=latent_cov, size=self.feat_dim)
        else:
            self.latent_map_weights = np.random.normal(size=(feat_dim, latent_dim))

        # generating the varying weights mapping latents to phenotype
        self.phenotype_weights = self.generate_phenotype_weights()

        # converting the latent cell states to phenotypes
        self.calculate_latents_phenotypes()

    def gene_expression_sampler(self, covariance=None, mask_dim=0):
        """
        :param covariance: optional covariance matrix relating the samples from different cell types
        :return: np array of size (cell_types, pop_size, exp_dim)
        """
        mean = np.zeros(self.feat_dim, dtype=float)
        if covariance is None:
            covariance = np.identity(self.feat_dim)
        gene_exp = np.random.multivariate_normal(mean=mean, size=(self.num_cell_types, self.pop_size), cov=covariance)
        if mask_dim > 0:
            for cell_idx in range(self.num_cell_types):
                mask_idx = np.random.choice(range(self.feat_dim), size=mask_dim, replace=False)
                gene_exp[cell_idx, :, mask_idx] *= 0
        return gene_exp

    def generate_phenotype_weights(self):
        # the dense root weights
        root_weights = np.random.normal(size=self.latent_dim)
        # generating the sparsity pattern of the delta matrix
        num_edges = self.parent_path.shape[0]
        if self.k_sparse == 0:
            delta_mat = np.random.binomial(n=1, p=0, size=(num_edges, self.latent_dim))  # legacy, could be simplified
        else:
            delta_mat = np.zeros(dtype=float, shape=(num_edges, self.latent_dim))
            # picking the k entries in each row to be non-zero
            for row in range(num_edges):
                idx = np.random.choice(range(self.latent_dim), size=self.k_sparse, replace=False)
                for i in idx:
                    delta_mat[row, i] = 1.0

        # scaling the values (and flipping some negative)
        delta_values = np.random.normal(loc=0., scale=0.25, size=delta_mat.shape)
        delta_mat = delta_mat * delta_values
        """
        now constructing the cell type specific latents as a combination of root weights
        and appropriate delta vectors
        """
        pheno_weights = np.matmul(delta_mat.T, self.parent_path).T + root_weights
        return pheno_weights

    def latent_map(self, gene_exp, tanh=False):
        """
        :param gene_exp: array of size (num_cells, gene_exp)
        :return: array of size (num_cells, latent_vector)
        """
        if tanh:
            return np.tanh(np.matmul(gene_exp, self.latent_map_weights))
        else:
            return np.matmul(gene_exp, self.latent_map_weights)

    def calculate_latents_phenotypes(self):
        for type_idx in range(self.num_cell_types):
            cell_type_latents = self.latent_map(self.expression_array[type_idx], tanh=self.tanh)
            # adding noise to the latents
            cell_type_latents = cell_type_latents + np.random.normal(scale=self.pheno_noise, size=cell_type_latents.shape)

            self.latents_by_cell.append(cell_type_latents)
            cell_phenos = np.matmul(cell_type_latents, self.phenotype_weights[type_idx])
            # adding noise to phenotypes
            cell_phenos = cell_phenos + np.random.normal(scale=self.pheno_noise, size=len(cell_phenos))
            self.phenotypes_by_cell.append(cell_phenos)
