import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

import scanpy as sc
import pandas as pd
import pickle

import timm
import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login

import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler, RandomSampler, SequentialSampler
from sklearn.model_selection import StratifiedKFold
import matplotlib.pyplot as plt
from sklearn.metrics import cohen_kappa_score
from tqdm import tqdm_notebook as tqdm

import argparse
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
from numpy.random import uniform, choice, normal
from torch import nn, optim, Tensor, manual_seed, argmax
from torch.utils.data import TensorDataset, DataLoader
from torchmetrics.classification import Accuracy, MulticlassConfusionMatrix
from pytorch_lightning.utilities.model_summary import ModelSummary
import pytorch_lightning as pl
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from torch.autograd import Variable
import pandas as pd
import torch

import os
import torch
os.environ['HF_HOME'] = './'
os.environ['HF_TOKEN'] = '' # fill your name

import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login

from PIL import Image
from matplotlib import cm
import numpy as np

import pickle

import timm
from PIL import Image
from torchvision import transforms
import torch
import OrdinalEntropy

# pretrained=True needed to load UNI weights (and download weights for the first time)
# using UNI2-h as example
timm_kwargs = {
   'img_size': 224, 
   'patch_size': 14, 
   'depth': 24,
   'num_heads': 24,
   'init_values': 1e-5, 
   'embed_dim': 1536,
   'mlp_ratio': 2.66667*2,
   'num_classes': 0, 
   'no_embed_class': True,
   'mlp_layer': timm.layers.SwiGLUPacked, 
   'act_layer': torch.nn.SiLU, 
   'reg_tokens': 8, 
   'dynamic_img_size': True
  }
model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
model.eval()

class BalancedPearsonCorrelationLoss(torch.nn.Module):
    """Pearson Corr balances between across gene and cell performance"""

    def __init__(
        self,
        rel_weight_gene: float = 1.0,
        rel_weight_cell: float = 1.0,
        norm_by = "mean",
        eps: float = 1e-8,
    ):
        """Initialise PearsonCorrelationLoss.

        Parameter
        ---------
        rel_weight_gene: float = 1.0
            The relative weight to put on the across gene/tss correlation.
        rel_weight_cell: float = 1.0
            The relative weight to put on the across cells correlation.
        norm_by:  Literal['mean', 'nonzero_median'] = 'nonzero_median'
            What to use as across gene / cell average to subtract from the
            signal to normalise it. Mean or the Median of the non zero entries.
        eps: float 1e-8
            epsilon
        """
        super().__init__()
        self.eps = eps
        self.norm_by = norm_by
        self.rel_weight_gene = rel_weight_gene
        self.rel_weight_cell = rel_weight_cell
        self.mse_loss = nn.MSELoss()

    def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """Forward.

        Notes
        -----
        preds: torch.Tensor
            2D torch tensor [genes x cells], batched over genes.
        targets: torch.Tensor
            2D torch tensor [genes x cells], batched over genes.
        """
        if self.norm_by == "mean":
            preds_avg_gene = preds.mean(dim=0, keepdim=True)
            targets_avg_gene = targets.mean(dim=0, keepdim=True)
            preds_avg_cell = preds.mean(dim=1, keepdim=True)
            targets_avg_cell = targets.mean(dim=1, keepdim=True)
        else:
            preds_avg_gene = nonzero_median(preds, 0, keepdim=True)
            targets_avg_gene = nonzero_median(targets, 0, keepdim=True)
            preds_avg_cell = nonzero_median(preds, 1, keepdim=True)
            targets_avg_cell = nonzero_median(targets, 1, keepdim=True)

        r_tss = torch.nn.functional.cosine_similarity(
            preds - preds_avg_gene,
            targets - targets_avg_gene,
            eps=self.eps,
            dim=0,
        )

        r_celltype = torch.nn.functional.cosine_similarity(
            preds - preds_avg_cell,
            targets - targets_avg_cell,
            eps=self.eps,
        )

        loss = self.rel_weight_gene * (1 - r_tss.mean()) + self.rel_weight_cell * (
            1 - r_celltype.mean()
        )

        # norm the loss to 2 by half the sum of the relative weights
        loss = (loss * 2) / (self.rel_weight_gene + self.rel_weight_cell)

        return loss + self.mse_loss(preds,targets)

import argparse

