import anndata as ad
import hdf5plugin
import scipy
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
import pandas as pd
import math
import logging
import scanpy as sc
import pickle
import argparse
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau
from tqdm import tqdm
from copy import deepcopy
from dance.utils import set_seed
from scipy.sparse import csr_matrix
from CellPLM.utils.eval import downstream_eval
from CellPLM.utils.data import XDict, clean_batches
from CellPLM.utils.mask import InputDropoutMaskBuilder
from CellPLM.model import OmicsFormer


def main(task=None, config=None):
    global gene_list, batch_labels, seq_list, order_list, coord_list, label_list
    tune_flag = True if config is None else False
    task = 'denoising'
    
    group = f"den_{args.pre_model}_{args.dataset}_293t"
    

    device = torch.device('cuda')

    model = OmicsFormer(**config)
    if not args.from_scratch:
        pretrained_file = f'../ckpt/{args.pre_model}.best.ckpt'
        pretrained_model_dict = torch.load(pretrained_file)['model_state_dict']
        try:
            model.load_state_dict(pretrained_model_dict)
        except:
            # only load params that are in the model and match the size
            model_dict = model.state_dict()
            pretrained_dict = {
                k: v
                for k, v in pretrained_model_dict.items()
                if k in model_dict and v.shape == model_dict[k].shape
            }
            for k, v in pretrained_dict.items():
                print(f"Loading params {k} with shape {v.shape}")
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)
    model.objective.layers[0].downstream = 'denoising'
    print(model)
    model.to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=config['wd'])
    scheduler = ReduceLROnPlateau(optim, 'min', patience=config['patience'], factor=0.9)

    train_loss = []
    valid_loss = []
    valid_metric = []
    # for epoch in (pbar := tqdm(range(config['epochs']))):
    for epoch in tqdm(range(config['epochs'])):
        model.train()
        batch_loss = []
        if args.zeroshot:
            break
            
        for i in range(len(seq_list)):
            minibatch_loss = []
            for j in range(len(train_list[i])):
                input_dict = {
                    'x_seq': train_list[i][j].to(device),
#                     'batch': batch_list[i][train_batch_idx[i][j]].to(device),
                    'coord': coord_list[i][train_batch_idx[i][j]].to(device),
                    'label': label_list[i][train_batch_idx[i][j]].to(device),
                    'gene_mask': all_genes_idx.to(device),
                }
                x_dict = XDict(input_dict)
                out_dict, loss = model(x_dict, gene_list)  
                optim.zero_grad()
                loss.backward()
                optim.step()
                minibatch_loss.append(loss.item())
            batch_loss.append(sum(minibatch_loss) / len(minibatch_loss))
        train_loss.append(sum(batch_loss) / len(batch_loss))
        scheduler.step(train_loss[-1])

        with torch.no_grad():
            model.eval()
            batch_loss = []
            for i in range(len(seq_list)):
                valid_epoch = []
                minibatch_loss = []
                pred = []
                label = []
                batch_order = []
                for j in range(len(valid_list[i])):
                    input_dict = {
                        'x_seq': valid_list[i][j].to(device),
#                         'batch': batch_list[i][valid_batch_idx[i][j]].to(device),
                        'coord': coord_list[i][valid_batch_idx[i][j]].to(device),
                        'label': label_list[i][valid_batch_idx[i][j]].to(device),
                        'gene_mask': all_genes_idx.to(device),
                    }
