import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

import argparse
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
from numpy.random import uniform, choice, normal
from torch import nn, optim, Tensor, manual_seed, argmax
from torch.utils.data import TensorDataset, DataLoader
from torchmetrics.classification import Accuracy, MulticlassConfusionMatrix
from pytorch_lightning.utilities.model_summary import ModelSummary
import pytorch_lightning as pl
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from torch.autograd import Variable
import pandas as pd
import torch


import os
import torch
os.environ['HF_HOME'] = '/home/username/scratch/'
os.environ['HF_TOKEN'] = '' #fill your token

import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login

from PIL import Image
from matplotlib import cm
import numpy as np


import torch
import torchvision
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


import timm
import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login

import time
import skimage.io
import numpy as np
import pandas as pd
import cv2
import PIL.Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler, RandomSampler, SequentialSampler
from sklearn.model_selection import StratifiedKFold
import matplotlib.pyplot as plt
from sklearn.metrics import cohen_kappa_score
from tqdm import tqdm_notebook as tqdm


df_train = pd.DataFrame()

import glob
# file_list = []
# for item in glob.glob("/home/username/piusername/camera17/CAMELYON16/images_pkl/*"):
#     if 'test' not in item:
#         file_list.append(item.split('/')[-1].split('.')[0])

with open('camera_filename.pickle', 'rb') as f:
    file_list = pickle.load(f)
    
df_train['image_id'] = file_list
label = []
label_st = []

for i in file_list:
    if 'normal' in i:
        label.append(0)
        label_st.append('normal')
    else:
        label.append(1)
        label_st.append('tumor')
df_train['label'] = label
df_train['label_st'] = label_st

class CAMERADataset(Dataset):
    def __init__(self,
                 df,
                 image_size,
                 n_tiles=1,
                 tile_mode=0,
                 rand=False,
                 transform=None,
                ):

        self.df = df.reset_index(drop=True)
        self.image_size = image_size
        self.n_tiles = n_tiles
        self.tile_mode = tile_mode
        self.rand = rand
        self.transform = transform

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        row = self.df.iloc[index]
        img_id = row.image_id
        
        images = torch.load("/home/username/piusername/camera17/CAMELYON16/images_pkl/" + img_id +'.pkl')
        label = row['label']
        images = images[0,:,:,:]
        
        return images, torch.tensor(label)


import sklearn.model_selection


train_index, test_index = sklearn.model_selection.train_test_split(df_train.index, random_state=2024)
train_index, valid_index = sklearn.model_selection.train_test_split(train_index, random_state=2024)



train_dataset_new = CAMERADataset(df_train.loc[train_index], 224*224, 1, 0, transform=transform)
valid_dataset = CAMERADataset(df_train.loc[valid_index], 224*224, 1, 0, transform=transform)
test_dataset = CAMERADataset(df_train.loc[test_index], 224*224, 1, 0, transform=transform)

train_dataloader = torch.utils.data.DataLoader(train_dataset_new, batch_size=32, shuffle=True, num_workers=32)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=32)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=32)

from typing import Optional, Dict, Any, Union, List
import numpy as np
from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    accuracy_score,
    cohen_kappa_score,
    classification_report,
)
import torch.nn.functional as F
def ECELoss(logits, labels, n_bins = 15):
    """
    Calculates the Expected Calibration Error of a model.
    (This isn't necessary for temperature scaling, just a cool metric).
    The input to this loss is the logits of a model, NOT the softmax scores.
    This divides the confidence outputs into equally-sized interval bins.
    In each bin, we compute the confidence gap:
    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |
    We then return a weighted average of the gaps, based on the number
    of samples in each bin
    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
    "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
    2015.
    """
    logits = torch.tensor(logits)
    labels = torch.tensor(labels)
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    softmaxes = F.softmax(logits, dim=1)
    confidences, predictions = torch.max(softmaxes, 1)
    accuracies = predictions.eq(labels)
    ece = torch.zeros(1, device=logits.device)
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        # Calculated |confidence - accuracy| in each bin
        in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
        prop_in_bin = in_bin.float().mean()
        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

    return ece

