import torchdiffeq
import torch
import torch.nn as nn
import numpy as np
import anndata as ad
import scanpy as sc
import pandas as pd
import random
import os
from copy import deepcopy
import os
import logging
import json
from constants import *

if use_cuda == True:
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"

def set_all_seeds(seed):
  random.seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.backends.cudnn.deterministic = True
  print("All Seeds Set Successfully.")

def normalize_genes_to_median_total(adata):
    # Calculate the total count per gene across all cells
    gene_totals = np.sum(adata.X, axis=0)
    

    # Calculate the median of these total counts
    median_total = np.median(gene_totals)

    # Scale each gene's expression so its total count matches the median total count
    for i in range(adata.shape[1]):  # iterate over genes
        gene_expression = adata.X[:, i]
        gene_total = gene_totals[i]
        
        # Calculate scaling factor for the gene. Avoid division by zero by checking if gene_total is not zero.
        if gene_total != 0:
            scaling_factor = median_total / gene_total
        else:
            scaling_factor = 0
        
        # Apply scaling
        adata.X[:, i] = gene_expression * scaling_factor

def read_data(file_path,boot = False):
    TF_Atlas = ad.read_h5ad(file_path, backed='r')

    if boot:
        N = TF_Atlas.n_obs
        N_sample = N  
        indices = np.random.choice(N, size=N_sample, replace=True)

        unique_indices, inverse_indices = np.unique(indices, return_inverse=True)

        data_unique = TF_Atlas.X[unique_indices, :]  

        sampled_data = data_unique[inverse_indices, :]

        obs_unique = TF_Atlas.obs.iloc[unique_indices].reset_index(drop=True)
        sampled_obs = obs_unique.iloc[inverse_indices].reset_index(drop=True)

        subsampled_TF_Atlas = ad.AnnData(X=sampled_data, obs=sampled_obs, var=TF_Atlas.var)
        TF_Atlas = subsampled_TF_Atlas.to_memory() 
    else:
        TF_Atlas = TF_Atlas.to_memory()
    return TF_Atlas

