import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"


import argparse
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
from numpy.random import uniform, choice, normal
from torch import nn, optim, Tensor, manual_seed, argmax
from torch.utils.data import TensorDataset, DataLoader
from torchmetrics.classification import Accuracy, MulticlassConfusionMatrix
from pytorch_lightning.utilities.model_summary import ModelSummary
import pytorch_lightning as pl
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from torch.autograd import Variable
import pandas as pd
import torch

import os
import torch
os.environ['HF_HOME'] = '/home/username/scratch/'
os.environ['HF_TOKEN'] = '' #fill your token

import timm
import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login

import time
import skimage.io
import numpy as np
import pandas as pd
import cv2
import PIL.Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler, RandomSampler, SequentialSampler
from sklearn.model_selection import StratifiedKFold
import matplotlib.pyplot as plt
from sklearn.metrics import cohen_kappa_score
from tqdm import tqdm_notebook as tqdm

df_train = pd.DataFrame()

import glob

with open('camera_filename.pickle', 'rb') as f:
    file_list = pickle.load(f)
    
df_train['image_id'] = file_list
label = []
label_st = []

for i in file_list:
    if 'normal' in i:
        label.append(0)
        label_st.append('normal')
    else:
        label.append(1)
        label_st.append('tumor')
df_train['label'] = label
df_train['label_st'] = label_st


class CAMERADataset(Dataset):
    def __init__(self,
                 df,
                 image_size,
                 n_tiles=1,
                 tile_mode=0,
                 rand=False,
                 transform=None,
                ):

        self.df = df.reset_index(drop=True)
        self.image_size = image_size
        self.n_tiles = n_tiles
        self.tile_mode = tile_mode
        self.rand = rand
        self.transform = transform

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        row = self.df.iloc[index]
        img_id = row.image_id
        
        images = torch.load("/images_pkl/" + img_id +'.pkl')
        label = row['label']
        images = images[0,:,:,:]
        
        return images, torch.tensor(label)



import sklearn.model_selection


train_index, test_index = sklearn.model_selection.train_test_split(df_train.index, random_state=2024)
train_index, valid_index = sklearn.model_selection.train_test_split(train_index, random_state=2024)


train_dataset_new = CAMERADataset(df_train.loc[train_index], 224*224, 1, 0, transform=transform)
valid_dataset = CAMERADataset(df_train.loc[valid_index], 224*224, 1, 0, transform=transform)
test_dataset = CAMERADataset(df_train.loc[test_index], 224*224, 1, 0, transform=transform)


import torch
train_dataloader = torch.utils.data.DataLoader(train_dataset_new, batch_size=1, shuffle=False, num_workers=32)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=32)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=32)


emb_list = []
label_list=[]
timm_kwargs = {
   'img_size': 224, 
   'patch_size': 14, 
   'depth': 24,
   'num_heads': 24,
   'init_values': 1e-5, 
   'embed_dim': 1536,
   'mlp_ratio': 2.66667*2,
   'num_classes': 0, 
   'no_embed_class': True,
   'mlp_layer': timm.layers.SwiGLUPacked, 
   'act_layer': torch.nn.SiLU, 
   'reg_tokens': 8, 
   'dynamic_img_size': True
  }
model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
model.eval()
model.cuda()
with torch.no_grad():
    for (x,y) in train_dataloader:
        x = x.cuda()
        emb_list.append(model(x).cpu().numpy())
        label_list.append(y)


train_data= np.array(emb_list)[:,0,:]
train_list= np.array(label_list)


import sklearn.linear_model

model = sklearn.linear_model.LogisticRegression()


import sklearn
import numpy as np
from sklearn.model_selection import KFold
kf = KFold(n_splits=3,  shuffle=True, random_state=2024)




wrong_sample = []
for itd, (train_id, test_id) in enumerate(kf.split([i for i in range(len(train_data))])):
    x_train = train_data[train_id]
    train_label = train_list[train_id].T[0]
    x_val = train_data[test_id]
    val_label = train_list[test_id].T[0]
    
    model = sklearn.linear_model.LogisticRegression()
    model.fit(x_train, train_label)
    pred_label = model.predict(x_val)
    for idx, (i,j) in enumerate(zip(val_label, pred_label)):
        if i !=j:
            wrong_sample.append(test_id[idx])



