import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import scanpy as sc
import pandas as pd
import pickle

import timm
import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login

import time
import skimage.io
import numpy as np
import pandas as pd
import cv2
import PIL.Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler, RandomSampler, SequentialSampler
from sklearn.model_selection import StratifiedKFold
import matplotlib.pyplot as plt
from sklearn.metrics import cohen_kappa_score
from tqdm import tqdm_notebook as tqdm


import argparse
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
from numpy.random import uniform, choice, normal
from torch import nn, optim, Tensor, manual_seed, argmax
from torch.utils.data import TensorDataset, DataLoader
from torchmetrics.classification import Accuracy, MulticlassConfusionMatrix
from pytorch_lightning.utilities.model_summary import ModelSummary
import pytorch_lightning as pl
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from torch.autograd import Variable
import pandas as pd
import torch

import os
import torch
os.environ['HF_HOME'] = '/home/username/scratch/'
os.environ['HF_TOKEN'] = '' #fill your token

import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login

from PIL import Image
from matplotlib import cm
import numpy as np

import pickle

import timm
from PIL import Image
from torchvision import transforms
import torch

# pretrained=True needed to load UNI weights (and download weights for the first time)
# using UNI2-h as example
timm_kwargs = {
   'img_size': 224, 
   'patch_size': 14, 
   'depth': 24,
   'num_heads': 24,
   'init_values': 1e-5, 
   'embed_dim': 1536,
   'mlp_ratio': 2.66667*2,
   'num_classes': 0, 
   'no_embed_class': True,
   'mlp_layer': timm.layers.SwiGLUPacked, 
   'act_layer': torch.nn.SiLU, 
   'reg_tokens': 8, 
   'dynamic_img_size': True
  }
model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
model.eval()


class REGDataset(Dataset):
    def __init__(self,
                 adata,
                 image,
                 transform,
                 gene = 'NOSTRIN'
                ):

        self.adata = adata[:,gene]
        self.image = image
        self.transform = transform

    def __len__(self):
        return self.adata.shape[0]

    def __getitem__(self, index):
        input_data = self.image[index]
        input_data = torch.from_numpy(input_data)
        output_data = torch.FloatTensor(self.adata[str(index)].X)[0,:]
        
        return input_data, output_data

gene_list = {"genes": [
        "KRT15",
        "CRCT1",
        "RHCG",
        "IGHG1",
        "GGCT",
        "ASPRV1",
        "JCHAIN",
        "PI3",
        "SLURP1",
        "CA2",
        "SSFA2",
        "SPRR2E",
        "LCE3D",
        "MT1X",
        "IGHG4",
        "S100A7A",
        "HERC6",
        "WARS",
        "RPL22L1",
        "SDR16C5",
        "IL1RN",
        "MX1",
        "GGH",
        "CLEC2B",
        "CAST",
        "PTGS1",
        "IGFBP6",
        "PRSS8",
        "LAP3",
        "TXNL4A",
        "SERPINE2",
        "IGLC2",
        "KRT75",
        "C9orf3",
        "IGHG3",
        "SPRR2D",
        "IFIT1",
        "PPL",
        "LSM5",
        "TMEM256",
        "KRT2",
        "SPRR2G",
        "SNRPD1",
        "SPRR2A",
        "OAS1",
        "GBA",
        "GADD45GIP1",
        "LYZ",
        "SPINT1",
        "NAGK"
    ],}


# In[9]:


adata_train = sc.read("/home/username/scratch/train_data/adata_skin_norm_train.h5ad")
adata_valid = sc.read("/home/username/scratch/train_data/adata_skin_norm_valid.h5ad")
adata_test = sc.read("/home/username/scratch/train_data/adata_skin_norm_test.h5ad")

adata_train.obs_names =[str(i) for i in range(len(adata_train))]
adata_valid.obs_names =[str(i) for i in range(len(adata_valid))]
adata_test.obs_names =[str(i) for i in range(len(adata_test))]