def load_tf_atlas(file_path, N_genes = 500, trajectory_inference = False , heldout_TF = ['HOXD12'], boot = False, generalizability = False, percentage_interv = 1):
    GRN8_genes = ['CDX1', 'CDX2','HOXD11' ,'CDX4', 'HOXB7', 'HOXC9', 'HOXC10' ,'HOXD9' , 'HOXA10', 'HOXC11', 'HOXC12' , 'HOXD12', 'HOXC13'] 
    GRN4_genes = ['GRHL1', 'GRHL3', 'TFAP2C', 'TEAD1', 'TEAD2', 'TEAD3', 'TEAD4']
    GRN5_genes = ['FLI1','JUNB', 'JUN', 'JUND', 'ETV2']
    GRN_evaluated = GRN4_genes + GRN5_genes + GRN8_genes

    #marker genes in the Anterior Posterior Axis
    Anterior_Posterior_genes = ['SHH','HOXD4', 'HOXD12', 'HOXD10', 'HOXD3', 'HOXD8', 'HOXD11', 'HOXD13', 'HOXD9', 'HOXD1','BMP2', 'TBX2', 'TBX3', 'SALL1', 'SALL2', 'SALL3', 'SALL4','OTX2']
    Endothelial_genes = ['PECAM1', 'CDH5', 'VWF', 'KDR', 'FLT1', 'TEK', 'CLDN5']
    Trophoblasts_genes = ['KRT7', 'HLA-G', 'TFAP2C', 'GATA3', 'MMP2', 'CSH1']

    marker_genes = list(set(Anterior_Posterior_genes+Endothelial_genes+Trophoblasts_genes))

    
    TF_Atlas = read_data(file_path, boot = boot)
    sc.pp.log1p(TF_Atlas)

    matching_strings_mCherry = [tf_string for tf_string in TF_Atlas.obs['TF'] if 'mCherry' == tf_string.split('-')[-1]]
    df_mCherry = pd.DataFrame({'TF': matching_strings_mCherry})
    adata_objects_mCherry = ad.AnnData(obs=df_mCherry)

    initial_dist = TF_Atlas[TF_Atlas.obs['TF'].isin(adata_objects_mCherry.obs['TF'])]

    
    tf_list = list(TF_Atlas.obs['Group'].unique())
    tf_list.remove('mCherry')


    #TFs_cell_types = list(set(marker_genes+GRN_evaluated))
    #TFs_for_highly_variable_genes = list(set(TFs_cell_types) & set(tf_list))

    if boot or generalizability: 
        tf_list = GRN_evaluated
    else:
        #truncate TF_list for ablation studies 
        new_size = int(len(tf_list) * percentage_interv)
        tf_list = random.sample(tf_list, new_size)
        TFs_for_highly_variable_genes = tf_list


    #genes to be enforced into the expression vector 
    enforced_expression_space_genes = list(set(tf_list + marker_genes + GRN_evaluated))


    tf_list_train_validation = [tf for tf in tf_list if tf not in heldout_TF]
    GRN_train = [tf for tf in GRN_evaluated if tf not in heldout_TF]
    if trajectory_inference:
        GRN_test = heldout_TF
    else:
        GRN_test = []


    logging.info("GRN Analysis. GRN genes valuated: ")
    for tf in GRN_evaluated:
        logging.info(tf)


    dfs = {}
    adata_objects = {}

    for tf in tf_list:
        matching_strings = [tf_string for tf_string in TF_Atlas.obs['TF'] if tf == tf_string.split('-')[-1]]
        df = pd.DataFrame({'TF': matching_strings})
        dfs[tf] = df
        adata_objects[tf] = ad.AnnData(obs=df)

    #Setting up a dictionary with the filtered data
    filtered_data_dict = {}

    for tf in tf_list:
        filtered_data_dict[tf] = TF_Atlas[TF_Atlas.obs['TF'].isin(adata_objects[tf].obs['TF'])]

    if boot or generalizability:
        top_genes_final = list(set(enforced_expression_space_genes))
        top_genes_updated = pd.Index(top_genes_final)

        logging.info("There are %.2f genes selected for evaluation." %
                    (len(top_genes_updated)))

    else: 
        #Find the N number of most variable genes with all_filtered_data_log
        #deepcopy(all_filtered_data) to avoid overallocation of memory (rather than deepcopying TF_ATLAS)
        all_filtered_data_log = ad.concat([filtered_data_dict[tf] for tf in tf_list], axis=0)

        all_filtered_data_highly_variable = ad.concat([filtered_data_dict[tf] for tf in TFs_for_highly_variable_genes], axis=0)

        
        sc.pp.highly_variable_genes(all_filtered_data_highly_variable, flavor='seurat', n_top_genes=N_genes)


        top_genes = all_filtered_data_log.var_names[all_filtered_data_highly_variable.var['highly_variable']]

        top_genes_list = list(top_genes)

        top_genes_list.extend(enforced_expression_space_genes)


        #make the list unique
        top_genes_final = list(set(top_genes_list))
        top_genes_updated = pd.Index(top_genes_final)

        logging.info("There are %.2f genes selected for evaluation." %
                    (len(top_genes_updated)))


    #Get a list of TFs with more than 100 cells. These TFs are validation TFs 
    tf_val = []
    for tf in filtered_data_dict:
        if len(filtered_data_dict[tf]) >= 100:
            tf_val.append(tf)
    
    
    #The list TFs that are all zeroes after scaling down to N most variable genes
    all_zeroes_tf = []

    filtered_data_dict_valid_genes = {}

    for tf in filtered_data_dict:
        adata = filtered_data_dict[tf]

        valid_genes = [gene for gene in top_genes_updated if gene in adata.var_names]

        #Scale down the cells to N most variable genes
        adata_filtered = adata[:, valid_genes]

        #Make sure it is not a cluster of zeroes        
        all_same = np.all(adata_filtered.X == 0)
        if all_same == True:
            all_zeroes_tf.append(tf)
            continue

        filtered_data_dict_valid_genes[tf] = adata_filtered

    initial_dist = initial_dist[:, valid_genes]
    valid_genes_index_dict = {gene: index for index, gene in enumerate(valid_genes)}


    train_dict = {}
    test_dict = {}
    validation_dict = {}


    #Set up training and validation for revolver training 
    tf_val = list(set(tf_val) - set(all_zeroes_tf))

    
    for tf, adata in filtered_data_dict_valid_genes.items():
        if tf in tf_list_train_validation:
            if tf in tf_val: #For TFs that have more than 100 cells and not GRN genes
                indices = np.random.permutation(adata.shape[0])
                
                train_end = int(0.8 * adata.shape[0])  # Now using 80% of the data for training
                
                train_indices = indices[:train_end]
                validation_indices = indices[train_end:]  # The remaining 20% for validation
                
                train_dict[tf] = adata[train_indices].copy()
                validation_dict[tf] = adata[validation_indices].copy()
            else: 
                train_dict[tf] = adata.copy()
        if tf in GRN_test:
            test_dict[tf] = adata.copy()

    return initial_dist, train_dict, test_dict, validation_dict, valid_genes_index_dict, GRN_train, GRN_test
        