train_dataset_new = CAMERADataset(df_train.loc[train_index], 224*224, 1, 0, transform=transform)
valid_dataset = CAMERADataset(df_train.loc[valid_index], 224*224, 1, 0, transform=transform)
test_dataset = CAMERADataset(df_train.loc[test_index], 224*224, 1, 0, transform=transform)

def generate_new_dataset(origina_data,label_list):
    input_list = []
    output_list = []
    array_list = []
    for idx, (i,j) in enumerate(origina_data):
        if idx in label_list:
            array_list.append(1)
        else:
            array_list.append(0)
        input_list.append(i)
        output_list.append(j)
        
    d1 = torch.tensor(np.array(input_list))
    d2 = torch.tensor(np.array(output_list))
    d3 = torch.tensor(np.array(array_list))
    return TensorDataset(d1,d2,d3)

train_dataset_update = generate_new_dataset(train_dataset_new, sorted(set(wrong_sample)))


train_dataloader = torch.utils.data.DataLoader(train_dataset_update, batch_size=128, shuffle=True, num_workers=1)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=1)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=1)



import sys
sys.path.append('/home/username/piusername/UNI/')
sys.path.append('/home/username/piusername/UNI/diff_sample_code/paper_code/')
from diff_sample_code.paper_code.utils import check_dir, prepare_dset, update_print, get_relative_maha_distance, maha, \
    get_pretrained_model, get_maha_distance, MahaDistNormalizer, ranking_loss


maha_normalizer = MahaDistNormalizer()


import argparse
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
from numpy.random import uniform, choice, normal
from torch import nn, optim, Tensor, manual_seed, argmax
from torch.utils.data import TensorDataset, DataLoader
from torchmetrics.classification import Accuracy, MulticlassConfusionMatrix
from pytorch_lightning.utilities.model_summary import ModelSummary
import pytorch_lightning as pl
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from torch.autograd import Variable

# load parameters of Gaussian distributed
maha_intermediate_dict = np.load("/home/username/piusername/UNI/diff_sample_code/paper_code/ssl/maha_dict_univ2_camera_512.npy", allow_pickle='TRUE')
class_cov_invs = maha_intermediate_dict.item()['class_cov_invs']
class_means = maha_intermediate_dict.item()['class_means']
cov_invs = maha_intermediate_dict.item()['cov_inv']
means = maha_intermediate_dict.item()['mean']

parser = argparse.ArgumentParser()
args = argparse.Namespace()
args.method = 'wer'
args.epsilon = 1.0
args.alpha = 0.1 #0.05 by default, 0.1 is best
args.fgamma = 1.0
args.epsilon_p = 2.0
args.e_lambda = 0.3
args.reverse = False
args.T = 5.0
all_out_list = {}
class entropy_ce(nn.Module):
    def __init__(self):
        super(entropy_ce, self).__init__()
    
    def forward(self,x_input, y_target, e_lambda):
        p = F.softmax(x_input)
        # p = p.detach()
        # print(weight.shape)
        entropy = - torch.sum(p * F.log_softmax(x_input), dim=1).reshape(-1,1)
        
        
        x_input = F.log_softmax(x_input, 1)
        y_target = F.one_hot(y_target, num_classes=2)
        loss = - torch.sum(x_input * y_target, 1)
        loss = (1-e_lambda) * torch.mean(loss) -  e_lambda * torch.mean(entropy)
        
        return loss
    
    
class weighted_binary_entropy_ce(nn.Module):
    def __init__(self):
        super(weighted_binary_entropy_ce, self).__init__()
    
    def forward(self,x_input, y_target, weight, e_lambda):
        weight = weight.reshape(-1, 1)
        # print(weight.shape)
        p = F.softmax(x_input)
        # logp = F.log_softmax(x_input)
        y_target = F.one_hot(y_target, num_classes=2)

        # p = p.detach()
        pt = torch.sum(p * y_target, dim=1) + 1e-4
        # logpt=  torch.sum(logp * y_target, dim=1)
        entropy = - (pt* torch.log(torch.clamp(pt,min=1e-5)) + torch.log(torch.clamp(1-pt,min=1e-5)) * (1- pt)).reshape(-1,1)
        # entropy = - torch.sum(p * F.log_softmax(x_input), dim=1).reshape(-1,1)
        # print(entropy)
        # print(rank_input1)
        

        weight_beta = e_lambda * weight
        # weight_1 = torch.ones_like(weight_beta) - weight_beta
        entropy = weight_beta * entropy
        # print(entropy)
        
        x_input = F.log_softmax(x_input, 1)
        loss = - torch.sum( x_input * y_target, 1)
        loss = torch.mean(loss) -  torch.mean(entropy)
        # print(loss)

        return loss
    