image_train = np.load("/home/username/scratch/train_data/skin_train_imagefeature_univ2.npy")
image_valid = np.load("/home/username/scratch/train_data/skin_valid_imagefeature_univ2.npy")
image_test = np.load("/home/username/scratch/train_data/skin_test_imagefeature_univ2.npy")

adata_train.obs['batch'] = ['0']*3838 + ['1']*3650


# In[10]:


from sklearn.decomposition import PCA
pca = PCA(n_components=20)
adata_train.obsm['X_pca'] = pca.fit_transform(adata_train[:,gene_list['genes']].X)
adata_valid.obsm['X_pca'] = pca.fit_transform(adata_valid[:,gene_list['genes']].X)
adata_test.obsm['X_pca'] = pca.fit_transform(adata_test[:,gene_list['genes']].X)

parser = argparse.ArgumentParser()
args = argparse.Namespace()
args.scale=False

if args.scale:
    sc.pp.scale(adata_train)

import sklearn.cluster

import sys
sys.path.append('/home/username/piusername/UNI/')
sys.path.append('/home/username/piusername/UNI/diff_sample_code/paper_code/')
from diff_sample_code.paper_code.utils import check_dir, prepare_dset, update_print, get_relative_maha_distance, maha, \
    get_pretrained_model, get_maha_distance, MahaDistNormalizer, ranking_loss

class MahaDistNormalizer:
    def __init__(self,):
        super(MahaDistNormalizer, self).__init__()
        self.min, self.max = 1e20, -1e10

    def run(self, x, left, right):
        self.min = min(x.min(), self.min)
        self.max = max(x.max(), self.max)
        k = (right-left)/(self.max - self.min)
        return left+k*(x - self.min)
    
    

def my_softmax(X):
    X -= X.max()
    X_exp = X.exp()
    max_data = X_exp.max() + 0.001
    # partition = X_exp.sum(dim=1, keepdim=True)
    #print("X size is ", X_exp.size())
    #print("partition size is ", partition, partition.size())
    return X_exp / max_data

maha_normalizer = MahaDistNormalizer()

maha_intermediate_dict = np.load("/home/username/piusername/UNI/diff_sample_code/paper_code/ssl/maha_dict_univ2_skin_harmonyfix.npy", allow_pickle='TRUE')
class_cov_invs = maha_intermediate_dict.item()['class_cov_invs']
class_means = maha_intermediate_dict.item()['class_means']
cov_invs = maha_intermediate_dict.item()['cov_inv']
means = maha_intermediate_dict.item()['mean']


# maha_normalizer.min = np.min(alldata)
# maha_normalizer.max = np.max(alldata)

from sklearn.decomposition import PCA
pca = PCA(n_components=20, random_state=0)
adata_train.obsm['X_pca'] = pca.fit_transform(adata_train[:,gene_list['genes']].X)
adata_valid.obsm['X_pca'] = pca.fit_transform(adata_valid[:,gene_list['genes']].X)
adata_test.obsm['X_pca'] = pca.fit_transform(adata_test[:,gene_list['genes']].X)

if len(set(adata_train.obs['batch'])) > 1:
    sc.external.pp.harmony_integrate(adata_train, key='batch')
    adata_train.obsm['X_pca'] = adata_train.obsm['X_pca_harmony'].astype('double')
else:
    adata_train.obsm['X_pca'] = adata_train.obsm['X_pca'].astype('double')

import sklearn.cluster
import numpy as np
from sklearn.datasets import make_blobs
import sklearn.cluster
from sklearn.cluster import KMeans
from sklearn.preprocessing import KBinsDiscretizer
import numpy as np
# Initialize the discretizer
from sklearn.preprocessing import KBinsDiscretizer
import numpy as np
# Initialize the discretizer
n_bins = int(np.max(adata_train[:,gene_list['genes']].X)) + 1
from sklearn.metrics import silhouette_samples, silhouette_score
search_data = {}
for test_cluster in range(3,15):
    kbd = sklearn.cluster.KMeans(n_clusters=test_cluster, random_state=0, n_init="auto").fit(adata_train.obsm['X_pca'])
    pred_label = kbd.predict(adata_train.obsm['X_pca'])
    avg_sil = silhouette_score(adata_train.obsm['X_pca'], pred_label)
    search_data[test_cluster] = avg_sil