def load_tf_atlas_conformal_inference(file_path, N_genes = 500, trajectory_inference = False, boot = False, generalizability = False):
    GRN8_genes = ['CDX1', 'CDX2','HOXD11' ,'CDX4', 'HOXB7', 'HOXC9', 'HOXC10' ,'HOXD9' , 'HOXA10', 'HOXC11', 'HOXC12' , 'HOXD12', 'HOXC13'] 
    GRN4_genes = ['GRHL1', 'GRHL3', 'TFAP2C', 'TEAD1', 'TEAD2', 'TEAD3', 'TEAD4']
    GRN5_genes = ['FLI1','JUNB', 'JUN', 'JUND', 'ETV2']
    GRN_evaluated = GRN4_genes + GRN5_genes + GRN8_genes

    #marker genes in the Anterior Posterior Axis
    Anterior_Posterior_genes = ['SHH','HOXD4', 'HOXD12', 'HOXD10', 'HOXD3', 'HOXD8', 'HOXD11', 'HOXD13', 'HOXD9', 'HOXD1','BMP2', 'TBX2', 'TBX3', 'SALL1', 'SALL2', 'SALL3', 'SALL4','OTX2']
    Endothelial_genes = ['PECAM1', 'CDH5', 'VWF', 'KDR', 'FLT1', 'TEK', 'CLDN5']
    Trophoblasts_genes = ['KRT7', 'HLA-G', 'TFAP2C', 'GATA3', 'MMP2', 'CSH1']

    marker_genes = list(set(Anterior_Posterior_genes+Endothelial_genes+Trophoblasts_genes))

    
    TF_Atlas = read_data(file_path, boot = boot)
    sc.pp.log1p(TF_Atlas)

    matching_strings_mCherry = [tf_string for tf_string in TF_Atlas.obs['TF'] if 'mCherry' == tf_string.split('-')[-1]]
    df_mCherry = pd.DataFrame({'TF': matching_strings_mCherry})
    adata_objects_mCherry = ad.AnnData(obs=df_mCherry)

    initial_dist = TF_Atlas[TF_Atlas.obs['TF'].isin(adata_objects_mCherry.obs['TF'])]

    
    tf_list = list(TF_Atlas.obs['Group'].unique())
    tf_list.remove('mCherry')


    #Choose the list of TFs for train, validation, and test. 

    random_selection = random.sample(tf_list, k=60)
    tf_calibration = random_selection[:30]
    tf_test = random_selection[30:]
    tf_list_train = [item for item in tf_list if item not in random_selection]


    #genes to be enforced into the expression vector 
    enforced_expression_space_genes = list(set(tf_list + marker_genes + GRN_evaluated))
    
    TFs_for_highly_variable_genes = tf_list
    


    logging.info("GRN Analysis. GRN genes valuated: ")
    for tf in GRN_evaluated:
        logging.info(tf)


    dfs = {}
    adata_objects = {}

    for tf in tf_list:
        matching_strings = [tf_string for tf_string in TF_Atlas.obs['TF'] if tf == tf_string.split('-')[-1]]
        df = pd.DataFrame({'TF': matching_strings})
        dfs[tf] = df
        adata_objects[tf] = ad.AnnData(obs=df)

    #Setting up a dictionary with the filtered data
    filtered_data_dict = {}

    for tf in tf_list:
        filtered_data_dict[tf] = TF_Atlas[TF_Atlas.obs['TF'].isin(adata_objects[tf].obs['TF'])]

    #Find the N number of most variable genes with all_filtered_data_log
    #deepcopy(all_filtered_data) to avoid overallocation of memory (rather than deepcopying TF_ATLAS)
    all_filtered_data_log = ad.concat([filtered_data_dict[tf] for tf in tf_list], axis=0)

    all_filtered_data_highly_variable = ad.concat([filtered_data_dict[tf] for tf in TFs_for_highly_variable_genes], axis=0)

    
    sc.pp.highly_variable_genes(all_filtered_data_highly_variable, flavor='seurat', n_top_genes=N_genes)


    top_genes = all_filtered_data_log.var_names[all_filtered_data_highly_variable.var['highly_variable']]

    top_genes_list = list(top_genes)

    top_genes_list.extend(enforced_expression_space_genes)


    #make the list unique
    top_genes_final = list(set(top_genes_list))
    top_genes_updated = pd.Index(top_genes_final)

    logging.info("There are %.2f genes selected for evaluation." %
                (len(top_genes_updated)))


    #Get a list of TFs with more than 100 cells. These TFs are validation TFs 
    tf_val = []
    for tf in filtered_data_dict:
        if len(filtered_data_dict[tf]) >= 100:
            tf_val.append(tf)
    
    
    #The list TFs that are all zeroes after scaling down to N most variable genes
    all_zeroes_tf = []

    filtered_data_dict_valid_genes = {}

    for tf in filtered_data_dict:
        adata = filtered_data_dict[tf]

        valid_genes = [gene for gene in top_genes_updated if gene in adata.var_names]

        #Scale down the cells to N most variable genes
        adata_filtered = adata[:, valid_genes]

        #Make sure it is not a cluster of zeroes        
        all_same = np.all(adata_filtered.X == 0)
        if all_same == True:
            all_zeroes_tf.append(tf)
            continue

        filtered_data_dict_valid_genes[tf] = adata_filtered

    initial_dist = initial_dist[:, valid_genes]
    valid_genes_index_dict = {gene: index for index, gene in enumerate(valid_genes)}


    train_dict = {}
    test_dict = {}
    calibration_dict = {}


    #remove TF with all zero entries
    for tf in all_zeroes_tf: 
        del filtered_data_dict_valid_genes[tf]

    
    for tf, adata in filtered_data_dict_valid_genes.items():
        if tf in tf_list_train:
            train_dict[tf] = adata.copy()
        elif tf in tf_calibration:
            calibration_dict[tf] = adata.copy()
        elif tf in tf_test:
            test_dict[tf] = adata.copy()
        else: 
            raise ValueError(f"Error in Configuring Train, Calibration, and Test Dictionaries.")

    return initial_dist, train_dict, test_dict, calibration_dict, valid_genes_index_dict
        