parser = argparse.ArgumentParser(description='Process some integers.')
# args = argparse.Namespace()
parser.add_argument('--method', default='joint_pcc_mse')
parser.add_argument('--epsilon', default=1.0)
parser.add_argument('--alpha', default=0.1)
parser.add_argument('--fgamma', default=1.0)
parser.add_argument('--epsilon_p', default=2.0)
parser.add_argument('--e_lambda', default=0.3)
parser.add_argument('--gene_num', default=50)
parser.add_argument('--dataset_name', default='READ')
parser.add_argument('--d_lambda', default=1.0, type=float)
parser.add_argument('--lr', default=1e-4, type=float)
parser.add_argument('--batchsize', default=512, type=int)
parser.add_argument('--T', default=1.0)
parser.add_argument('--savepath', default="./running_out")
parser.add_argument('--class_weight', default=1e-3, type=float)
parser.add_argument('--nsplit', default=5, type=int)
parser.add_argument('--layersize', default=512, type=int)
args = parser.parse_args()
dataset_name = args.dataset_name

import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login

from PIL import Image
from matplotlib import cm
import numpy as np

import pickle

class REGDataset(Dataset):
    def __init__(self,
                 adata,
                 image,
                 transform,
                 gene = 'NOSTRIN'
                ):

        self.adata = adata[:,gene]
        self.image = image
        self.transform = transform

    def __len__(self):
        return self.adata.shape[0]

    def __getitem__(self, index):
        input_data = self.image[index]
        input_data = torch.from_numpy(input_data)
        output_data = torch.FloatTensor(self.adata[str(index)].X)[0,:]
        
        return input_data, output_data

sample_list = ['IDC','READ', 'PRAD', 'LYMPH_IDC', 'COAD', 'CCRCC', 'Brain']

import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login

from PIL import Image
from matplotlib import cm
import numpy as np

import pickle

import sys
sys.path.append('./')
sys.path.append('./diff_sample_code/paper_code/')

class REGDataset(Dataset):
    def __init__(self,
                 adata,
                 image,
                 transform,
                 gene = 'NOSTRIN'
                ):

        self.adata = adata[:,gene]
        self.image = image
        self.transform = transform

    def __len__(self):
        return self.adata.shape[0]

    def __getitem__(self, index):
        input_data = self.image[index]
        input_data = torch.from_numpy(input_data)
        output_data = torch.FloatTensor(self.adata[str(index)].X.T[0])
        
        return input_data, output_data

dataset_name = args.dataset_name
# sample_list = ['READ', 'PRAD', 'LYMPH_IDC', 'COAD', 'CCRCC']

# # gene_list = {"genes": ["ABCC11", "ADH1B", "ADIPOQ", "ANKRD30A", "AQP1", "AQP3", "CCR7", "CD3E", "CEACAM6", "CEACAM8", "CLIC6", "CYTIP", "DST", "ERBB2", "ESR1", "FASN", "GATA3", "IL2RG", "IL7R", "KIT", "KLF5", "KRT14", "KRT5", "KRT6B", "MMP1", "MMP12", "MS4A1", "MUC6", "MYBPC1", "MYH11", "MYLK", "OPRPN", "OXTR", "PIGR", "PTGDS", "PTN", "PTPRC", "SCD", "SCGB2A1", "SERHL2", "SERPINA3", "SFRP1", "SLAMF7", "TACSTD2", "TCL1A", "TENT5C", "TOP2A", "TPSAB1", "TRAC", "VWF"]}

# gene_list = pd.read_json(f"./hest-bench/{dataset_name}/var_50genes.json")

# gene_list['genes'].values

if dataset_name == 'IDC':
    adata_train = sc.read("./train_data/IDC_train.h5ad")
    adata_valid = sc.read("./train_data/IDC_valid.h5ad")
    adata_test = sc.read("./train_data/IDC_test.h5ad")

    adata_train.obs_names =[str(i) for i in range(len(adata_train))]
    adata_valid.obs_names =[str(i) for i in range(len(adata_valid))]
    adata_test.obs_names =[str(i) for i in range(len(adata_test))]


    image_train = np.load("./train_data/IDC_train_imagefeature.npy")
    image_valid = np.load("./train_data/IDC_valid_imagefeature.npy")
    image_test = np.load("./train_data/IDC_test_imagefeature.npy")
    gene_list = pd.read_json(f"./hest-bench/{dataset_name}/var_50genes.json")
    