#                     import ipdb
#                     ipdb.set_trace()
                    x_dict = XDict(input_dict)
                    out_dict, loss = model(x_dict, gene_list)
                    pred.append(out_dict['pred'])
                    label.append(x_dict['label'])
                    batch_order.append(valid_batch_idx[i][j])
                    minibatch_loss.append(loss.item())
                batch_loss.append(sum(minibatch_loss) / len(minibatch_loss))

                batch_order = torch.cat(batch_order)
                pred = torch.cat(pred)[batch_order]
                label = torch.cat(label)[batch_order]
                valid_mask = valid_mask_list[i]
                valid_scores = downstream_eval(task, pred, label, eval_mask=valid_mask)#, normalize=False)
                valid_epoch.append(valid_scores['rmse'])
                test_mask = test_mask_list[i]
                test_scores = downstream_eval(task, pred, label, eval_mask=test_mask)#, normalize=False)

        valid_loss.append(sum(batch_loss) / len(batch_loss))
        valid_metric.append(sum(valid_epoch) / len(valid_epoch))
        # scheduler.step(valid_metric[-1])
        # pbar.set_description(f'Epoch {epoch} | Train loss: {train_loss[-1]:.4f} | Valid loss: {valid_loss[-1]:.4f}')

        if task == 'denoising':
            print(f'Epoch {epoch} | Train loss: {train_loss[-1]:.4f} | Valid loss: {valid_loss[-1]:.4f}')
            print(f'Valid RMSE: {valid_scores["rmse"]:.4f} | Test RMSE: {test_scores["rmse"]:.4f}')
            print(f'Valid Corr: {valid_scores["corr"]:.4f} | Test Corr: {test_scores["corr"]:.4f}')


        if min(valid_metric) == valid_metric[-1]:
            temp = deepcopy(model.state_dict())
        # if epoch > 0 and min(valid_loss[-20:]) != min(valid_loss):
        #     print('Early stopped.')
        #     break

    # Inference
    if not args.zeroshot:
        model.load_state_dict(temp)
    final_pred = []
    final_label = []
    model.eval()
    with torch.no_grad():
        for i in range(len(seq_list)):
            valid_epoch = []
            minibatch_loss = []
            pred = []
            label = []
            batch_order = []
            for j in range(len(valid_list[i])):
                input_dict = {
                    'x_seq': valid_list[i][j].to(device),
#                     'batch': batch_list[i][valid_batch_idx[i][j]].to(device),
                    'coord': coord_list[i][valid_batch_idx[i][j]].to(device),
                    'label': label_list[i][valid_batch_idx[i][j]].to(device),
                    'gene_mask': all_genes_idx.to(device),
                }
                x_dict = XDict(input_dict)
                out_dict, loss = model(x_dict, gene_list)
                pred.append(out_dict['pred'])
                label.append(x_dict['label'])
                batch_order.append(valid_batch_idx[i][j])
            batch_order = torch.cat(batch_order)
            pred = torch.cat(pred)[batch_order]
            label = torch.cat(label)[batch_order]
            
            final_pred.append(pred.cpu())
            final_label.append(label.cpu())
        del loss, out_dict, model
    torch.cuda.empty_cache()

    if task == 'denoising':
        pred = torch.cat(final_pred)
        label = torch.cat(final_label)
        scores = downstream_eval(task, pred, label, eval_mask=test_mask_list[0])
        print(scores)
    # elif task == 'denoising':
        # print(f"Corr: {sum(c) / len(c)}, RMSE: {sum(rmse) / len(rmse)}, MAE: {sum(mae) / len(mae)}")
    # del res, y, c, df
    del pred, label#, temp
    torch.cuda.empty_cache()