def load_data_stable(file_path, interv_file):
    simulated_data = read_data(file_path)
    
    #sc.pp.normalize_total(simulated_data, target_sum=1e5)
    sc.pp.log1p(simulated_data)
    
    train_dict = {}
    validation_dict = {}

    with open(interv_file, 'r') as json_file:
        over_expression_dict = json.load(json_file)

    for type in over_expression_dict:
        adata_filtered = simulated_data[simulated_data.obs['Group'] == type, :]

        #Make sure it is not a cluster of zeroes        
        all_same = np.all(adata_filtered.X == adata_filtered.X[0])
        if all_same == True:
            continue

        indices = np.random.permutation(adata_filtered.shape[0])         
        train_end = int(0.8 * adata_filtered.shape[0])  # Now using 80% of the data for training
        
        train_indices = indices[:train_end]
        validation_indices = indices[train_end:]  # The remaining 20% for validation
        
        train_dict[type] = adata_filtered[train_indices].copy()
        validation_dict[type] = adata_filtered[validation_indices].copy()

    initial_dist = simulated_data[simulated_data.obs['Group'] == 'control', :]

    
    return initial_dist, validation_dict, train_dict, simulated_data, over_expression_dict


def load_data_stable_test(file_path, interv_file):
    simulated_data = read_data(file_path)
    
    sc.pp.log1p(simulated_data)
    
    test_dict = {}

    with open(interv_file, 'r') as json_file:
        over_expression_dict = json.load(json_file)

    for type in over_expression_dict:
        adata_filtered = simulated_data[simulated_data.obs['Group'] == type, :]

        #Make sure it is not a cluster of zeroes        
        all_same = np.all(adata_filtered.X == adata_filtered.X[0])
        if all_same == True:
            continue

        test_dict[type] = adata_filtered.copy()

    
    return test_dict