elif dataset_name =='Brain':
    adata_train = sc.read("./train_data/adata_brain_norm_train.h5ad")
    adata_valid = sc.read("./train_data/adata_brain_norm_valid.h5ad")
    adata_test = sc.read("./train_data/adata_brain_norm_test.h5ad")

    adata_train.obs_names =[str(i) for i in range(len(adata_train))]
    adata_valid.obs_names =[str(i) for i in range(len(adata_valid))]
    adata_test.obs_names =[str(i) for i in range(len(adata_test))]


    image_train = np.load("./train_data/brain_train_imagefeature_univ2.npy")
    image_valid = np.load("./train_data/brain_valid_imagefeature_univ2.npy")
    image_test = np.load("./train_data/brain_test_imagefeature_univ2.npy")
    gene_list = {"genes": [
        "ACTA2",
        "BST2",
        "CCND1",
        "COL18A1",
        "COL4A1",
        "COL4A2",
        "COL6A1",
        "CPD",
        "CREG1",
        "CTSH",
        "DDIT4",
        "DNAJB1",
        "ENG",
        "ERO1A",
        "FN1",
        "FSTL1",
        "FURIN",
        "GLRX",
        "HSPA1A",
        "IER3",
        "IFI27",
        "IGFBP4",
        "IGFBP7",
        "ISG15",
        "LGALS3",
        "MFGE8",
        "MFSD12",
        "MGP",
        "MT1X",
        "MYL9",
        "NDRG1",
        "NFKBIA",
        "NGRN",
        "NOTCH3",
        "NUPR1",
        "PFKP",
        "PI4KA",
        "PLXND1",
        "RPL11",
        "RPS3",
        "SNHG25",
        "SOD2",
        "SPTSSA",
        "TAGLN",
        "TAP1",
        "THY1",
        "TPM1",
        "TSPYL1",
        "VEGFA",
        "XBP1"
    ]}
    gene_list = pd.DataFrame(gene_list)
    
elif dataset_name =='skin':
    adata_train = sc.read("./train_data/adata_skin_norm_train.h5ad")
    adata_valid = sc.read("./train_data/adata_skin_norm_valid.h5ad")
    adata_test = sc.read("./train_data/adata_skin_norm_test.h5ad")

    adata_train.obs_names =[str(i) for i in range(len(adata_train))]
    adata_valid.obs_names =[str(i) for i in range(len(adata_valid))]
    adata_test.obs_names =[str(i) for i in range(len(adata_test))]


    image_train = np.load("./train_data/skin_train_imagefeature_univ2.npy")
    image_valid = np.load("./train_data/skin_valid_imagefeature_univ2.npy")
    image_test = np.load("./train_data/skin_test_imagefeature_univ2.npy")
    gene_list = {"genes":  [
        "KRT15",
        "CRCT1",
        "RHCG",
        "IGHG1",
        "GGCT",
        "ASPRV1",
        "JCHAIN",
        "PI3",
        "SLURP1",
        "CA2",
        "SSFA2",
        "SPRR2E",
        "LCE3D",
        "MT1X",
        "IGHG4",
        "S100A7A",
        "HERC6",
        "WARS",
        "RPL22L1",
        "SDR16C5",
        "IL1RN",
        "MX1",
        "GGH",
        "CLEC2B",
        "CAST",
        "PTGS1",
        "IGFBP6",
        "PRSS8",
        "LAP3",
        "TXNL4A",
        "SERPINE2",
        "IGLC2",
        "KRT75",
        "C9orf3",
        "IGHG3",
        "SPRR2D",
        "IFIT1",
        "PPL",
        "LSM5",
        "TMEM256",
        "KRT2",
        "SPRR2G",
        "SNRPD1",
        "SPRR2A",
        "OAS1",
        "GBA",
        "GADD45GIP1",
        "LYZ",
        "SPINT1",
        "NAGK"
    ]}
    gene_list = pd.DataFrame(gene_list)
    
