import anndata as ad
import hdf5plugin
import scipy
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
import pandas as pd
import math
import logging
import scanpy as sc
import pickle
import argparse
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau
from tqdm import tqdm
from copy import deepcopy
from collections import Counter
from dance.utils import set_seed
from scipy.sparse import csr_matrix
from CellPLM.utils.eval import downstream_eval
from CellPLM.utils.data import XDict, clean_batches, stratified_sample_genes_by_sparsity
from CellPLM.model import OmicsFormer
from CellPLM.head import setup_head
import json
import ipdb

def main(task=None):
    global config, gene_list, batch_labels, seq_list, order_list, coord_list, label_list, train_idx, val_idx, test_idx 
    group = f"ann_{args.pre_model}_final_0929"
    
    seq_list = []
    batch_list = []
    order_list = []
    coord_list = []
    label_list = []
    data_hvg = data.copy()
    if config['hvg'] < data_hvg.shape[1]:
        sc.pp.highly_variable_genes(data_hvg, n_top_genes=config['hvg'], subset=True, flavor='seurat_v3')
    gene_list = data_hvg.var.index.tolist()
    for batch in tqdm(range(batch_labels.max() + 1)):
        x = data_hvg[batch_labels == batch].X.astype(float)
        seq_list.append(create_sparse_tensor(x))
        order_list.append(order[batch_labels == batch])
        batch_list.append(torch.from_numpy(batch_labels[batch_labels == batch]))
        coord_list.append(torch.zeros(x.shape[0], 2) - 1)
        label_list.append(torch.from_numpy(labels[batch_labels == batch]).float())
        
    if task is None:
        task = config['head_type']
    device = torch.device('cuda')
    
    model = OmicsFormer(**config)
    pretrained_file = f'../ckpt/{args.pre_model}.best.ckpt'
    pretrained_model_dict = torch.load(pretrained_file)['model_state_dict']
    model_dict = model.state_dict()
    pretrained_dict = {
        k: v
        for k, v in pretrained_model_dict.items()
        if k in model_dict and v.shape == model_dict[k].shape #and 'layers.2' not in k and 'layers.1' not in k
    }
    for k, v in pretrained_dict.items():
        print(f"Loading params {k} with shape {v.shape}")
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    
    print(model)
    model.to(device)
    
    optim = torch.optim.AdamW([
        {'params': list(model.embedder.parameters()), 'lr': config['lr']*0.1,
         'weight_decay': 1e-10},
        {'params': list(model.encoder.parameters()) + list(model.head.parameters()) + list(model.latent.parameters()), 'lr': config['lr'],
         'weight_decay': config['wd']},
    ])
    
    if config['scheduler'] == 'plat':
        scheduler = ReduceLROnPlateau(optim, 'min', patience=config['patience'], factor=0.95)

    eval_dict = {}
    if task == 'annotation':
        eval_dict['num_classes'] = len(np.unique(labels))
    train_loss = []
    valid_loss = []
    valid_metric = []
    test_metric = []
    final_test = -1
    final_epoch = -1
    
    for epoch in tqdm(range(config['epochs'])):
        epoch_loss = []
        train_scores = []
        model.train()

        if epoch<30 and config['scheduler'] != 'cos':
            for param_group in optim.param_groups[1:]:
                param_group['lr'] = config['lr'] * (epoch+1)/30

        for i in range(len(seq_list)):
            idx = torch.arange(batch_list[i].shape[0]).long()
            input_dict = {
                'coord': coord_list[i].to(device),#[cur].to(device),
                'label': label_list[i].to(device),#[cur].to(device),
                'loss_mask': train_idx[i].to(device).bool(),#[cur].to(device).bool(),
            }
            input_dict['x_seq'] = seq_list[i].to(device)
            x_dict = XDict(input_dict)
            out_dict, loss = model(x_dict, gene_list)
            with torch.no_grad():
                train_scores.append(downstream_eval(task, out_dict['pred'], out_dict['label'], **eval_dict))

            optim.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 2.0)
            optim.step()
            epoch_loss.append(loss.item())
                
            if config['scheduler'] == 'cos':
                scheduler.step()
            
            elif config['scheduler'] == 'plat':
                scheduler.step(loss.item())
        train_scores_new = {}
        for k in train_scores[0].keys():
            train_scores_new[k] = []
            for t in train_scores:
                train_scores_new[k].append(t[k])
            train_scores_new[k] = sum(train_scores_new[k]) / len(train_scores_new[k])
        train_scores = train_scores_new
        train_loss.append(sum(epoch_loss) / len(epoch_loss))

        with torch.no_grad():
            model.eval()
            epoch_loss = []
            valid_epoch = []
            valid_pred = []
            valid_lb = []
            test_pred = []
            test_lb = []
            for i in range(len(seq_list)):
                input_dict = {
                    'coord': coord_list[i].to(device),#[cur].to(device),
                    'label': label_list[i].to(device),#[cur].to(device),
                    'loss_mask': torch.ones([label_list[i].shape[0]]).to(device).bool(),
                }
                input_dict['x_seq'] = seq_list[i].to(device)#.index_select(0, cur.to(device))
                x_dict = XDict(input_dict)

                out_dict, loss = model(x_dict, gene_list)
                epoch_loss.append(loss.item())
                valid_pred.append(out_dict['pred'][valid_idx[i]])
                valid_lb.append(out_dict['label'][valid_idx[i]])
                test_pred.append(out_dict['pred'][test_idx[i]])
                test_lb.append(out_dict['label'][test_idx[i]])

                valid_scores = downstream_eval(task, torch.cat(valid_pred), torch.cat(valid_lb), **eval_dict)
                test_scores = downstream_eval(task, torch.cat(test_pred), torch.cat(test_lb), **eval_dict)
                valid_epoch.append(valid_scores['f1_score'])
        valid_loss.append(sum(epoch_loss) / len(epoch_loss))
        valid_metric.append(sum(valid_epoch) / len(valid_epoch))
        test_metric.append(test_scores['f1_score'])
        
        if task == 'annotation':
            print(f'Epoch {epoch} | Train loss: {train_loss[-1]:.4f} | Valid loss: {valid_loss[-1]:.4f}')
            print(
                f'Train ACC: {train_scores["acc"]:.4f} | valid ACC: {valid_scores["acc"]:.4f} | test ACC: {test_scores["acc"]:.4f}')
            print(
                f'Train f1: {train_scores["f1_score"]:.4f} | valid f1: {valid_scores["f1_score"]:.4f} | test f1: {test_scores["f1_score"]:.4f}')
            print(
                f'Train pre: {train_scores["precision"]:.4f} | valid pre: {valid_scores["precision"]:.4f} | test pre: {test_scores["precision"]:.4f}')

        if max(valid_metric) == valid_metric[-1]:
            temp = deepcopy(model.state_dict())
            final_test = test_scores["f1_score"]
            final_epoch = epoch
            
        if max(valid_metric) != max(valid_metric[-config['es']:]):
            print('Early stopped.')
            break
            
    print({'max_test_f1': max(test_metric), 'final_test_f1': final_test, 'final_epoch': final_epoch, 'seed': args.seed, 'dataset': args.dataset})

