from __future__ import print_function
from matplotlib.pyplot import axis
from numpy.lib.function_base import append
import random
import torch
from torch import logit
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import config as cf
from datasets import ImagenetNoise

import torchvision.transforms as transforms

import os
import argparse

from torch.utils.data import Dataset
from PIL import Image
import numpy as np
from PIL import Image

from utils import prepare_dset, maha
from networks import *

from utils import get_pretrained_model
import matplotlib.pyplot as plt

import torch
import torchvision
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np


import os
import torch

import sys
sys.path.append('/home/username/piusername/MUSK/')
os.environ['HF_HOME'] = '/home/username/scratch/'
os.environ['HF_TOKEN'] = ''

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
    #  torch.backends.cudnn.deterministic = True
setup_seed(20)

parser = argparse.ArgumentParser(description='Ensemble Training')
# pretrained models setting
parser.add_argument('--maha_file', default='./ssl/maha_dict.npy', type=str)
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50')
parser.add_argument('--pretrained', default='', type=str,
                    help='path to moco pretrained checkpoint')
parser.add_argument('--pretrained_model', default='vit', type=str, help='SSL feature map type')


parser.add_argument('--batch_size', default=1024, type=int)
parser.add_argument('--dataset', default='cifar10', type=str, help='cifar10/cifar100')
parser.add_argument('--num_classes', default=10, type=int)
parser.add_argument('--random_state', type=int, default=0)




parser.add_argument('--ynoise_type', default='symmetric', type=str, help='symmetric/pairflip')
parser.add_argument('--ynoise_rate', default=0.0, type=float, help='label noise rate')
parser.add_argument('--xnoise_type', default='blur', type=str, help='gaussian/blur')
parser.add_argument('--xnoise_arg', default=1, type=float)
parser.add_argument('--xnoise_rate', default=0.0, type=float)
parser.add_argument('--trigger_size', type=int, default=3)
parser.add_argument('--trigger_ratio', type=float, default=0.)


parser.add_argument('--select_dataloader', type=str, default="trainloader")


args = parser.parse_args()

num_classes = args.num_classes

# os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# Hyper Parameter settings
use_cuda = torch.cuda.is_available()
best_acc = 0
batch_size = args.batch_size
# Custom_Dataset class
class Custom_Dataset(Dataset):
    def __init__(self, x, y, data_set, transform=None):
        self.x_data = x
        self.y_data = y
        self.data = data_set
        self.transform = transform

    def __len__(self):
        return len(self.x_data)

    # return idx
    def __getitem__(self, idx):
        if self.data == 'cifar':
            img = Image.fromarray(self.x_data[idx])
        elif self.data == 'svhn':
            img = Image.fromarray(np.transpose(self.x_data[idx], (1, 2, 0)))

        x = self.transform(img)

        return x, self.y_data[idx], idx


def predict_maha(model, epoch, args):
    feature_all = []
    labels_all = []
    for i in range(epoch):
        print(i)
        with torch.no_grad():
        # for batch_idx, ((inputs, xnoisy), (targets, true_tar)) in enumerate(trainloader):
            for batch_idx, ( inputs, targets) in enumerate(eval(args.select_dataloader)):

                inputs = inputs.cuda()
                if args.dataset != 'imagenet':
                    inputs = up_sample(inputs)
                if args.arch.startswith('clip'):
                    
                    # print(output.shape)
                    if args.arch == 'clip_r50':
                        output = model.encode_image(inputs)
                    elif args.arch == 'clip_r101':
                        output = model.encode_image(inputs)
                    else:
                        output = model.get_image_features(inputs)
                else:
                    
                    if args.arch == 'uni':
                        output = model(inputs)
                    elif args.arch == 'musk':
                        output = model(inputs, with_head=False,
        out_norm=False,
        ms_aug=True,
        return_global=True  
        )[0]  
                    else:
                        output = model(inputs)
                    # print(output.shape)
                if args.arch.startswith('hug'):
                    logits = output.logits.cpu().data.numpy()
                else:
                    logits = output.cpu().data.numpy()
                labels_np = targets.cpu().data.numpy()
                if batch_idx == 0 and i == 0:
                    feature_all = logits
                    labels_all = labels_np
                else:
                    feature_all = np.concatenate((feature_all,logits),axis=0)
                    labels_all = np.concatenate((labels_all,labels_np),axis=0)
                print(feature_all.shape)
                # print(num_classes)
    # print(num_classes)
    maha_intermediate_dict = maha(feature_all,labels_all,indist_classes = num_classes)
    print(feature_all.shape)
    np.save(args.maha_file, maha_intermediate_dict)