def load_data_brownian(train_dict,mean, std_dev):
    diffussion_TF_Ann_dict = {}
    for tf in train_dict:
        diffussion_TF_X = np.zeros_like(train_dict[tf].X)

        for i in range(len(train_dict[tf].X)):  
            cell = train_dict[tf].X[i]
            delta_W = np.random.normal(mean, std_dev, cell.shape)
            cell_jump = cell + delta_W
            diffussion_TF_X[i] = cell_jump

        diffussion_TF_Ann = ad.AnnData(X=diffussion_TF_X, obs=train_dict[tf].obs)
        diffussion_TF_Ann.var = train_dict[tf].var
        diffussion_TF_Ann.obs['Group'] = pd.Categorical([str(tf) + " Diffused"]*len(diffussion_TF_Ann.obs)) 
        diffussion_TF_Ann_dict[tf] = diffussion_TF_Ann

    all_data = ad.concat([diffussion_TF_Ann_dict[tf] for tf in  diffussion_TF_Ann_dict], axis=0)
    return diffussion_TF_Ann_dict, all_data


def load_data_dcdfg(DAG_knock_init, file_path):
    simulated_data = read_data(file_path)
        
    train_dict = {}
    validation_dict = {}

    groups = simulated_data.obs['Group'].unique()
    # Convert the array of unique groups to a list
    group_list = list(groups)

    #over_expression_dict is a string list conversion table
    over_expression_dict = {}

    for knockdown in group_list:

    
        numbers = knockdown.strip('[]').split()

        # Convert the split strings to floats
        parsed_list = [float(num.strip('.')) for num in numbers if num]
        cleaned_list = [x for x in parsed_list if not np.isnan(x)]
        over_expression_dict[knockdown] = cleaned_list

        logging.info(knockdown)
        if knockdown == DAG_knock_init:
            initial_dist = simulated_data[simulated_data.obs['Group'] == DAG_knock_init, :]
            continue
        adata_filtered = simulated_data[simulated_data.obs['Group'] == knockdown, :]


        #Make sure it is not a cluster of zeroes        
        all_same = np.all(adata_filtered.X == adata_filtered.X[0])
        if all_same == True:
            continue
        

        indices = np.random.permutation(adata_filtered.shape[0])         
        train_end = int(0.8 * adata_filtered.shape[0])  # Now using 80% of the data for training
        
        train_indices = indices[:train_end]
        validation_indices = indices[train_end:]  # The remaining 20% for validation
    
        train_dict[knockdown] = adata_filtered[train_indices].copy()
        validation_dict[knockdown] = adata_filtered[validation_indices].copy()

    
    return initial_dist, validation_dict, train_dict, simulated_data, over_expression_dict