def create_sparse_tensor(x):
    return torch.sparse_csr_tensor(x.indptr, x.indices, x.data, (x.shape[0], x.shape[1])).to_sparse().float().coalesce()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default='MS')
    parser.add_argument("--seed", type=int, default=10)
    parser.add_argument("--epochs", type=int, default=2000)
    parser.add_argument("--pre_model", type=str, default='20230926_75M_50M')
    args = parser.parse_args()
    set_seed(args.seed)

    DATASET = args.dataset
    if DATASET == 'hPancreas':
        data_train = ad.read_h5ad(f'../data/demo_train.h5ad')
        data_test = ad.read_h5ad(f'../data/demo_test.h5ad')
        train_num = data_train.shape[0]
        data = ad.concat([data_train, data_test])
        data.obs['celltype'] = data.obs['Celltype']
        data.obs['batch'] = 0
        
        import mygene
        mg = mygene.MyGeneInfo()
        data.var.index = mg.querymany(data.var.index.tolist(), scopes='symbol', fields='ensembl.gene', as_dataframe=True, species='human').reset_index().drop_duplicates(subset='query')['ensembl.gene'].fillna('0').tolist()
        data.var_names_make_unique()
        
    elif DATASET == 'MS':
        data_train = ad.read_h5ad(f'../data/c_data.h5ad')
        data_test = ad.read_h5ad(f'../data/filtered_ms_adata.h5ad')
        data_train.var = data_train.var.set_index('index_column')
        data_test.var = data_test.var.set_index('index_column')
        train_num = data_train.shape[0]
        data = ad.concat([data_train, data_test])
        data.obs['batch'] = 0
        data.var_names_make_unique()

    with (open(f"../ckpt/{args.pre_model}.config.pkl", "rb")) as openfile:
        config = pickle.load(openfile)
    pretrain_gene_list = config['gene_list']
    gene_list = data.var.index.tolist()
    gene_list = [x for x in gene_list if x in pretrain_gene_list]
    data = data[:, gene_list]
    print(data.shape)

    train_rate = 0.8
    valid_rate = 0.1
    order = np.arange(data.shape[0])
    labels = LabelEncoder().fit_transform(data.obs['celltype'])
    batch_labels = LabelEncoder().fit_transform(data.obs['batch'])
    out_dim = len(np.unique(labels))

    train_idx = []
    valid_idx = []
    test_idx = []
    
    for batch in tqdm(range(batch_labels.max() + 1)):

        tr = torch.randperm(train_num).long()
        temp = torch.zeros([data[batch_labels == batch].shape[0]])
        temp[tr[:int(train_num*0.85)]] = 1
        train_idx.append(temp.bool())
        temp = torch.zeros([data[batch_labels == batch].shape[0]])
        temp[tr[int(train_num*0.85):train_num]] = 1
        valid_idx.append(temp.bool())
        temp = torch.zeros([data[batch_labels == batch].shape[0]])
        temp[train_num:] = 1
        test_idx.append(temp.bool())


    config['es'] = 200
    config['lr'] = 5e-3
    config['wd'] = 1e-7
    config['scheduler'] = 'plat'
    config['drop_node_rate'] = 0.3
    config['dec_layers'] = 1
    config['model_dropout'] = 0.5
    config['mask_node_rate'] = 0.75
    config['mask_feature_rate'] = 0.25
    config['dec_mod'] = 'mlp'
    config['latent_mod'] = 'ae'
    config['epochs'] = args.epochs
    config['head_type'] = 'annotation'
    config['mask_node_rate'] = 0.75
    config['out_dim'] = out_dim
    config['max_batch_size'] = 70000
    config['hvg'] = 3000
    config['batch_num'] = batch_labels.max() + 1
    config['patience'] = 20

    main('annotation')