def get_eval_metrics(
    targets_all: Union[List[int], np.ndarray],
    preds_all: Union[List[int], np.ndarray],
    probs_all: Optional[Union[List[float], np.ndarray]] = None,
    get_report: bool = True,
    prefix: str = "",
    roc_kwargs: Dict[str, Any] = {},
) -> Dict[str, Any]:
    """
    Calculate evaluation metrics and return the evaluation metrics.

    Args:
        targets_all (array-like): True target values.
        preds_all (array-like): Predicted target values.
        probs_all (array-like, optional): Predicted probabilities for each class. Defaults to None.
        get_report (bool, optional): Whether to include the classification report in the results. Defaults to True.
        prefix (str, optional): Prefix to add to the result keys. Defaults to "".
        roc_kwargs (dict, optional): Additional keyword arguments for calculating ROC AUC. Defaults to {}.

    Returns:
        dict: Dictionary containing the evaluation metrics.

    """
    bacc = balanced_accuracy_score(targets_all, preds_all)
    kappa = cohen_kappa_score(targets_all, preds_all, weights="quadratic")
    acc = accuracy_score(targets_all, preds_all)
    cls_rep = classification_report(targets_all, preds_all, output_dict=True, zero_division=0)

    eval_metrics = {
        f"{prefix}acc": acc,
        f"{prefix}bacc": bacc,
        f"{prefix}kappa": kappa,
        f"{prefix}weighted_f1": cls_rep["weighted avg"]["f1-score"],
    }

    if get_report:
        eval_metrics[f"{prefix}report"] = cls_rep

    if probs_all is not None:
        roc_auc = roc_auc_score(targets_all, probs_all[:,1])
        eval_metrics[f"{prefix}auroc"] = roc_auc
        ece_loss = ECELoss(probs_all, test_label)
        eval_metrics[f"{prefix}eceloss"] = ece_loss.item()
        

    return eval_metrics

import argparse
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
from numpy.random import uniform, choice, normal
from torch import nn, optim, Tensor, manual_seed, argmax
from torch.utils.data import TensorDataset, DataLoader
from torchmetrics.classification import Accuracy, MulticlassConfusionMatrix
from pytorch_lightning.utilities.model_summary import ModelSummary
import pytorch_lightning as pl
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from torch.autograd import Variable
# model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
# transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
# model.eval()


parser = argparse.ArgumentParser()
args = argparse.Namespace()
args.method = 'ce'
args.epsilon = 1.0
args.alpha = 0.1 #0.05 by default, 0.1 is best
args.fgamma = 1.0
args.epsilon_p = 2.0
args.e_lambda = 0.3

class entropy_ce(nn.Module):
    def __init__(self):
        super(entropy_ce, self).__init__()
    
    def forward(self,x_input, y_target, e_lambda):
        p = F.softmax(x_input)
        # p = p.detach()
        # print(weight.shape)
        entropy = - torch.sum(p * F.log_softmax(x_input), dim=1).reshape(-1,1)
        
        
        x_input = F.log_softmax(x_input, 1)
        y_target = F.one_hot(y_target, num_classes=2)
        loss = - torch.sum(x_input * y_target, 1)
        loss = (1-e_lambda) * torch.mean(loss) -  e_lambda * torch.mean(entropy)
        
        return loss