max_avg = np.argmax(list(search_data.values()))
n_bins = list(search_data.keys())[max_avg]
# n_bins = 3
kbd = sklearn.cluster.KMeans(n_clusters=n_bins, random_state=0, n_init="auto").fit(adata_train.obsm['X_pca'])
print("optimal clustering number", n_bins)
adata_train.obsm['bin'] = kbd.predict(adata_train.obsm['X_pca'])
adata_valid.obsm['bin'] = kbd.predict(adata_valid.obsm['X_pca'].astype('double'))
adata_test.obsm['bin'] = kbd.predict(adata_test.obsm['X_pca'].astype('double'))


class REGDataset(Dataset):
    def __init__(self,
                 adata,
                 image,
                 transform,
                 gene = 'NOSTRIN'
                ):

        self.adata = adata[:,gene]
        self.image = image
        self.transform = transform

    def __len__(self):
        return self.adata.shape[0]

    def __getitem__(self, index):
        input_data = self.image[index]
        input_data = torch.from_numpy(input_data)
        output_data = torch.FloatTensor(self.adata[str(index)].X)[0,:]
        output_label = torch.FloatTensor(self.adata[str(index)].obsm['bin']).long()
        
        return input_data, output_data, output_label
    
    

train_dataset_new = REGDataset(adata_train, image_train, transform, gene_list['genes'])
valid_dataset = REGDataset(adata_valid, image_valid, transform, gene_list['genes'])
test_dataset = REGDataset(adata_test, image_test, transform, gene_list['genes'])

train_dataloader = torch.utils.data.DataLoader(train_dataset_new, batch_size=512, shuffle=False, drop_last=False)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=512, shuffle=False, drop_last=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=512, shuffle=False, drop_last=False)

alldata = get_relative_maha_distance(image_train,cov_invs, class_cov_invs, means, class_means, adata_train.obsm['bin'])

class MahaDistNormalizer:
    def __init__(self,):
        super(MahaDistNormalizer, self).__init__()
        self.min, self.max = np.min(alldata), np.max(alldata)

    def run(self, x, left, right):
#         self.min = min(x.min(), self.min)
#         self.max = max(x.max(), self.max)
        k = (right-left)/(self.max - self.min)
        return left+k*(x - self.min)
maha_normalizer = MahaDistNormalizer()

class entropy_ce(nn.Module):
    def __init__(self):
        super(entropy_ce, self).__init__()
    
    def forward(self,x_input, y_target, e_lambda):
        p = F.softmax(x_input)
        # p = p.detach()
        # print(weight.shape)
        entropy = - torch.sum(p * F.log_softmax(x_input), dim=1).reshape(-1,1)
        
        
        x_input = F.log_softmax(x_input, 1)
        y_target = F.one_hot(y_target, num_classes=args.class_num)
        loss = - torch.sum(x_input * y_target, 1)
        loss = (1-e_lambda) * torch.mean(loss) -  e_lambda * torch.mean(entropy)
        
        return loss

import OrdinalEntropy


train_data= image_train
train_list= adata_train.obsm['bin']


import sklearn.linear_model

model = sklearn.linear_model.LogisticRegression()


import sklearn
import numpy as np
from sklearn.model_selection import KFold
nsplit = 3
kf = KFold(n_splits=nsplit,  shuffle=True, random_state=2024)

wrong_sample = []
for itd, (train_id, test_id) in enumerate(kf.split([i for i in range(len(train_data))])):
    x_train = train_data[train_id]
    train_label = train_list[train_id]
    x_val = train_data[test_id]
    val_label = train_list[test_id]
    
    model = sklearn.linear_model.LogisticRegression()
    model.fit(x_train, train_label)
    pred_label = model.predict(x_val)
    for idx, (i,j) in enumerate(zip(val_label, pred_label)):
        if i != j:
            wrong_sample.append(test_id[idx])

