import anndata as ad
import hdf5plugin
import scipy
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
import pandas as pd
import math
import logging
import os
import scanpy as sc
import pickle
import argparse
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler
from tqdm import tqdm
from copy import deepcopy
# from dance.utils import set_seed
from scipy.sparse import csr_matrix
from CellPLM.utils.eval import downstream_eval
from CellPLM.utils.data import XDict, stratified_sample_genes_by_sparsity
from CellPLM.utils.mask import InputDropoutMaskBuilder
from CellPLM.model import OmicsFormer
from gears import PertData
import random
from gears.data_utils import get_dropout_non_zero_genes, rank_genes_groups_by_cov
from pathlib import Path

def set_seed(rndseed, cuda: bool = True, extreme_mode: bool = False):
    os.environ["PYTHONHASHSEED"] = str(rndseed)
    random.seed(rndseed)
    np.random.seed(rndseed)
    torch.manual_seed(rndseed)
    if cuda:
        torch.cuda.manual_seed(rndseed)
        torch.cuda.manual_seed_all(rndseed)
    if extreme_mode:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    # dgl.seed(rndseed)
    # dgl.random.seed(rndseed)
    print(f"Setting global random seed to {rndseed}")

def get_activateion_score(pert_label,ori_x,perted_x):
    G_M=torch.sum(pert_label*ori_x,0)/torch.sum(pert_label,0)
    G_W=torch.mean(perted_x-ori_x,0)
    G_E=torch.mean(ori_x,0)
    scores=G_W*G_E/G_M
    return scores

def main(task=None, config=None):
    # global gene_list, batch_labels, seq_list, order_list, coord_list, label_list
    tune_flag = True if config is None else False

    if task is None:
        task = config['head_type']
    print('task name',task)
    
    config['pert_fill'] = args.pert_fill
    config["gene_list"] = pretrained_gene_list
#     config["batch_num"] = batch_labels.max() + 1
    device = torch.device('cuda')

    if args.load_pre:
        pretrained_file = f'../ckpt/{args.pre_model}.best.ckpt'
        pretrained_model_dict = torch.load(pretrained_file)['model_state_dict']
        pretrained_model_dict = {k[7:]: v for k, v in pretrained_model_dict.items()} # remove "module."
        try:
            model.load_state_dict(pretrained_model_dict)
        except:
            # only load params that are in the model and match the size
            model_dict = model.state_dict()
            pretrained_dict = {
                k: v
                for k, v in pretrained_model_dict.items()
                if k in model_dict and v.shape == model_dict[k].shape
            }
            for k, v in pretrained_dict.items():
                print(f"Loading params {k} with shape {v.shape}")
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)
    
    print(model)
    model.to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=config['wd'])
    scheduler = ReduceLROnPlateau(optim, 'min', patience=10, factor=0.9)

    train_loss = []
    valid_loss = []
    valid_metric = []
    for epoch in (pbar := tqdm(range(config['epochs']))):
        epoch_loss = []
        model.train()
        output_values=[]
        input_values=[]
        for i in (range(len(train_list))):
            if len(train_list[i]) > 0:
                input_dict = {
                    'x_seq': train_list[i].to(device),
#                     'batch': torch.tensor(batch_labels[train_batch_list[i].long()]).to(device),
                    'coord': coord[train_batch_list[i].long()].to(device),
                    'label': train_label_list[i].to(device),
                }
                x_dict = XDict(input_dict)
                # print(input_gene_list[i])
                out_dict, loss = model(x_dict, input_gene_list[i])  
                optim.zero_grad()
                loss.backward()
                optim.step()
                epoch_loss.append(loss.item())
                
        train_loss.append(sum(epoch_loss) / len(epoch_loss))
        # scheduler.step()

        model.eval()
        with torch.no_grad():
            epoch_loss = []
            valid_epoch = []

            if epoch % args.eval_interval == 0:
                pred_test = []
                label_test = []
                order_test = []
                test_batches = []
                pred_valid = []
                label_valid = []
                order_valid = []
                valid_batches = []
                control_valid = []
                control_test = []
                batch_valid=[]
                pert_test=[]
                for i in range(len(valid_list)):
                    if len(valid_list[i]) > 0:
                        input_dict = {
                            'x_seq': valid_list[i].to(device),
                            'coord': coord[valid_batch_list[i].long()].to(device),
                            'label': valid_label_list[i].to(device),
                        }
                        x_dict = XDict(input_dict)
                        out_dict, loss = model(x_dict, input_gene_list[i])
                        control_valid.append(valid_list[i].to_dense()[:, :-(len(input_gene_list[i]) - len(data_genes))])
                        label_valid.append(valid_label_list[i])
                        pred_valid.append(out_dict['pred'].cpu())
                        batch_valid.append(torch.tensor(batch_labels[valid_batch_list[i].long()]))
                        valid_batches.append(local_batches_test[i])
                        order_valid.append(valid_batch_list[i])
                    epoch_loss.append(loss.item())

                    

                control_valid = torch.cat(control_valid).to(device)#[np.concatenate(order_valid)]
                pred_valid = torch.cat(pred_valid).to(device)#[np.concatenate(order_valid)]
                label_valid = torch.cat(label_valid).to(device)#[np.concatenate(order_valid)]
                batch_valid=torch.cat(batch_valid).to(device)

                valid_loss.append(sum(epoch_loss) / len(epoch_loss))

                pbar.set_description(f'Epoch {epoch} | Train loss: {train_loss[-1]:.4f} | Valid loss: {valid_loss[-1]:.4f}')

        if min(valid_loss) == valid_loss[-1]:
#         if min(valid_metric) == valid_metric[-1]:
        # if max(valid_metric) == valid_metric[-1]:
            temp = deepcopy(model.state_dict())
        # if epoch > 0 and min(valid_loss[-20:]) != min(valid_loss):
        if args.early_stop:
                
            if epoch > 0 and min(valid_loss[-config['valid_epochs']:]) != min(valid_loss):
                print('Early stopped.')
                break

    # Inference
    if args.load_best:
        model.load_state_dict(temp)
    pred = []
    label = []
    order = []
    control = []
    test_batches = []
    pert_label=[]
    gene_lists=[]
    model.eval()
    pert_idx={}
    pert_idx_test={}
    counts=0
    test_gene=[]

    with torch.no_grad():
        
        for i in range(len(test_list)):
            if len(test_list[i]) > 0:
                input_dict = {
                    'x_seq': test_list[i].to(device),
#                     'batch': torch.tensor(batch_labels[test_batch_list[i].long()]).to(device),
                    'coord': coord[test_batch_list[i].long()].to(device),
                    'label': test_label_list[i].to(device),
                }
                x_dict = XDict(input_dict)
                out_dict, loss = model(x_dict, input_gene_list[i]) #
                # print(len(input_gene_list[i]))# 1,N (N=gene_list + pert_gene)
                gene_lists.append(np.expand_dims(np.array(input_gene_list[i]),0))
                test_gene.append(np.expand_dims(np.array(input_gene_list[i]),0))
                pert_label.append(x_dict['label'].to_dense())
                print(test_list[i])
                import ipdb
                ipdb.set_trace()
                control.append(test_list[i].to_dense()[:, :-(len(input_gene_list[i]) - len(data_genes))])
                order.append(test_batch_list[i])
                label.append(test_label_list[i])
                pred.append(out_dict['pred'].cpu())
                test_batches.append(local_batches_test[i])
                
                pert_idx[input_gene_list[i][-1]]=out_dict['pred'].shape[0]
                pert_idx_test[input_gene_list[i][-1]]=out_dict['pred'].shape[0]
        del loss, out_dict, model
        
    
    torch.cuda.empty_cache()
    print(pert_idx)

    if task == 'perturbation_prediction':
        pred = torch.cat(pred).to(device)#[np.concatenate(order)]
        label = torch.cat(label).to(device)#[np.concatenate(order)]
        control = torch.cat(control).to(device)#[np.concatenate(order)]
        pert_label=torch.cat(pert_label).to(device)
        gene_lists=np.concatenate(gene_lists,0)

        test_batches = np.concatenate(test_batches)#[np.concatenate(order)]

        np.savez( Path(args.save_dir) / f'pert_{args.dataset}_{args.seed}_{args.split}.npz', pred=pred.detach().cpu().numpy(), label=label.detach().cpu().numpy(),
                 control=control.detach().cpu().numpy(), batch=test_batches)

        scores_delta = downstream_eval(task, pred, label, top_de_dict=top_de_dict, batch_labels=test_batches, control_level=control)
        scores = downstream_eval(task, pred, label, top_de_dict=top_de_dict)
        scores_batch = downstream_eval(task, pred, label, top_de_dict=top_de_dict, batch_labels=test_batches)
        
        print(f'scores: {scores}')
        print(f'scores_batch: {scores_batch}')
        print(f'scores_delta: {scores_delta}')
        # if tune_flag:
        print({
            'final_rmse': scores['all_rmse'],
            'final_corr': scores['all_corr'],
            'final_cos': scores['all_cos'],
            'final_de_rmse': scores['top_de_rmse'],
            'final_de_corr': scores['top_de_corr'],
            'final_de_cos': scores['top_de_cos'],

            'final_rmse_batch': scores_batch['all_rmse'],
            'final_corr_batch': scores_batch['all_corr'],
            'final_cos_batch': scores_batch['all_cos'],
            'final_de_rmse_batch': scores_batch['top_de_rmse'],
            'final_de_corr_batch': scores_batch['top_de_corr'],
            'final_de_cos_batch': scores_batch['top_de_cos'],

            'final_rmse_delta': scores_delta['all_rmse'],
            'final_corr_delta': scores_delta['all_corr'],
            'final_cos_delta': scores_delta['all_cos'],
            'final_de_rmse_delta': scores_delta['top_de_rmse'],
            'final_de_corr_delta': scores_delta['top_de_corr'],
            'final_de_cos_delta': scores_delta['top_de_cos'],
        })
       

    # del res, y, c, df
    del pred, label, temp
    torch.cuda.empty_cache()

