import anndata as ad
import hdf5plugin
import scipy
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
import pandas as pd
import math
import logging
import os
import scanpy as sc
import pickle
import argparse
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from copy import deepcopy
from dance.utils import set_seed
from scipy.sparse import csr_matrix
from CellPLM.utils.eval import imputation_eval
from CellPLM.utils.data import XDict, stratified_sample_genes_by_sparsity, data_setup
from CellPLM.utils.mask import InputDropoutMaskBuilder
from CellPLM.model import OmicsFormer

def main(config=None):
    global seq_list, batch_list, batch_labels, order_list, dataset_list, label_list, coord_list, input_genes, target_genes_idx, \
        input_genes_idx, all_genes_idx, label_target_genes_idx, label_all_genes_idx, label_input_genes_idx


    device = torch.device('cuda')
    config = dict(config)
    val_num = config['val_num']
#     config["batch_num"] = batch_labels.max() + 1

    model = OmicsFormer(**config)
    if not args.from_scratch:
        pretrained_file = f'../ckpt/{args.pre_model}.best.ckpt'
        pretrained_model_dict = torch.load(pretrained_file)['model_state_dict']
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_model_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }
        for k, v in pretrained_dict.items():
            print(f"Loading params {k} with shape {v.shape}")
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
    
    input_genes_idx = torch.tensor(input_genes_idx).to(device)
    all_genes_idx = torch.tensor(all_genes_idx).to(device)
    target_genes_idx = torch.tensor(target_genes_idx).to(device)
    label_target_genes_idx = torch.tensor(label_target_genes_idx).to(device)
    label_all_genes_idx = torch.tensor(label_all_genes_idx).to(device)
        
    model.objective.layers[0].downstream = 'imputation'
    print(model)
    model.to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=config['wd'])
    scheduler = ReduceLROnPlateau(optim, 'min', patience=config['patience'], factor=0.9)

    eval_dict = {}
    train_loss = []
    valid_loss = []
    test_loss = []

    for epoch in (pbar := tqdm(range(config['epochs']))):
        if args.zeroshot:
            break
        if epoch < 5:
            for param_group in optim.param_groups:
                param_group['lr'] = config['lr'] * (epoch + 1) / 5
        epoch_loss = []
        model.train()
        for i in range(len(seq_list)):
            input_dict = {'x_seq': seq_list[i].to(device),
                          'coord': coord_list[i].to(device), 
                          'label': label_list[i].to(device),
                          'gene_mask': input_genes_idx,#input_genes_idx if dataset_list[i]=='query' else all_genes_idx,
                          'y_gene_mask': label_input_genes_idx,}#label_all_genes_idx}#label_input_genes_idx if dataset_list[i]=='query' else label_all_genes_idx}
            x_dict = XDict(input_dict)
            out_dict, loss = model(x_dict, input_genes)

            optim.zero_grad()
            loss.backward()
            optim.step()
            epoch_loss.append(loss.item())
        train_loss.append(math.sqrt(sum(epoch_loss) / len(epoch_loss)))


        with torch.no_grad():
            model.eval()
            epoch_loss = []
            pred_list = []
            y_list = []
            for i in range(val_num):
                input_dict = {'x_seq': seq_list[-i].to(device),
                              'coord': coord_list[-i].to(device), 
                              'label': label_list[-i].to(device),
                              'gene_mask': input_genes_idx,#input_genes_idx if dataset_list[-i] == 'query' else all_genes_idx,
                              'y_gene_mask': label_input_genes_idx,
                             }#label_input_genes_idx if dataset_list[-i]=='query' else label_all_genes_idx}
                x_dict = XDict(input_dict)
                out_dict, loss = model(x_dict, input_genes)
                epoch_loss.append(loss)
            valid_loss.append(math.sqrt(sum(epoch_loss) / len(epoch_loss)))

        scheduler.step(valid_loss[-1])
        print(f'Epoch {epoch} | Train loss: {train_loss[-1]:.4f} | Valid loss: {valid_loss[-1]:.4f}')

        if min(valid_loss) == valid_loss[-1]:
            temp = deepcopy(model.state_dict())

        if epoch % 5 == 0:
            with torch.no_grad():
                epoch_loss = []
                pred_list = []
                y_list = []
                for i in range(len(seq_list)):
                    if dataset_list[i] == 'ref':
                        continue
                    input_dict = {'x_seq': seq_list[i].to(device),
                                  'coord': coord_list[i].to(device), 
                                  'label': label_list[i].to(device),
                                  'gene_mask': target_genes_idx, 
                                  'y_gene_mask': label_target_genes_idx}
                    x_dict = XDict(input_dict)
                    out_dict, loss = model(x_dict, input_genes)
                    pred = out_dict['pred']
                    sf = 1e4 / x_dict['label'].sum(1, keepdim=True)
                    y = x_dict['label'][:, label_target_genes_idx]
                    loss = F.mse_loss(pred[y!=0], y[y!=0])
                    y = torch.log1p(y * sf)
                    epoch_loss.append(loss.item())
                    pred_list.append(torch.log1p(pred * sf))
                    y_list.append(y)
                test_loss.append(math.sqrt(sum(epoch_loss) / len(epoch_loss)))
                scores = imputation_eval(torch.cat(pred_list), torch.cat(y_list))
                print(f'test: {scores}')

    # Inference
    if not args.zeroshot:
        model.load_state_dict(temp)
    pred_list = []
    y_list = []
    emb_list = []
    with torch.no_grad():
        for i in range(len(seq_list)):
            input_dict = {'x_seq': seq_list[i].to(device),
                        'coord': coord_list[i].to(device), 
                        'label': label_list[i].to(device),
                        'gene_mask': target_genes_idx, 
                        'y_gene_mask': label_target_genes_idx}
            x_dict = XDict(input_dict)
            out_dict, loss = model(x_dict, input_genes)
            pred = out_dict['pred']
            sf = 1e4 / x_dict['label'].sum(1, keepdim=True)
            y = torch.log1p(x_dict['label'][:, label_target_genes_idx] * sf)
            pred_list.append(torch.log1p(pred * sf))
            y_list.append(y)
            emb_list.append(out_dict['latent'].cpu())
    scores = imputation_eval(torch.cat(pred_list), torch.cat(y_list))
    print(f'final: {scores}')
    global obs