frame_list = []
for seed in range(0,5):
    pl.seed_everything(seed, workers=True)
    print("the seed is", seed)
    # pretrained=True needed to load UNI weights (and download weights for the first time)
    # using UNI2-h as example
    timm_kwargs = {
       'img_size': 224, 
       'patch_size': 14, 
       'depth': 24,
       'num_heads': 24,
       'init_values': 1e-5, 
       'embed_dim': 1536,
       'mlp_ratio': 2.66667*2,
       'num_classes': 0, 
       'no_embed_class': True,
       'mlp_layer': timm.layers.SwiGLUPacked, 
       'act_layer': torch.nn.SiLU, 
       'reg_tokens': 8, 
       'dynamic_img_size': True
      }
    model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
    transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
    model.eval()

    class Model(pl.LightningModule):
        def __init__(self, k, optimizer = 'Adam', dropout_rate = 0, fullmodel = False):
            super().__init__()
            self.pretrain_model = model.cuda()
            self.num_classes = k
            self.model = nn.Sequential(*[nn.Linear(1536, 512), nn.ReLU(), nn.Linear(512, k)]).cuda()
            # Define other attributes
            self.loss = nn.CrossEntropyLoss()
            self.optimizer = optimizer
            self.lr = {'Adam': 0.001, 'SGD': 0.1}[optimizer]
            self.accuracy = Accuracy(task="multiclass", num_classes=k)
            self.test_pred = []  # collect predictions
            self.prob = [] #store prediction probability
            self.confusion_matrix = MulticlassConfusionMatrix(num_classes=k)
            self.fullmodel = fullmodel


        def compute_loss(self, inputs, outputs, targets, args):
            criterion = nn.CrossEntropyLoss().cuda()
            target_bi = torch.zeros(inputs.size(0), self.num_classes).scatter_(1, targets.cpu().view(-1,1).long(), 1)
            target_bi = target_bi.cuda()
            if args.method == 'ce':
                ce_loss = criterion(outputs, targets)
            elif args.method == 'ls':
                epsilon = args.epsilon
                target_bi_smooth = (1.0 - epsilon) * target_bi + epsilon/self.num_classes
                ce_loss = -torch.mean(torch.sum(torch.nn.functional.log_softmax(outputs, dim=1) * target_bi_smooth, dim=1)) ####################Label Smoothing

            elif args.method == 'l1':
                loss_cla = criterion(outputs, targets)
                loss_f1_norm = torch.mean(torch.norm(outputs,p=1,dim=1))
                ce_loss = loss_cla + args.alpha * loss_f1_norm  ########################## L1 Norm

            elif args.method == 'focal':
                target_var = targets.view(-1,1)
                logpt = torch.nn.functional.log_softmax(outputs, dim=1)
                logpt = logpt.gather(1,target_var)
                logpt = logpt.view(-1)
                pt = Variable(logpt.exp().data)
                weights = (1-pt)**(args.fgamma)
                ce_loss = -torch.mean(weights * logpt)   ################################## Focal Loss

            elif args.method == 'poly':
                p = F.softmax(outputs)
                x_input = F.log_softmax(outputs, 1)
                y_target = F.one_hot(targets, num_classes=self.num_classes)
                pt = torch.sum(p * y_target, dim=1)
                ce_loss = - torch.sum( x_input * y_target, 1)

                poly_loss = args.epsilon_p * (1. - pt)

                ce_loss = torch.mean(ce_loss + poly_loss)
                
            elif args.method == 'entropy':
                criterion = entropy_ce()
                ce_loss = criterion(outputs, targets, args.e_lambda)

            return ce_loss

        def forward(self, x):
            if self.fullmodel:
                feature = self.pretrain_model(x)
            else:
                with torch.no_grad():
                    feature = self.pretrain_model(x)
    #         feature = self.pretrain_model(x)
            out = self.model(feature)

            return out

        def configure_optimizers(self):
            if self.optimizer == 'Adam':
                optimizer = optim.Adam(self.parameters(), lr=self.lr)
            else:
                optimizer = optim.SGD(self.parameters(), lr=self.lr)
            return optimizer

        def training_step(self, batch, batch_idx):
            x, y = batch
            logits = self.forward(x)
            loss = self.compute_loss(x, logits, y, args)
            self.log('loss', loss)
            # Track accuracy
            y_target = y
            y_pred = argmax(logits, dim=1)
            acc = self.accuracy(y_pred, y_target)
            self.log('accuracy', acc)
            return loss

        def validation_step(self, batch, batch_idx):
            x, y = batch
            logits = self.forward(x)
            loss =  self.compute_loss(x, logits, y, args)
            self.log('val_loss', loss)
            # Track accuracy
            y_target = y
            y_pred = argmax(logits, dim=1)
            acc = self.accuracy(y_pred, y_target)
            self.log('val_accuracy', acc)

        def test_step(self, batch, batch_idx):
            x, y = batch
            # Evaluate model
            logits = self.forward(x)
            # Track loss
            loss = self.compute_loss(x, logits, y, args)
            self.log('test_loss', loss)
            # Track accuracy
            y_target = y
            y_pred = argmax(logits, dim=1)  # find label with highest probability
            acc = self.accuracy(y_pred, y_target)
            self.log('test_accuracy', acc)
            # Collect predictions
            self.test_pred.extend(y_pred.cpu().numpy())
            self.prob.extend(logits.cpu().numpy())
            # Update confusion matrix
            self.confusion_matrix.update(y_pred, y_target)

    # Create a PyTorch Lightning trainer and add callbacks
    pl.seed_everything(seed, workers=True)
    model_class = Model(k=2, fullmodel = False).cuda()

    import lightning
    import glob
    import shutil

    early_stopping_callback = pl.callbacks.early_stopping.EarlyStopping(
        monitor = 'val_loss',
        patience = 10,
        min_delta = 0.005,
        mode = 'min',
    )

    dirpath = "./classify_ft_new_camera_univ2/"
    experiment = dirpath
    if not os.path.exists(experiment):
        os.makedirs(experiment)
    else:
        shutil.rmtree(experiment)
        
    train_dataloader = torch.utils.data.DataLoader(train_dataset_new, batch_size=32, shuffle=True, num_workers=1)
    valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=1)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=1)
    model_checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath = dirpath,
        filename = 'best_model',
        monitor = 'val_loss',
        mode = 'min',
    )
    trainer = pl.Trainer(
        max_epochs = 300,
        enable_model_summary = False,  # summary printed already
        callbacks = [
            early_stopping_callback,
            model_checkpoint_callback
        ],
        accelerator='gpu', devices=1, deterministic=True
    )

    trainer.fit(model_class, train_dataloader, valid_dataloader)


    model_class = Model(k=2).cuda()
    best_checkpoint = trainer.checkpoint_callback.best_model_path
    print(best_checkpoint)
    model = model_class.load_from_checkpoint(best_checkpoint, k=2, fullmodel = False)
    test_data = trainer.test(model, test_dataloader)

    test_label = []

    for i,j in test_dataset:
        test_label.append(j)

    out = get_eval_metrics(test_label, model.test_pred, np.array(model.prob))
    frame_list.append(pd.DataFrame(out).drop(columns=['report']).iloc[0])

pd.concat(frame_list, axis=1)
