import time
import scanpy as sc
import numpy as np
import argparse
import warnings
import os
import gc
import argparse
import json
import random
import math
import random
from functools import reduce
import numpy as np
import pandas as pd
from scipy import sparse
from sklearn.model_selection import train_test_split, ShuffleSplit, StratifiedShuffleSplit, StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_recall_fscore_support, classification_report
import torch
from torch import nn
from torch.optim import Adam, SGD, AdamW
from torch.nn import functional as F
from performer_pytorch import PerformerLM
import scanpy as sc
import anndata as ad
from general_utils import *
import pickle as pkl
warnings.filterwarnings('ignore')

parser = argparse.ArgumentParser(description='Process single-cell data.')
parser.add_argument('--input_adata', type=str, required=True, help='Path to input .h5ad file')
parser.add_argument('--output_adata', type=str, required=True, help='Path to output .npy file')
parser.add_argument('--model_path', type=str, required=True, help='Path to scBERT pretrained model checkpoint')
args = parser.parse_args()

adams = sc.read(args.input_adata)
panglao = sc.read_h5ad(args.model_path+'panglao_10000.h5ad')
counts = sparse.lil_matrix((adams.X.shape[0],panglao.X.shape[1]),dtype=np.float32)
ref = panglao.var_names.tolist()
obj = adams.var_names.tolist()

for i in range(len(ref)):
    if ref[i] in obj:
        loc = obj.index(ref[i])
        counts[:,i] = adams.X[:,loc]

counts = counts.tocsr()

data = ad.AnnData(X=counts)
data.var_names = ref
data.obs_names = adams.obs_names
data.obs = adams.obs
data.uns = panglao.uns

sc.pp.filter_cells(data, min_genes=200)
sc.pp.normalize_total(data, target_sum=1e4)
sc.pp.log1p(data, base=2)
SEED = 2021
EPOCHS = 100

SEQ_LEN = data.var.shape[0] + 1
UNASSIGN = False
UNASSIGN_THRES = 0.5 if UNASSIGN == True else 0
CLASS = 5+2

POS_EMBED_USING = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open(args.model_path+'label_dict', 'rb') as fp:
    label_dict = pkl.load(fp)

with open(args.model_path+'label', 'rb') as fp:
    label = pkl.load(fp)

class_num = np.unique(label, return_counts=True)[1].tolist()
class_weight = torch.tensor([(1 - (x / sum(class_num))) ** 2 for x in class_num])
label = torch.from_numpy(label)
data = data.X

model = PerformerLM(
    num_tokens = CLASS,
    dim = 200,
    depth = 6,
    max_seq_len = SEQ_LEN,
    heads = 10,
    local_attn_heads = 0,
    g2v_position_emb = True
)

model.to_out = Identity(dropout=0., h_dim=128, out_dim=label_dict.shape[0])
path = args.model_path+'panglao_pretrain.pth'

ckpt = torch.load(path)
model.load_state_dict(ckpt['model_state_dict'])

for param in model.parameters():
    param.requires_grad = False

model = model.to(device)

batch_size = data.shape[0]
model.eval()
pred_finals = []
novel_indices = []

with torch.no_grad():
    for index in range(batch_size):
        full_seq = data[index].toarray()[0]
        full_seq[full_seq > (CLASS - 2)] = CLASS - 2
        full_seq = torch.from_numpy(full_seq).long()
        full_seq = torch.cat((full_seq, torch.tensor([0]))).to(device)
        full_seq = full_seq.unsqueeze(0)

        pred_logits = model(full_seq)
        softmax = nn.Softmax(dim=-1)
        pred_prob = softmax(pred_logits)
        pred_final = pred_prob.argmax(dim=-1).item()

        if np.amax(np.array(pred_prob.cpu()), axis=-1) < UNASSIGN_THRES:
            novel_indices.append(index)

        pred_finals.append(pred_final)

pred_list = label_dict[pred_finals].tolist()
for index in novel_indices:
    pred_list[index] = 'Unassigned'

np.save(args.output_adata, pred_list)