class weighted_entropy_ce(nn.Module):
    def __init__(self):
        super(weighted_entropy_ce, self).__init__()
    
    def forward(self,x_input, y_target, weight, e_lambda):
        weight = weight.reshape(-1, 1)
        # print(weight.shape)
        p = F.softmax(x_input)
        # p = p.detach()
        entropy = - torch.sum(p * F.log_softmax(x_input), dim=1).reshape(-1,1)
        # print(entropy.shape)
        # print(rank_input1)
        

        weight_beta = e_lambda * weight
        # weight_1 = torch.ones_like(weight_beta) - weight_beta
        entropy = weight_beta * entropy
        # print(entropy)
        
        x_input = F.log_softmax(x_input, 1)
        y_target = F.one_hot(y_target, num_classes=2)
        loss = - torch.sum( x_input * y_target, 1)
        loss = torch.mean(loss) -  torch.mean(entropy)
        # print(loss)

        return loss
    
class weighted_poly(nn.Module):
    def __init__(self):
        super(weighted_poly, self).__init__()
    
    def forward(self, x_input, y_target, weight):
        weight = weight.reshape(-1, 1)
        # print(weight.shape)
        p = F.softmax(x_input)
        y_target = F.one_hot(y_target, num_classes=2)
        # p = p.detach()
        pt = torch.sum(p * y_target, dim=1).reshape(-1,1)
        poly_loss = weight * (1. - pt)
        
        x_input = F.log_softmax(x_input, 1)
        
        ce_loss = - torch.sum( x_input * y_target, 1)
        loss = torch.mean(ce_loss + poly_loss)
        # print(loss)

        return loss

def my_softmax(X):
    X -= X.max()
    X_exp = X.exp()
    max_data = X_exp.max() + 0.001
    # partition = X_exp.sum(dim=1, keepdim=True)
    #print("X size is ", X_exp.size())
    #print("partition size is ", partition, partition.size())
    return X_exp / max_data