def create_sparse_tensor(x, i):
    return torch.sparse_csr_tensor(x[0][i], x[1][i],
                                            x[2][i],
                                            x[3][i].tolist()).to_sparse().float().coalesce()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default='imputation')
    parser.add_argument("--dataset", type=str, default='lung') # lung, liver
    parser.add_argument("--pre_model", type=str, default='20230926_75M_50M')
    parser.add_argument("--seed", type=int, default=10)
    parser.add_argument("--tune", action='store_true')
    parser.add_argument("--from_scratch", action='store_true')
    parser.add_argument("--zeroshot", action='store_true')
    args = parser.parse_args()
    set_seed(args.seed)
    torch.set_num_threads(12)

    # Data Setup
    task = args.task
    if args.dataset == 'lung':
        query_dataset = 'HumanLungCancerPatient2_filtered_ensg.h5ad'
        ref_dataset = 'GSE131907_Lung_ensg.h5ad'

        query_data = ad.read_h5ad(f'../data/{query_dataset}')
        ref_data = ad.read_h5ad(f'../data/{ref_dataset}')
        
    elif args.dataset == 'liver':
        query_dataset = 'HumanLiverCancerPatient2_filtered_ensg.h5ad'
        ref_dataset = 'GSE151530_Liver_ensg.h5ad'
        query_data = ad.read_h5ad(f'../data/{query_dataset}')
        ref_data = ad.read_h5ad(f'../data/{ref_dataset}')

    query_data = query_data[query_data.obs.fov.isin(range(1,101))] # subset first 100 fovs
    ref_gene_list = ref_data.var.index.to_list()
    query_gene_list = query_data.var.index.to_list()
    
    with (open(f"../ckpt/{args.pre_model}.config.pkl", "rb")) as openfile:
        config = pickle.load(openfile)
    pretrained_gene_list = config['gene_list']
    ref_gene_list = [x for x in ref_gene_list if x in pretrained_gene_list]
    query_gene_list = [x for x in query_gene_list if x in pretrained_gene_list]
    ref_data = ref_data[:, ref_gene_list]
    query_data = query_data[:, query_gene_list]

    sc.pp.filter_genes(ref_data, min_cells=1)
    sc.pp.filter_cells(ref_data, min_genes=1)
    sc.pp.filter_genes(query_data, min_cells=1)
    sc.pp.filter_cells(query_data, min_genes=1)

    query_gene_list = list(set(query_gene_list).intersection(set(ref_gene_list)))
    query_data = query_data[:, query_gene_list]
    ref_data = ref_data[:, query_gene_list]

    ref_data.obs['batch'] = ref_data.obs['Sample'].astype('str')
    query_data.obs['batch'] = query_data.obs['fov'].astype('str')
    query_data.obs['x_FOV_px'] = query_data.obs['center_x']
    query_data.obs['y_FOV_px'] = query_data.obs['center_y']
    query_data.obs['platform'] = 'cosmx'
    ref_data.obs['platform'] = '10x'
    ref_data.obs['cell_type'] = 'NA'
    ref_data.obs['Dataset'] = 'ref'
    query_data.obs['Dataset'] = 'query'

    data = query_data.concatenate(ref_data, join='inner', batch_key='null', index_unique=None)
    sc.pp.filter_cells(data, min_genes=1)
    print(data.shape)
    obs = data.obs
    
    target_genes = stratified_sample_genes_by_sparsity(query_data, seed=args.seed)
    del ref_data, query_data
    label_target_genes_idx = target_genes_idx = np.concatenate([np.where(data.var.index == x)[0] for x in target_genes])
    label_input_genes_idx = input_genes_idx = [i for i in range(data.shape[1]) if i not in target_genes_idx]
    label_all_genes_idx = all_genes_idx = np.arange(data.shape[1])
    input_genes = data.var.index[input_genes_idx].to_list()
    input_data = data[:, input_genes_idx]
    seq_list, batch_list, batch_labels, order_list, dataset_list, coord_list, _ = data_setup(input_data)
    seq_list_new = []
    for i in range(len(seq_list[0])):
        seq_list_new.append(create_sparse_tensor(seq_list, i))
    seq_list = seq_list_new
    label_list = []
    for i in order_list:
        label_list.append(torch.from_numpy(data.X[i].todense()).float())
    
    all_genes_idx = np.array([pretrained_gene_list.index(x) for x in data.var.index])
    target_genes_idx = all_genes_idx[target_genes_idx]
    input_genes_idx = all_genes_idx[input_genes_idx]

    config['mask_node_rate'] = 0.95
    config['mask_feature_rate'] = 0.25
    config['epochs'] = 400
    config['lr'] = 1e-3
    config['wd'] = 1e-6
    config['patience'] = 10
    config['val_num'] = 1 
    main(config)