train_dataset_new = REGDataset(adata_train, image_train, transform, gene_list['genes'])
valid_dataset = REGDataset(adata_valid, image_valid, transform, gene_list['genes'])
test_dataset = REGDataset(adata_test, image_test, transform, gene_list['genes'])

def generate_new_dataset(origina_data,label_list):
    input_list = []
    output_list = []
    array_list = []
    corr_list = []
    for idx, (i,j, k) in enumerate(origina_data):
        if idx in label_list:
            array_list.append(1)
        else:
            array_list.append(0)
        input_list.append(i)
        output_list.append(j)
        corr_list.append(k)
        
    d1 = torch.tensor(np.array(input_list))
    d2 = torch.tensor(np.array(output_list))
    d3 = torch.tensor(np.array(array_list))
    d4 = torch.tensor(np.array(corr_list))
    return TensorDataset(d1,d2, d4, d3)

train_dataset_update = generate_new_dataset(train_dataset_new, sorted(set(wrong_sample)))

parser = argparse.ArgumentParser()
args = argparse.Namespace()
args.method = 'weight_ord_joint_wer'
args.epsilon = 1.0
args.alpha = 0.1 #0.05 by default, 0.1 is best
args.fgamma = 1.0
args.epsilon_p = 2.0
args.e_lambda = 0.3
args.gene_num = 50
args.class_num = n_bins
args.weight = 1.0
args.lr = 1e-3
args.T = 1.0
args.d_lambda = 1.0
args.class_weight = 1.0
args.n_split = nsplit

class entropy_ce(nn.Module):
    def __init__(self):
        super(entropy_ce, self).__init__()
    
    def forward(self,x_input, y_target, e_lambda):
        p = F.softmax(x_input)
        # p = p.detach()
        # print(weight.shape)
        entropy = - torch.sum(p * F.log_softmax(x_input), dim=1).reshape(-1,1)
        
        
        x_input = F.log_softmax(x_input, 1)
        y_target = F.one_hot(y_target, num_classes=args.class_num)
        loss = - torch.sum(x_input * y_target, 1)
        loss = (1-e_lambda) * torch.mean(loss) -  e_lambda * torch.mean(entropy)
        
        return loss
    
class weighted_poly(nn.Module):
    def __init__(self):
        super(weighted_poly, self).__init__()
    
    def forward(self, x_input, y_target, weight):
        weight = weight.reshape(-1, 1)
#         print(weight.shape)
        p = F.softmax(x_input)
        y_target = F.one_hot(y_target, num_classes=args.class_num)
        # p = p.detach()
        pt = torch.sum(p * y_target, dim=1).reshape(-1,1)
        poly_loss = weight * (1. - pt)
        
        x_input = F.log_softmax(x_input, 1)
        
        ce_loss = - torch.sum( x_input * y_target, 1)
        loss = torch.mean(ce_loss + poly_loss)
        # print(loss)

        return loss
    
class weighted_entropy_ce(nn.Module):
    def __init__(self):
        super(weighted_entropy_ce, self).__init__()
    
    def forward(self,x_input, y_target, weight, e_lambda):
        weight = weight.reshape(-1, 1)
        # print(weight.shape)
        p = F.softmax(x_input)
        # p = p.detach()
        entropy = - torch.sum(p * F.log_softmax(x_input), dim=1).reshape(-1,1)
        # print(entropy.shape)
        # print(rank_input1)
        

        weight_beta = e_lambda * weight
        # weight_1 = torch.ones_like(weight_beta) - weight_beta
        entropy = weight_beta * entropy
        # print(entropy)
        
        x_input = F.log_softmax(x_input, 1)
        y_target = F.one_hot(y_target, num_classes=args.class_num)
        loss = - torch.sum( x_input * y_target, 1)
        loss = torch.mean(loss) -  torch.mean(entropy)
        # print(loss)

        return loss