def compute_mean_cov(pretrain_model,args):
    
    pretrain_model.cuda()
    if not args.arch.startswith('clip'):
        pretrain_model = torch.nn.DataParallel(pretrain_model)
        pretrain_model.eval()
    # train for N epochs
    predict_maha(pretrain_model, 3, args)
# Data Uplaod
print('\n[Phase 1] : Data Preparation')


# if args.dataset != 'imagenet':
#     trainset, testset, trainvalset = prepare_dset(args)
#     num_classes = trainset.nb_classes
# else:
#     normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                                      std=[0.229, 0.224, 0.225])
#     trainset = ImagenetNoise(
#         transform=transforms.Compose([
#             transforms.Resize(256),
#             transforms.RandomResizedCrop(224),
#             transforms.RandomHorizontalFlip(),
#             transforms.ToTensor(),
#             normalize,
#         ]),
#         xnoise_rate=args.xnoise_rate,
#         xnoise_arg=args.xnoise_arg,
#         xnoise_type=args.xnoise_type,
#         ynoise_type=args.ynoise_type,
#         ynoise_rate=args.ynoise_rate,
#         random_state=args.random_state,
#         num_classes=args.num_classes
#     )
#     num_classes = args.num_classes
#     testset = ImagenetNoise(
#         train=False,
#         transform=transforms.Compose([
#             transforms.Resize(256),
#             transforms.CenterCrop(224),
#             transforms.ToTensor(),
#             normalize,
#         ]),
#         num_classes=args.num_classes
#     )



# trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False,num_workers=4)
# testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)


import timm
import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login

import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login


# pretrained=True needed to load UNI weights (and download weights for the first time)
# init_values need to be passed in to successfully load LayerScale parameters (e.g. - block.0.ls1.gamma)
# model = timm.create_model("hf-hub:MahmoodLab/UNI", pretrained=True, init_values=1e-5, dynamic_img_size=True)
# transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
# model.eval()


timm_kwargs = {
   'img_size': 224, 
   'patch_size': 14, 
   'depth': 24,
   'num_heads': 24,
   'init_values': 1e-5, 
   'embed_dim': 1536,
   'mlp_ratio': 2.66667*2,
   'num_classes': 0, 
   'no_embed_class': True,
   'mlp_layer': timm.layers.SwiGLUPacked, 
   'act_layer': torch.nn.SiLU, 
   'reg_tokens': 8, 
   'dynamic_img_size': True
  }
model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
model.eval()
del model

import timm
import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login

# # pretrained=True needed to load UNI weights (and download weights for the first time)
# # using UNI2-h as example
# timm_kwargs = {
#    'img_size': 224, 
#    'patch_size': 14, 
#    'depth': 24,
#    'num_heads': 24,
#    'init_values': 1e-5, 
#    'embed_dim': 1536,
#    'mlp_ratio': 2.66667*2,
#    'num_classes': 0, 
#    'no_embed_class': True,
#    'mlp_layer': timm.layers.SwiGLUPacked, 
#    'act_layer': torch.nn.SiLU, 
#    'reg_tokens': 8, 
#    'dynamic_img_size': True
#   }

# model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
# transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
# model.eval()

# model = timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=True)


# transform = transforms.Compose(
#     [
#         transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
#         transforms.CenterCrop(224),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
#     ]
# )

# from musk import utils, modeling
# from timm.models import create_model
# import torch
# model = create_model("musk_large_patch16_384")
# utils.load_model_and_may_interpolate("hf_hub:xiangjx/musk", model, 'model|module', '')
# model.to(device="cuda")

# import torchvision
# from PIL import Image
# from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD

# transform = torchvision.transforms.Compose([
#     torchvision.transforms.Resize(384, interpolation=3, antialias=True),
#     torchvision.transforms.CenterCrop((384, 384)),
#     torchvision.transforms.ToTensor(),
#     torchvision.transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD)
# ])


# print('| Building and loading pretrained model type [' + args.arch + ']')

# dataroot = '../../assets/data/tcga_luadlusc'

# # create some image folder datasets for train/test and their data laoders
# train_dataset = torchvision.datasets.ImageFolder(j_(dataroot, 'train'), transform=transform)
# test_dataset = torchvision.datasets.ImageFolder(j_(dataroot, 'test'), transform=transform)
# trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
# testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

import traceback
from abc import abstractmethod
import timm
import torch

class InferenceEncoder(torch.nn.Module):
    
    def __init__(self, weights_path=None, **build_kwargs):
        super(InferenceEncoder, self).__init__()
        
        self.weights_path = weights_path
        self.model, self.precision = self._build(weights_path, **build_kwargs)
        
    def forward(self, x):
        z = self.model(x)
        return z
        
    @abstractmethod
    def _build(self, **build_kwargs):
        pass

class ResNet50InferenceEncoder(InferenceEncoder):
    def _build(
        self, 
        _,
        pretrained=True, 
        timm_kwargs={"features_only": True, "out_indices": [3], "num_classes": 0},
        pool=True
    ):
        import timm

        model = timm.create_model("resnet50.tv_in1k", pretrained=pretrained, **timm_kwargs)
        precision = torch.float32
        if pool:
            self.pool = torch.nn.AdaptiveAvgPool2d(1)
        else:
            self.pool = None
        
        return model, precision
    
    def forward(self, x):
        out = self.forward_features(x)
        if self.pool:
            out = self.pool(out).squeeze(-1).squeeze(-1)
        return out
    
    def forward_features(self, x):
        out = self.model(x)
        if isinstance(out, list):
            assert len(out) == 1
            out = out[0]
        return out

model = ResNet50InferenceEncoder()
model.eval()

# Older versions of timm have compatibility issues. Please ensure that you use a newer version by running the following command: pip install timm>=1.0.3.


data_dir = '/home/username/scratch/panda_data/'
df_train = pd.read_csv(os.path.join(data_dir, 'train.csv'))

df_train['isup_grade'].value_counts()

n_tiles=1

class PANDADataset(Dataset):
    def __init__(self,
                 df,
                 image_size,
                 n_tiles=n_tiles,
                 tile_mode=0,
                 rand=False,
                 transform=None,
                ):

        self.df = df.reset_index(drop=True)
        self.image_size = image_size
        self.n_tiles = n_tiles
        self.tile_mode = tile_mode
        self.rand = rand
        self.transform = transform

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        row = self.df.iloc[index]
        img_id = row.image_id
        
        images = torch.load("/home/username/scratch/panda_data/train_images_pkl/" + img_id +'.pkl')
        images = images[0,:,:,:]
        label = row['isup_grade']
        
        return images, torch.tensor(label)

import sklearn.model_selection

train_index, test_index = sklearn.model_selection.train_test_split(df_train.index, random_state=2024)
# train_index, valid_index = sklearn.model_selection.train_test_split(train_index, random_state=2024)

train_dataset_new = PANDADataset(df_train.loc[train_index], 224*224, 1, 0, transform=transform)
# valid_dataset = PANDADataset(df_train.loc[valid_index], 224*224, 1, 0, transform=transform)
test_dataset = PANDADataset(df_train.loc[test_index], 224*224, 1, 0, transform=transform)

trainloader = torch.utils.data.DataLoader(train_dataset_new, batch_size=512, shuffle=True)
# valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=512, shuffle=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=512, shuffle=False)


up_sample = nn.Upsample(size=(224,224), mode='bilinear')

# pretrain_model = get_pretrained_model(args)
# if args.arch == 'resnet34':
#     pretrain_model = resnet.resnet34(pretrained=True, num_classes=1000)
# if args.arch == 'resnet50':
#     pretrain_model = resnet.resnet50(pretrained=True, num_classes=1000)
# if args.arch == 'resnet101':
#     pretrain_model = resnet.resnet101(pretrained=True, num_classes=1000)

compute_mean_cov(model,args)