def create_sparse_tensor(x):
    return torch.sparse_csr_tensor(x.indptr, x.indices, x.data, (x.shape[0], x.shape[1])).to_sparse().float()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default='denoising')
    parser.add_argument("--dataset", type=str, default='5k_pbmc')
    parser.add_argument("--from_scratch", action='store_true')
    parser.add_argument("--pre_model", type=str, default='20230926_75M_50M')
    parser.add_argument("--batch_size", type=int, default=70000)
    parser.add_argument("--patience", type=int, default=10)
    parser.add_argument("--seed", type=int, default=10)
    parser.add_argument("--tune", action='store_true')
    parser.add_argument("--zeroshot", action='store_true')
    args = parser.parse_args()
    set_seed(args.seed)
    batch_size = args.batch_size
    patience = args.patience
    torch.set_num_threads(12)

    # Data Setup
    task = args.task
    if args.dataset == '5k_pbmc':
        dataset_name = '5k_pbmc_filtered.h5ad'
        data = ad.read_h5ad(f'../data/{dataset_name}')
        data.obs['batch'] = 0
        data.var.index = data.var['gene_ids'].tolist()
        data.var_names_make_unique()
    elif args.dataset == 'jurkat':
        dataset_name = 'jurkat_filtered.h5ad'
        data = ad.read_h5ad(f'../data/{dataset_name}')
        data.var.index = data.var['ENSG'].tolist()
        data.obs['batch'] = 0
        data.var_names_make_unique()
    elif args.dataset == '293t':
        dataset_name = '293t_filtered.h5ad'
        data = ad.read_h5ad(f'../data/{dataset_name}')
        data.var.index = data.var['ENSG'].tolist()
        data.obs['batch'] = 0
        data.var_names_make_unique()
    
    with (open(f"../ckpt/{args.pre_model}.config.pkl", "rb")) as openfile:
        config = pickle.load(openfile)
    pretrained_gene_list = config['gene_list']
    gene_list = [x for x in data.var.index.to_list() if x in pretrained_gene_list]
    data = data[:, gene_list]
    sc.pp.filter_genes(data, min_cells=data.shape[0] * 0.05)
    sc.pp.filter_cells(data, min_counts=1)
    data.raw = data
    print(data.shape)
    all_genes_idx = torch.tensor([pretrained_gene_list.index(x) for x in data.var.index])
    mask_builder = InputDropoutMaskBuilder(input_drop_type="mar", valid_drop_rate=0.1, 
                                           test_drop_rate=0.1, seed=args.seed)
    order = np.arange(data.shape[0])
    batch_labels = LabelEncoder().fit_transform(data.obs['batch'])
    train_batch_idx = []
    valid_batch_idx = []
    train_list = []
    valid_list = []
    valid_mask_list = []
    test_mask_list = []
    seq_list = []
    batch_list = []
    order_list = []
    coord_list = []
    label_list = []

    for batch in tqdm(range(batch_labels.max() + 1)):
        x_raw = data[batch_labels == batch].raw.X.astype(float)   
        x = data[batch_labels == batch].X.astype(float)   
        train_mask, valid_mask, test_mask = mask_builder.apply_mask(create_sparse_tensor(x_raw))
        
        train_batch = []
        valid_batch = []
        train_minibatch_list = []
        valid_minibatch_list = []
        train_loader = DataLoader(range(len(range(x.shape[0]))), batch_size=batch_size, shuffle=True)
        for _, minibatch in enumerate(train_loader):
            train_batch.append(create_sparse_tensor(csr_matrix(x.toarray() * train_mask)[minibatch]))
            train_minibatch_list.append(minibatch)
        train_list.append(train_batch)
        train_batch_idx.append(train_minibatch_list)

        valid_loader = DataLoader(range(len(range(x.shape[0]))), batch_size=batch_size, shuffle=False)
        for _, minibatch in enumerate(valid_loader):
            valid_batch.append(create_sparse_tensor(csr_matrix(x.toarray() * train_mask)[minibatch]))
            valid_minibatch_list.append(minibatch)
        valid_list.append(valid_batch)
        valid_batch_idx.append(valid_minibatch_list)

        valid_mask_list.append(valid_mask)
        test_mask_list.append(test_mask)
        
        seq_list.append(create_sparse_tensor(csr_matrix(x.toarray() * train_mask)))
        order_list.append(order[batch_labels == batch])
        batch_list.append(torch.from_numpy(batch_labels[batch_labels == batch]))
        coord_list.append(torch.zeros(x.shape[0], 2)-1)
        label_list.append(torch.from_numpy(x_raw.A))

#     out_dim = len(gene_list)
    del data, order, x



    config['mask_node_rate'] = 0.95
    config['mask_feature_rate'] = 0.25
    config['epochs'] = 2000
    config['lr'] = 1e-3
    config['wd'] = 1e-7
    config['patience'] = 200.
    main(task, config)