frame_list = []
for seed in range(0,5):
    pl.seed_everything(seed, workers=True)
    print("the seed is", seed)

    class Model(pl.LightningModule):
        def __init__(self, k, optimizer = 'Adam', dropout_rate = 0, fullmodel = False):
            super().__init__()
            self.num_classes = k
            self.model = nn.Sequential(*[nn.Linear(1536, 512), nn.ReLU(), nn.Linear(512,512)])
            self.regress_head = nn.Sequential(*[nn.Linear(512, args.gene_num), nn.Softplus()])
            self.class_head = nn.Linear(512, args.class_num)
            # Define other attributes
            self.optimizer = optimizer
            self.lr = {'Adam': args.lr, 'SGD': 0.1}[optimizer]
            self.test_pred = []  # collect predictions
            self.prob = [] #store prediction probability
            self.fullmodel = fullmodel
            self.class_weight = args.class_weight


        def compute_loss(self, inputs, probs, preds, target_class, target_cont, args, traintestind, idx):
            criterion_reg = OrdinalEntropy.BalancedPearsonCorrelationLoss()
            
            if traintestind == False:
                return criterion_reg(preds,target_cont)
        
            if args.method =='weight_ord_joint_wpoly':
                with torch.no_grad():
                    pre_feature =  inputs
                    targets = target_class.T[0]
                    maha_distance = get_relative_maha_distance(pre_feature.cpu().data.numpy(),cov_invs, class_cov_invs, means, class_means, targets.cpu().data.numpy())
                    maha_distance = torch.from_numpy(maha_distance)
                    maha_distance_normalized = maha_normalizer.run(maha_distance, -1., 1.)
                    maha_distance_normalized = my_softmax(maha_distance_normalized/args.T)

                    if idx != None: #option 2
                        index_list = torch.argwhere(idx.cpu() == 1)
                        if len(index_list)>=1:
                            index_list = index_list.T[0]
                            maha_distance_normalized[index_list,0] = 1.0
                    else:
                        maha_distance_normalized = maha_distance_normalized #did not change anything
                weight = maha_distance_normalized
                weighted_poly_l =  weighted_poly()
                ce_loss = criterion_reg(preds,target_cont) + args.d_lambda * OrdinalEntropy.ordinal_entropy(self.model(inputs),target_cont) + weighted_poly_l(probs,targets,weight.to(preds.device))
                
                
            elif args.method =='weight_ord_joint_wer':
                with torch.no_grad():
                    pre_feature =  inputs
                    targets = target_class.T[0]
                    maha_distance = get_relative_maha_distance(pre_feature.cpu().data.numpy(),cov_invs, class_cov_invs, means, class_means, targets.cpu().data.numpy())
                    maha_distance = torch.from_numpy(maha_distance)
                    maha_distance_normalized = maha_normalizer.run(maha_distance, -1., 1.)
                    maha_distance_normalized = my_softmax(maha_distance_normalized/args.T)
                    if idx != None: #option 2
                        index_list = torch.argwhere(idx.cpu() == 1)
                        if len(index_list)>=1:
                            index_list = index_list.T[0]
                            maha_distance_normalized[index_list,0] = 1.0
                    else:
                        maha_distance_normalized = maha_distance_normalized #did not change anything
                weight = maha_distance_normalized
                weighted_poly_l =  weighted_entropy_ce()
                ce_loss = criterion_reg(preds,target_cont) + args.d_lambda * OrdinalEntropy.ordinal_entropy(self.model(inputs),target_cont) + self.class_weight * weighted_poly_l(probs,targets,weight.to(preds.device),args.e_lambda)    
            else:
                part1 = criterion_reg(preds,target_cont)
                part2 = criterion_class(probs,target_class.T[0])
                ce_loss = part1 + part2

            return ce_loss

        def forward(self, x):
            out = self.model(x)
            out = self.regress_head(out)
            return out
        
        def classify(self, x):
            out = self.model(x)
            out = self.class_head(out)
            return out

        def configure_optimizers(self):
            if self.optimizer == 'Adam':
                optimizer = optim.Adam(self.parameters(), lr=self.lr)
            else:
                optimizer = optim.SGD(self.parameters(), lr=self.lr)
            return optimizer

        def training_step(self, batch, batch_idx):
            x, y, label, diff = batch
            preds = self.forward(x)
            logits = self.classify(x)
            loss = self.compute_loss(x,logits, preds,label,y,args,traintestind=True, idx = diff)
            self.log('loss', loss)
            return loss

        def validation_step(self, batch, batch_idx):
            x, y, label = batch
            preds = self.forward(x)
            logits = self.classify(x)
            loss = self.compute_loss(x,logits, preds,label,y,args,traintestind=False, idx = None)
            self.log('val_loss', loss)
            return loss

        def test_step(self, batch, batch_idx):
            x, y, label = batch
            preds = self.forward(x)
            logits = self.classify(x)
            loss = self.compute_loss(x,logits, preds,label,y,args,traintestind=False, idx =  None)
            self.log('test_loss', loss)
            # Collect predictions
            self.test_pred.extend(preds.cpu().numpy())

    # Create a PyTorch Lightning trainer and add callbacks
    pl.seed_everything(seed, workers=True)
    model_class = Model(k=args.gene_num, fullmodel = False).cuda()

    import lightning
    import glob
    import shutil

    early_stopping_callback = pl.callbacks.early_stopping.EarlyStopping(
        monitor = 'val_loss',
        patience = 10,
        min_delta = 0.005,
        mode = 'min',
    )

    dirpath = "./classify_ft_regression_weightord_new/"
    experiment = dirpath
    if not os.path.exists(experiment):
        os.makedirs(experiment)
    else:
        shutil.rmtree(experiment)
        
    train_dataloader = torch.utils.data.DataLoader(train_dataset_update, batch_size=512, shuffle=True, num_workers=1)
    valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=1024, shuffle=False, num_workers=1)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1024, shuffle=False, num_workers=1, drop_last=False)
    
    model_checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath = dirpath,
        filename = 'best_model',
        monitor = 'val_loss',
        mode = 'min',
    )
    trainer = pl.Trainer(
        max_epochs = 300,
        enable_model_summary = False,  # summary printed already
        callbacks = [
            early_stopping_callback,
            model_checkpoint_callback
        ],
        accelerator='gpu', devices=1, deterministic=True
    )

    trainer.fit(model_class, train_dataloader, valid_dataloader)


    model_class = Model(k=args.gene_num).cuda()
    best_checkpoint = trainer.checkpoint_callback.best_model_path
    print(best_checkpoint)
    model = Model.load_from_checkpoint(best_checkpoint, k=args.gene_num, fullmodel = False)
    test_data = trainer.test(model, test_dataloader)
    np.save(f"./uni_regression/fulllistUNI_{args.method}_lr{args.lr}_largemodel_real512_validreg_nsplit{args.n_split}_classweight{args.class_weight}_nsplit{nsplit}_dlambda{args.d_lambda}_pca20_seed{seed}", model.test_pred)


import scipy.stats
import sklearn.metrics

gene_l = []
cell_l = []
mse_l_all = []
groundtruth = adata_test[:,gene_list['genes']].X
for seed_l in range(0,5):
    pcc_l = []
    mse_l = []
    check0 = np.load(f"./uni_regression/fulllistUNI_{args.method}_lr{args.lr}_largemodel_real512_validreg_nsplit{args.n_split}_classweight{args.class_weight}_nsplit{nsplit}_dlambda{args.d_lambda}_pca20_seed{seed_l}.npy")
    
    for i,j in zip(groundtruth , check0):
        pcc_l.append(scipy.stats.pearsonr(i, j)[0])
        mse_l.append(sklearn.metrics.mean_squared_error(i, j))
    gene_l.append(np.nanmean(pcc_l))
    mse_l_all.append(np.mean(mse_l))
    pcc_l = []
    for i,j in zip(groundtruth.T , check0.T):
        pcc_l.append(scipy.stats.pearsonr(i, j)[0])
    cell_l.append(np.nanmean(pcc_l))

#check result
df = pd.DataFrame()
df['pcc_gene'] = gene_l 
df['pcc_cell'] = cell_l
df['mse'] = mse_l_all
df.T