frame_list = []
for seed in range(0,5):
    print("the seed is", seed)
    # Create a PyTorch Lightning trainer and add callbacks
    pl.seed_everything(seed, workers=True)
    model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
    transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
    model.eval()

    class Model(pl.LightningModule):
        def __init__(self, k, optimizer = 'Adam', dropout_rate = 0, fullmodel = False):
            super().__init__()
            self.pretrain_model = model.cuda()
            self.num_classes = k
            self.model = nn.Sequential(*[nn.Linear(1536, 512), nn.ReLU(), nn.Linear(512, k)]).cuda()
            # Define other attributes
            self.loss = nn.CrossEntropyLoss()
            self.optimizer = optimizer
            self.lr = {'Adam': 0.001, 'SGD': 0.1}[optimizer]
            self.accuracy = Accuracy(task="multiclass", num_classes=k)
            self.test_pred = []  # collect predictions
            self.prob = [] #store prediction probability
            self.confusion_matrix = MulticlassConfusionMatrix(num_classes=k)
            self.fullmodel = fullmodel


        def compute_loss(self, inputs, outputs, targets, args, idx=None):
            criterion = nn.CrossEntropyLoss().cuda()
            
            if idx==None:
                ce_loss = criterion(outputs, targets)
                return ce_loss
            # pretrain_inputs = pretrain_inputs.cuda(args.pre_gpu)
            with torch.no_grad():
                pre_feature = self.pretrain_model(inputs)
                maha_distance = get_relative_maha_distance(pre_feature.cpu().data.numpy(),cov_invs, class_cov_invs, means, class_means, targets.cpu().data.numpy())
                maha_distance = torch.from_numpy(maha_distance)
                maha_distance_normalized = maha_normalizer.run(maha_distance, -1., 1.)
                maha_distance_normalized = my_softmax(maha_distance_normalized/args.T)
                if idx != None: #option 2
                    index_list = torch.argwhere(idx.cpu() == 1)
                    if len(index_list)>=1:
                        index_list = index_list.T[0]
                        maha_distance_normalized[index_list,0] = 1.0
                else:
                    maha_distance_normalized = maha_distance_normalized #did not change anything
            if args.method == 'wer':
                criterion = weighted_binary_entropy_ce().cuda()
                maha_weight = maha_distance_normalized.cuda()
                if args.reverse:
                    maha_weight = (2. * torch.ones_like(maha_distance_normalized) - maha_distance_normalized).cuda()
                ce_loss = criterion(outputs, targets, maha_weight, args.e_lambda)
            if args.method == 'wpoly':
                criterion = weighted_poly().cuda()
                maha_weight = maha_distance_normalized.cuda()
                if args.reverse:
                    maha_weight = (2. * torch.ones_like(maha_distance_normalized) - maha_distance_normalized).cuda()
                ce_loss = criterion(outputs, targets, maha_weight)
            

            return ce_loss

        def forward(self, x):
            if self.fullmodel:
                feature = self.pretrain_model(x)
            else:
                with torch.no_grad():
                    feature = self.pretrain_model(x)
    #         feature = self.pretrain_model(x)
            out = self.model(feature)

            return out

        def configure_optimizers(self):
            if self.optimizer == 'Adam':
                optimizer = optim.Adam(self.parameters(), lr=self.lr)
            else:
                optimizer = optim.SGD(self.parameters(), lr=self.lr)
            return optimizer

        def training_step(self, batch, batch_idx):
            x, y, idx = batch
            logits = self.forward(x)
            loss = self.compute_loss(x, logits, y, args, idx)
            self.log('loss', loss)
            # Track accuracy
            y_target = y
            y_pred = argmax(logits, dim=1)
            acc = self.accuracy(y_pred, y_target)
            self.log('accuracy', acc)
            return loss

        def validation_step(self, batch, batch_idx):
            x, y = batch
            logits = self.forward(x)
            loss =  self.compute_loss(x, logits, y, args, idx=None)
            self.log('val_loss', loss)
            # Track accuracy
            y_target = y
            y_pred = argmax(logits, dim=1)
            acc = self.accuracy(y_pred, y_target)
            self.log('val_accuracy', acc)

        def test_step(self, batch, batch_idx):
            x, y = batch
            # Evaluate model
            logits = self.forward(x)
            # Track loss
            loss = self.compute_loss(x, logits, y, args, idx=None)
            self.log('test_loss', loss)
            # Track accuracy
            y_target = y
            y_pred = argmax(logits, dim=1)  # find label with highest probability
            acc = self.accuracy(y_pred, y_target)
            self.log('test_accuracy', acc)
            # Collect predictions
            self.test_pred.extend(y_pred.cpu().numpy())
            self.prob.extend(logits.cpu().numpy())
            # Update confusion matrix
            self.confusion_matrix.update(y_pred, y_target)

    # Create a PyTorch Lightning trainer and add callbacks
    pl.seed_everything(seed, workers=True)
    model_class = Model(k=2, fullmodel = False).cuda()
    
    train_dataloader = torch.utils.data.DataLoader(train_dataset_update, batch_size=256, shuffle=True, num_workers=1)
    valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=1)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=1)

    import lightning
    import glob
    import shutil

    early_stopping_callback = pl.callbacks.early_stopping.EarlyStopping(
        monitor = 'val_loss',
        patience = 10,
        min_delta = 0.005,
        mode = 'min',
    )

    dirpath = "./classify_ft_new_camerav2univ2/"
    experiment = dirpath
    if not os.path.exists(experiment):
        os.makedirs(experiment)
    else:
        shutil.rmtree(experiment)
        
    model_checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath = dirpath,
        filename = 'best_model',
        monitor = 'val_loss',
        mode = 'min',
    )
    trainer = pl.Trainer(
        max_epochs = 300,
        enable_model_summary = False,  # summary printed already
        callbacks = [
            early_stopping_callback,
            model_checkpoint_callback
        ],
        accelerator='gpu', devices=1, deterministic=True
    )

    trainer.fit(model_class, train_dataloader, valid_dataloader)


    model_class = Model(k=2).cuda()
    best_checkpoint = trainer.checkpoint_callback.best_model_path
    print(best_checkpoint)
    model = model_class.load_from_checkpoint(best_checkpoint, k=2, fullmodel = False)
    test_data = trainer.test(model, test_dataloader)

    test_label = []

    for i,j in test_dataset:
        test_label.append(j)
    all_out_list[seed] = model.test_pred

    out = get_eval_metrics(test_label, model.test_pred, np.array(model.prob))
    frame_list.append(pd.DataFrame(out).drop(columns=['report']).iloc[0])


pd.concat(frame_list, axis=1)