else:
    adata_train = sc.read(f"./train_data/{dataset_name}_norm_train.h5ad")
    adata_valid = sc.read(f"./train_data/{dataset_name}_norm_valid.h5ad")
    adata_test = sc.read(f"./train_data/{dataset_name}_norm_test.h5ad")

    adata_train.obs_names =[str(i) for i in range(len(adata_train))]
    adata_valid.obs_names =[str(i) for i in range(len(adata_valid))]
    adata_test.obs_names =[str(i) for i in range(len(adata_test))]


    image_train = np.load(f"./train_data/{dataset_name}_train_imagefeature_univ2.npy")
    image_valid = np.load(f"./train_data/{dataset_name}_valid_imagefeature_univ2.npy")
    image_test = np.load(f"./train_data/{dataset_name}_test_imagefeature_univ2.npy")
    gene_list = pd.read_json(f"./hest-bench/{dataset_name}/var_50genes.json")


print(dataset_name)
import sys
from diff_sample_code.paper_code.utils import check_dir, prepare_dset, update_print, get_relative_maha_distance, maha, \
    get_pretrained_model, get_maha_distance, MahaDistNormalizer, ranking_loss

class MahaDistNormalizer:
    def __init__(self,):
        super(MahaDistNormalizer, self).__init__()
        self.min, self.max = 1e20, -1e10

    def run(self, x, left, right):
        self.min = min(x.min(), self.min)
        self.max = max(x.max(), self.max)
        k = (right-left)/(self.max - self.min)
        return left+k*(x - self.min)
    
    

def my_softmax(X):
    X -= X.max()
    X_exp = X.exp()
    max_data = X_exp.max() + 0.001
    # partition = X_exp.sum(dim=1, keepdim=True)
    #print("X size is ", X_exp.size())
    #print("partition size is ", partition, partition.size())
    return X_exp / max_data


