##referred from https://openreview.net/forum?id=4WPhXYMK6N&noteId=k9e7SavlwP
from __future__ import print_function
from matplotlib.pyplot import axis
from numpy.lib.function_base import append
import random
import torch
from torch import logit
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import config as cf
from datasets import ImagenetNoise

import torchvision.transforms as transforms

import os
import argparse

from torch.utils.data import Dataset
from PIL import Image
import numpy as np
from PIL import Image

from utils import prepare_dset, maha
from networks import *

from utils import get_pretrained_model
import matplotlib.pyplot as plt

import torch
import torchvision
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np

import os
import torch

import sys
sys.path.append('/home/username/piusername/MUSK/')
os.environ['HF_HOME'] = '/home/username/scratch/'
os.environ['HF_TOKEN'] = ''

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
    #  torch.backends.cudnn.deterministic = True
setup_seed(20)

parser = argparse.ArgumentParser(description='Ensemble Training')
# pretrained models setting
parser.add_argument('--maha_file', default='./ssl/maha_dict.npy', type=str)
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50')
parser.add_argument('--pretrained', default='', type=str,
                    help='path to moco pretrained checkpoint')
parser.add_argument('--pretrained_model', default='vit', type=str, help='SSL feature map type')


parser.add_argument('--batch_size', default=1024, type=int)
parser.add_argument('--dataset', default='cifar10', type=str, help='cifar10/cifar100')
parser.add_argument('--num_classes', default=10, type=int)
parser.add_argument('--random_state', type=int, default=0)




parser.add_argument('--ynoise_type', default='symmetric', type=str, help='symmetric/pairflip')
parser.add_argument('--ynoise_rate', default=0.0, type=float, help='label noise rate')
parser.add_argument('--xnoise_type', default='blur', type=str, help='gaussian/blur')
parser.add_argument('--xnoise_arg', default=1, type=float)
parser.add_argument('--xnoise_rate', default=0.0, type=float)
parser.add_argument('--trigger_size', type=int, default=3)
parser.add_argument('--trigger_ratio', type=float, default=0.)


parser.add_argument('--select_dataloader', type=str, default="trainloader")


args = parser.parse_args()

num_classes = args.num_classes

# os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# Hyper Parameter settings
use_cuda = torch.cuda.is_available()
best_acc = 0
batch_size = args.batch_size
# Custom_Dataset class
class Custom_Dataset(Dataset):
    def __init__(self, x, y, data_set, transform=None):
        self.x_data = x
        self.y_data = y
        self.data = data_set
        self.transform = transform

    def __len__(self):
        return len(self.x_data)

    # return idx
    def __getitem__(self, idx):
        if self.data == 'cifar':
            img = Image.fromarray(self.x_data[idx])
        elif self.data == 'svhn':
            img = Image.fromarray(np.transpose(self.x_data[idx], (1, 2, 0)))

        x = self.transform(img)

        return x, self.y_data[idx], idx


def predict_maha(model, epoch, args):
    feature_all = []
    labels_all = []
    for i in range(epoch):
        print(i)
        with torch.no_grad():
        # for batch_idx, ((inputs, xnoisy), (targets, true_tar)) in enumerate(trainloader):
            for batch_idx, ( inputs, targets) in enumerate(eval(args.select_dataloader)):

                inputs = inputs.cuda()
                if args.dataset != 'imagenet':
                    inputs = up_sample(inputs)
                if args.arch.startswith('clip'):
                    
                    # print(output.shape)
                    if args.arch == 'clip_r50':
                        output = model.encode_image(inputs)
                    elif args.arch == 'clip_r101':
                        output = model.encode_image(inputs)
                    else:
                        output = model.get_image_features(inputs)
                else:
                    
                    if args.arch == 'uni':
                        output = model(inputs)
                    elif args.arch == 'musk':
                        output = model(inputs, with_head=False,
        out_norm=False,
        ms_aug=True,
        return_global=True  
        )[0]  
                    else:
                        output = model(inputs)
                    # print(output.shape)
                if args.arch.startswith('hug'):
                    logits = output.logits.cpu().data.numpy()
                else:
                    logits = output.cpu().data.numpy()
                labels_np = targets.cpu().data.numpy()
                if batch_idx == 0 and i == 0:
                    feature_all = logits
                    labels_all = labels_np
                else:
                    feature_all = np.concatenate((feature_all,logits),axis=0)
                    labels_all = np.concatenate((labels_all,labels_np),axis=0)
                print(feature_all.shape)
                # print(num_classes)
    # print(num_classes)
    maha_intermediate_dict = maha(feature_all,labels_all,indist_classes = num_classes)
    print(feature_all.shape)
    np.save(args.maha_file, maha_intermediate_dict)





def compute_mean_cov(pretrain_model,args):
    
    pretrain_model.cuda()
    if not args.arch.startswith('clip'):
        pretrain_model = torch.nn.DataParallel(pretrain_model)
        pretrain_model.eval()
    # train for N epochs
    predict_maha(pretrain_model, 3, args)
# Data Uplaod
print('\n[Phase 1] : Data Preparation')

import timm
import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login

import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login


# pretrained=True needed to load UNI weights (and download weights for the first time)
# init_values need to be passed in to successfully load LayerScale parameters (e.g. - block.0.ls1.gamma)
# model = timm.create_model("hf-hub:MahmoodLab/UNI", pretrained=True, init_values=1e-5, dynamic_img_size=True)
# transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
# model.eval()

import timm
import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login

from musk import utils, modeling
from timm.models import create_model
import torch
model = create_model("musk_large_patch16_384")
utils.load_model_and_may_interpolate("hf_hub:xiangjx/musk", model, 'model|module', '')
model.to(device="cuda")

import torchvision
from PIL import Image
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(384, interpolation=3, antialias=True),
    torchvision.transforms.CenterCrop((384, 384)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD)
])


print('| Building and loading pretrained model type [' + args.arch + ']')

dataroot = '../../assets/data/tcga_luadlusc'

# create some image folder datasets for train/test and their data laoders
train_dataset = torchvision.datasets.ImageFolder(j_(dataroot, 'train'), transform=transform)
test_dataset = torchvision.datasets.ImageFolder(j_(dataroot, 'test'), transform=transform)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


compute_mean_cov(model,args)