def create_sparse_tensor(x):
    return torch.sparse_csr_tensor(x.indptr, x.indices, x.data, (x.shape[0], x.shape[1])).to_sparse().float()

if __name__ == '__main__':
    def str2bool(v):
        if isinstance(v, bool):
            return v
        if v.lower() in ("yes", "true", "t", "y", "1"):
            return True
        elif v.lower() in ("no", "false", "f", "n", "0"):
            return False
        else:
            raise argparse.ArgumentTypeError("Boolean value expected.")
        
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default='perturbation_prediction')
    parser.add_argument("--dataset", type=str, default='adamson') # norman, dixit, adamson
    parser.add_argument("--split", type=str, default='single') # single, combo_seen0, combo_seen1, combo_seen2
    parser.add_argument("--latent_mod", type=str, default='gmvae')
    parser.add_argument("--pre_model", type=str, default='20230926_75M_50M')
    parser.add_argument("--batch_feat", action='store_true')          
    parser.add_argument("--seed", type=int, default=10)
    parser.add_argument("--epochs", type=int, default=500)
    parser.add_argument("--pert_fill", type=float, default=-100)
    parser.add_argument("--tune", action='store_true')
    parser.add_argument("--use_rng", action='store_true')
    parser.add_argument("--rng_seed", type=int, default=None)
    parser.add_argument("--model_dropout", type=float, default=0.7)
    parser.add_argument("--mask_node_rate", type=float, default=0.)
    parser.add_argument("--mask_feature_rate", type=float, default=0.)
    parser.add_argument("--valid_epochs", type=int, default=20)
    parser.add_argument("--early_stop", type=str2bool, default=True)
    parser.add_argument("--load_best", type=str2bool, default=True)
    parser.add_argument("--eval_interval", type=int, default=1)
    parser.add_argument("--info_interval", type=int, default=5)
    parser.add_argument("--load_pre", type=str2bool, default=True)
    parser.add_argument("--save_dir", type=str, default='save_as')
    args = parser.parse_args()
    print(args)
    set_seed(args.seed)
    task = args.task
    torch.set_num_threads(12)

    with (open(f"../ckpt/{args.pre_model}.config.pkl", "rb")) as openfile:
        config = pickle.load(openfile)
    if args.load_pre:
        pretrained_gene_list = config['gene_list']
    pert_data_loader = PertData('./data/pert_data')
    pert_data_loader.load(data_path=f'./data/pert_data/{args.dataset}')
    my_adata=pert_data_loader.adata
    print(pert_data_loader.adata)

    data = pert_data_loader.prepare_split(split=args.split, seed=args.seed)
    # data.var['ensg'] = data.var.index
    # data.var['gene_id'] = data.var['gene_name']
    # data.var = data.var.set_index('gene_id')
    
    data.var_names_make_unique()
    exclude_pert = {
        'adamson': ['SRPR+ctrl', 'SLMO2+ctrl', 'TIMM23+ctrl', 'AMIGO3+ctrl', 'KCTD16+ctrl'],
        'norman': ['RHOXF2BB+ctrl', 'LYL1+IER5L', 'ctrl+IER5L', 'KIAA1804+ctrl', 'IER5L+ctrl', 'RHOXF2BB+ZBTB25', 'RHOXF2BB+SET'],
        'datlingerbock2021': [],
        'gse203592': ['zMulti+ctrl', 'PDCD1+ctrl', 'zNone+ctrl'],
        'hutcellscrispra': [
            'AKAP12+ctrl', 'APOBEC3D+ctrl', 'BICDL2+ctrl', 'CD2+ctrl', 'CD247+ctrl'
            'CD27+ctrl', 'CD28+ctrl', 'CEACAM1+ctrl', 'CNR2+ctrl', 'EMP1+ctrl'
            'FOXL2NB+ctrl', 'GRAP+ctrl', 'IFNG+ctrl', 'IL1R1+ctrl', 'IL2+ctrl'
            'IL2RB+ctrl', 'IL2RG+ctrl', 'IL9R+ctrl', 'LAT+ctrl', 'LAT2+ctrl', 'MUC1+ctrl'
            'NLRC3+ctrl', 'NOTCH1+ctrl', 'OTUD7A+ctrl', 'P2RY14+ctrl', 'PIK3AP1+ctrl'
            'PLCG2+ctrl', 'SLA2+ctrl', 'TAGAP+ctrl', 'TNFRSF9+ctrl'
        ],
        'full_hutcellscrispra': [
            'AKAP12+ctrl', 'APOBEC3D+ctrl', 'BICDL2+ctrl', 'CD2+ctrl', 'CD247+ctrl'
            'CD27+ctrl', 'CD28+ctrl', 'CEACAM1+ctrl', 'CNR2+ctrl', 'EMP1+ctrl'
            'FOXL2NB+ctrl', 'GRAP+ctrl', 'IFNG+ctrl', 'IL1R1+ctrl', 'IL2+ctrl'
            'IL2RB+ctrl', 'IL2RG+ctrl', 'IL9R+ctrl', 'LAT+ctrl', 'LAT2+ctrl', 'MUC1+ctrl'
            'NLRC3+ctrl', 'NOTCH1+ctrl', 'OTUD7A+ctrl', 'P2RY14+ctrl', 'PIK3AP1+ctrl'
            'PLCG2+ctrl', 'SLA2+ctrl', 'TAGAP+ctrl', 'TNFRSF9+ctrl'
        ],
    }
    if args.dataset in exclude_pert.keys():
        data = data[~data.obs.condition.isin(exclude_pert[args.dataset])]
        print(f"Excluding {exclude_pert[args.dataset]}")
    print(data)
    # print(list(data.var.index))
    gene_dicts={}
    for e_name,g_name in zip(list(data.var.index),list(data.var['gene_name'])):
        gene_dicts[g_name]=e_name
    # exit()
    gene_list = data.var.index.to_list()
    if args.load_pre:
        gene_list = [x for x in gene_list if x in pretrained_gene_list]
        data = data[:, gene_list]
    else:
        pert_genes = list(np.unique(data.obs['gene']))
        if 'CTRL1' in pert_genes:
            unctrl_adata=data[data.obs['condition'] == 'ctrl']
            pert_genes.remove('CTRL1')
            
        pretrained_gene_list = list(set(gene_list + pert_genes))

    data_genes = np.array(gene_list)
    rank_genes_groups_by_cov(data, groupby='condition_name', covariate='cell_type', control_group='ctrl_1', 
                             n_genes=len(data.var), key_added = 'rank_genes_groups_cov_all')
    data = get_dropout_non_zero_genes(data)
    gene2idx = {j:i for i,j in enumerate(data.var.index)}
    top_de_dict = data.uns['top_non_dropout_de_20']
    for k in top_de_dict.keys():
        top_de_dict[k] = np.vectorize(gene2idx.get)(top_de_dict[k])

    control_data = data[data.obs.control == 1]
    pert_data = data[data.obs.control == 0]

    print(pert_data.obs['condition'])
    batch_labels = LabelEncoder().fit_transform(pert_data.obs['condition'])
    print(batch_labels)
    pert = [str(x).split('+') for x in pert_data.obs['condition']]
    # print(pert)
    # crispr=[str(x).split('+') for x in data.obs['crispr']]
    flags = [(np.array(x) != 'ctrl') for x in pert]
    pert = [np.array(pert[i])[flags[i]].astype(str) for i in range(len(pert))]

    train_list = []
    valid_list = []
    test_list = []
    train_label_list = []
    valid_label_list = []
    test_label_list = []
    train_batch_list = []
    valid_batch_list = []
    test_batch_list = []
    input_gene_list = []
    test_crispr_list=[]
    local_batches_test = []
    local_batches_valid=[]
    local_batches_train=[]
    crispr_cont_list=[]
    crispr_pert_list_all=[]

    # perts_inputs


    if args.use_rng:
        seed = args.rng_seed if args.rng_seed is not None else args.seed
        rng = np.random.default_rng(seed=seed)

    for batch in tqdm(range(batch_labels.max() + 1)):
        x = control_data.X.A
        y = torch.tensor(pert_data[batch_labels == batch].X.A).float()
        # y=y-x
        pert_idx = np.arange(len(pert))[batch_labels == batch].tolist()
        pert_label = np.array(pert[pert_idx[0]])
        pert_label_list=[]
        for p in pert_label:
            pert_label_list.append(gene_dicts[p])
        pert_label=np.array(pert_label_list)

        if all([p in pretrained_gene_list for p in pert_label]):
            input_gene_list.append(data_genes)
            # input_gene_list.append(np.concatenate([data_genes, pert_label]))
            # pert_value=(x_dicts[pert_label[0]]*2)
            pert_value=-100

            # pert_input = args.pert_fill * np.ones([len(x), len(pert_label)])
            pert_input = pert_value * np.ones([len(x), len(pert_label)])
            #! 677)
            # print(pert_label)
            pert_label_idx_list=[]
            for p in pert_label:
                pert_label_idx_list.append(data_genes.tolist().index(p))
            pert_label_idx = np.array(pert_label_idx_list)
            ori_x = x 
            
            batch_splits = pert_data.obs.split.values[batch_labels == batch]
            train_batch = np.arange(len(batch_splits))[batch_splits == 'train']
            local_batches_train.append(batch_labels[batch_labels==batch])
            if len(train_batch) > 0:
                if args.use_rng:
                    train_idx = rng.choice(len(x), len(train_batch), replace=False)
                else:
                    train_idx = np.random.choice(len(x), len(train_batch), replace=False)
                
                x = x[train_idx]
                # print(x.shape)
                for i in pert_label_idx:
                    pert_input = y[train_batch, i].unsqueeze(-1)
                    # x = np.hstack([x, pert_input])
                    x[:,[i]]=pert_input
                train_list.append(create_sparse_tensor(csr_matrix(x)))
                
                train_batch_list.append(torch.tensor(train_batch).int())
                y_inputs = y[train_batch]
                # y[:,pert_label_idx] = 0
                train_label_list.append(y_inputs)
            else:
                train_list.append([])
                train_batch_list.append([])
                train_label_list.append([])

            valid_batch = np.arange(len(batch_splits))[batch_splits == 'val']
            local_batches_valid.append(batch_labels[batch_labels==batch])
            if len(valid_batch) > 0:
                if args.use_rng:
                    valid_idx = rng.choice(len(x), len(valid_batch), replace=False)
                else:
                    valid_idx = np.random.choice(len(x), len(valid_batch), replace=False)
                    
                x = x[valid_idx]
                for i in pert_label_idx:
                    pert_input = y[valid_batch, i].unsqueeze(-1)
                    # x = np.hstack([x, pert_input])
                    x[:,[i]]=pert_input
                valid_list.append(create_sparse_tensor(csr_matrix(x)))
                
                valid_batch_list.append(torch.tensor(valid_batch).int())
                y_inputs = y[valid_batch]
                # y[:,pert_label_idx] = 0
                valid_label_list.append(y_inputs)
            else:
                valid_list.append([])
                valid_batch_list.append([])
                valid_label_list.append([])

            test_batch = np.arange(len(batch_splits))[batch_splits == 'test']
            local_batches_test.append(batch_labels[batch_labels == batch])
            if len(test_batch) > 0:
                if args.use_rng:
                    test_idx = rng.choice(len(x), len(test_batch), replace=False)
                else:
                    test_idx = np.random.choice(len(x), len(test_batch), replace=False)
                
                x = x[test_idx]
                for i in pert_label_idx:
                    pert_input = y[test_batch, i].unsqueeze(-1)
                    # x = np.hstack([x, pert_input])
                    x[:,[i]]=pert_input
                test_list.append(create_sparse_tensor(csr_matrix(x)))
                
                test_batch_list.append(torch.tensor(test_batch).int())
                y_inputs = y[test_batch]
                test_label_list.append(y_inputs)
                print(test_batch.shape)
            else:
                test_list.append([])
                test_batch_list.append([])
                test_label_list.append([])
    print(len(train_list),len(valid_list),len(test_list))
    if not args.batch_feat:
        batch_labels = torch.zeros(pert_data.shape[0]).int()
    coord = torch.zeros(pert_data.shape[0], 2) - 1
    out_dim = y.shape[1]
    del data, x


    config['mask_type'] = 'input' # 'hidden'
    config['dec_mod'] = 'mlp' #'resmlp' # mlp
    config['dec_hid'] = 256 # 128
    config['dec_layers'] = 2 # 3 # 2
    config['model_dropout'] = args.model_dropout # 0.7
    config['mask_node_rate'] = 0.# args.model_dropout # 0.5
    config['mask_feature_rate'] = 0.# args.model_dropout # 0.5
    config['drop_node_rate'] = 0.3
    config['epochs'] = args.epochs
    config['lr'] = 1e-4 # 
    config['wd'] = 1e-8 # 1e-8
    config['latent_mod'] = 'gmvae'
    config['w_li'] = 1.
    config['w_en'] = 1.
    config['head_type'] = task
    config['out_dim'] = out_dim
    config['valid_epochs'] = args.valid_epochs # 20
    config['eval_interval'] = args.eval_interval # 1
    main(task, config)