adata_train.obsm['bin'] = np.array([0 for i in range(len(adata_train)//2)] + [1 for i in range(len(adata_train)//2)])
adata_valid.obsm['bin'] =  np.array([0 for i in range(len(adata_valid)//2)] + [1 for i in range(len(adata_valid)//2)])
adata_test.obsm['bin'] =  np.array([0 for i in range(len(adata_test) //2)] + [1 for i in range(len(adata_test)//2)])

args.class_num = 1

class REGDataset(Dataset):
    def __init__(self,
                 adata,
                 image,
                 transform,
                 gene = 'NOSTRIN'
                ):

        self.adata = adata[:,gene]
        self.image = image
        self.transform = transform

    def __len__(self):
        return self.adata.shape[0]

    def __getitem__(self, index):
        input_data = self.image[index]
        input_data = torch.from_numpy(input_data)
        output_data = torch.FloatTensor(self.adata[str(index)].X)[0,:]
        output_label = torch.FloatTensor(self.adata[str(index)].obsm['bin']).long()
        
        return input_data, output_data, output_label
    
    

train_dataset_new = REGDataset(adata_train, image_train, transform, gene_list['genes'])
valid_dataset = REGDataset(adata_valid, image_valid, transform, gene_list['genes'])
test_dataset = REGDataset(adata_test, image_test, transform, gene_list['genes'])

train_dataloader = torch.utils.data.DataLoader(train_dataset_new, batch_size=512, shuffle=False, drop_last=False)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=512, shuffle=False, drop_last=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=512, shuffle=False, drop_last=False)


class entropy_ce(nn.Module):
    def __init__(self):
        super(entropy_ce, self).__init__()
    
    def forward(self,x_input, y_target, e_lambda):
        p = F.softmax(x_input)
        # p = p.detach()
        # print(weight.shape)
        entropy = - torch.sum(p * F.log_softmax(x_input), dim=1).reshape(-1,1)
        
        
        x_input = F.log_softmax(x_input, 1)
        y_target = F.one_hot(y_target, num_classes=args.class_num)
        loss = - torch.sum(x_input * y_target, 1)
        loss = (1-e_lambda) * torch.mean(loss) -  e_lambda * torch.mean(entropy)
        
        return loss

import OrdinalEntropy

train_data= image_train
train_list= adata_train.obsm['bin']

import sklearn.linear_model

model = sklearn.linear_model.LogisticRegression()

import sklearn
import numpy as np
from sklearn.model_selection import KFold
nsplit = args.nsplit
kf = KFold(n_splits=nsplit,  shuffle=True, random_state=2024)
# kf.get_n_splits(df_train_update.index)

wrong_sample = []
for itd, (train_id, test_id) in enumerate(kf.split([i for i in range(len(train_data))])):
    x_train = train_data[train_id]
    train_label = train_list[train_id]
    x_val = train_data[test_id]
    val_label = train_list[test_id]
    
    model = sklearn.linear_model.LogisticRegression()
    model.fit(x_train, train_label)
    pred_label = model.predict(x_val)
    for idx, (i,j) in enumerate(zip(val_label, pred_label)):
        if i != j:
            wrong_sample.append(test_id[idx])

train_dataset_new = REGDataset(adata_train, image_train, transform, gene_list['genes'])
valid_dataset = REGDataset(adata_valid, image_valid, transform, gene_list['genes'])
test_dataset = REGDataset(adata_test, image_test, transform, gene_list['genes'])

def generate_new_dataset(origina_data,label_list):
    input_list = []
    output_list = []
    array_list = []
    corr_list = []
    for idx, (i,j, k) in enumerate(origina_data):
        if idx in label_list:
            array_list.append(1)
        else:
            array_list.append(0)
        input_list.append(i)
        output_list.append(j)
        corr_list.append(k)
        
    d1 = torch.tensor(np.array(input_list))
    d2 = torch.tensor(np.array(output_list))
    d3 = torch.tensor(np.array(array_list))
    d4 = torch.tensor(np.array(corr_list))
    return TensorDataset(d1,d2, d4, d3)

train_dataset_update = generate_new_dataset(train_dataset_new, sorted(set(wrong_sample)))

import argparse
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
from numpy.random import uniform, choice, normal
from torch import nn, optim, Tensor, manual_seed, argmax
from torch.utils.data import TensorDataset, DataLoader
from torchmetrics.classification import Accuracy, MulticlassConfusionMatrix
from pytorch_lightning.utilities.model_summary import ModelSummary
import pytorch_lightning as pl
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from torch.autograd import Variable

class entropy_ce(nn.Module):
    def __init__(self):
        super(entropy_ce, self).__init__()
    
    def forward(self,x_input, y_target, e_lambda):
        p = F.softmax(x_input)
        # p = p.detach()
        # print(weight.shape)
        entropy = - torch.sum(p * F.log_softmax(x_input), dim=1).reshape(-1,1)
        
        
        x_input = F.log_softmax(x_input, 1)
        y_target = F.one_hot(y_target, num_classes=args.class_num)
        loss = - torch.sum(x_input * y_target, 1)
        loss = (1-e_lambda) * torch.mean(loss) -  e_lambda * torch.mean(entropy)
        
        return loss
    
class weighted_poly(nn.Module):
    def __init__(self):
        super(weighted_poly, self).__init__()
    
    def forward(self, x_input, y_target, weight):
        weight = weight.reshape(-1, 1)
#         print(weight.shape)
        p = F.softmax(x_input)
        y_target = F.one_hot(y_target, num_classes=args.class_num)
        # p = p.detach()
        pt = torch.sum(p * y_target, dim=1).reshape(-1,1)
        poly_loss = weight * (1. - pt)
        
        x_input = F.log_softmax(x_input, 1)
        
        ce_loss = - torch.sum( x_input * y_target, 1)
        loss = torch.mean(ce_loss + poly_loss)
        # print(loss)

        return loss
    
class weighted_entropy_ce(nn.Module):
    def __init__(self):
        super(weighted_entropy_ce, self).__init__()
    
    def forward(self,x_input, y_target, weight, e_lambda):
        weight = weight.reshape(-1, 1)
        # print(weight.shape)
        p = F.softmax(x_input)
        # p = p.detach()
        entropy = - torch.sum(p * F.log_softmax(x_input), dim=1).reshape(-1,1)
        # print(entropy.shape)
        # print(rank_input1)
        

        weight_beta = e_lambda * weight
        # weight_1 = torch.ones_like(weight_beta) - weight_beta
        entropy = weight_beta * entropy
        # print(entropy)
        
        x_input = F.log_softmax(x_input, 1)
        y_target = F.one_hot(y_target, num_classes=args.class_num)
        loss = - torch.sum( x_input * y_target, 1)
        loss = torch.mean(loss) -  torch.mean(entropy)
        # print(loss)

        return loss

frame_list = []
for seed in range(0,5):
    pl.seed_everything(seed, workers=True)
    print("the seed is", seed)

    class Model(pl.LightningModule):
        def __init__(self, k, optimizer = 'Adam', dropout_rate = 0, fullmodel = False):
            super().__init__()
            self.num_classes = k
            self.model = nn.Sequential(*[nn.Linear(1536, 512), nn.ReLU(), nn.Linear(512,args.layersize)])
            self.regress_head = nn.Sequential(*[nn.Linear(args.layersize, args.gene_num), nn.Softplus()])
            self.class_head = nn.Linear(args.layersize, args.class_num)
            # Define other attributes
            self.optimizer = optimizer
            self.lr = {'Adam': args.lr, 'SGD': 0.1}[optimizer]
            self.test_pred = []  # collect predictions
            self.prob = [] #store prediction probability
            self.fullmodel = fullmodel


        def compute_loss(self, inputs, probs, preds, target_class, target_cont, args, traintestind, idx):
            criterion_reg = BalancedPearsonCorrelationLoss()
            if args.method == 'mse':
                criterion_reg = nn.MSELoss()
                return criterion_reg(preds,target_cont)
            elif args.method == 'huber':
                criterion_reg = nn.HuberLoss()
                return criterion_reg(preds,target_cont)
            elif args.method == 'pccmse':
                criterion_reg = BalancedPearsonCorrelationLoss()
                return criterion_reg(preds,target_cont)

        def forward(self, x):
            out = self.model(x)
            out = self.regress_head(out)
            return out
        
        def classify(self, x):
            out = self.model(x)
            out = self.class_head(out)
            return out

        def configure_optimizers(self):
            if self.optimizer == 'Adam':
                optimizer = optim.Adam(self.parameters(), lr=self.lr)
            else:
                optimizer = optim.SGD(self.parameters(), lr=self.lr)
                
            return optimizer

        def training_step(self, batch, batch_idx):
            x, y, label, diff = batch
            preds = self.forward(x)
            logits = self.classify(x)
            loss = self.compute_loss(x,logits, preds,label,y,args,traintestind=True, idx = diff)
            self.log('loss', loss)
            return loss

        def validation_step(self, batch, batch_idx):
            x, y, label = batch
            preds = self.forward(x)
            logits = self.classify(x)
            loss = self.compute_loss(x,logits, preds,label,y,args,traintestind=False, idx = None)
            self.log('val_loss', loss)
            return loss

        def test_step(self, batch, batch_idx):
            x, y, label = batch
            preds = self.forward(x)
            logits = self.classify(x)
            loss = self.compute_loss(x,logits, preds,label,y,args,traintestind=False, idx =  None)
            self.log('test_loss', loss)
            # Collect predictions
            self.test_pred.extend(preds.cpu().numpy())

    # Create a PyTorch Lightning trainer and add callbacks
    pl.seed_everything(seed, workers=True)
    model_class = Model(k=args.gene_num, fullmodel = False).cuda()

    import glob
    import shutil

    early_stopping_callback = pl.callbacks.early_stopping.EarlyStopping(
        monitor = 'val_loss',
        patience = 10,
        min_delta = 0.005,
        mode = 'min',
    )

    dirpath = args.savepath
    experiment = dirpath
    if not os.path.exists(experiment):
        os.makedirs(experiment)
    else:
        shutil.rmtree(experiment)
        
    train_dataloader = torch.utils.data.DataLoader(train_dataset_update, batch_size=args.batchsize, shuffle=True, num_workers=1)
    valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=1024, shuffle=False, num_workers=1)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1024, shuffle=False, num_workers=1, drop_last=False)
    
    model_checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath = dirpath,
        filename = 'best_model',
        monitor = 'val_loss',
        mode = 'min',
    )
    trainer = pl.Trainer(
        max_epochs = 300,
        enable_model_summary = False,  # summary printed already
        callbacks = [
            early_stopping_callback,
            model_checkpoint_callback
        ],
        accelerator='gpu', devices=1, deterministic=True
    )

    trainer.fit(model_class, train_dataloader, valid_dataloader)


    model_class = Model(k=args.gene_num).cuda()
    best_checkpoint = trainer.checkpoint_callback.best_model_path
    print(best_checkpoint)
    model = Model.load_from_checkpoint(best_checkpoint, k=args.gene_num, fullmodel = False)
    test_data = trainer.test(model, test_dataloader)
    np.save(f"./uni_regression_new/fulllistUNI_{args.method}_lr{args.lr}_largemodel_{dataset_name}_real{args.batchsize}_classweight{args.class_weight}_weightbaseautord_layersize{args.layersize}_harmonyfix_nsplit{nsplit}_dlambda{args.d_lambda}_pca20_seed{seed}", model.test_pred)