from __future__ import print_function
from matplotlib.pyplot import axis
from numpy.lib.function_base import append
import random
import torch
from torch import logit
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import config as cf
from datasets import ImagenetNoise

import torchvision.transforms as transforms

import os
import argparse

from torch.utils.data import Dataset
from PIL import Image
import numpy as np
from PIL import Image

from utils import prepare_dset, maha
from networks import *

from utils import get_pretrained_model
import matplotlib.pyplot as plt



def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
    #  torch.backends.cudnn.deterministic = True
setup_seed(20)

parser = argparse.ArgumentParser(description='Ensemble Training')
# pretrained models setting
parser.add_argument('--maha_file', default='./ssl/maha_dict.npy', type=str)
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50')
parser.add_argument('--pretrained', default='', type=str,
                    help='path to moco pretrained checkpoint')
parser.add_argument('--pretrained_model', default='vit', type=str, help='SSL feature map type')


parser.add_argument('--batch_size', default=1024, type=int)
parser.add_argument('--dataset', default='cifar10', type=str, help='cifar10/cifar100')
parser.add_argument('--num_classes', default=10, type=int)
parser.add_argument('--random_state', type=int, default=0)




parser.add_argument('--ynoise_type', default='symmetric', type=str, help='symmetric/pairflip')
parser.add_argument('--ynoise_rate', default=0.0, type=float, help='label noise rate')
parser.add_argument('--xnoise_type', default='blur', type=str, help='gaussian/blur')
parser.add_argument('--xnoise_arg', default=1, type=float)
parser.add_argument('--xnoise_rate', default=0.0, type=float)
parser.add_argument('--trigger_size', type=int, default=3)
parser.add_argument('--trigger_ratio', type=float, default=0.)


args = parser.parse_args()

# os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# Hyper Parameter settings
use_cuda = torch.cuda.is_available()
best_acc = 0
batch_size = args.batch_size
# Custom_Dataset class
class Custom_Dataset(Dataset):
    def __init__(self, x, y, data_set, transform=None):
        self.x_data = x
        self.y_data = y
        self.data = data_set
        self.transform = transform

    def __len__(self):
        return len(self.x_data)

    # return idx
    def __getitem__(self, idx):
        if self.data == 'cifar':
            img = Image.fromarray(self.x_data[idx])
        elif self.data == 'svhn':
            img = Image.fromarray(np.transpose(self.x_data[idx], (1, 2, 0)))

        x = self.transform(img)

        return x, self.y_data[idx], idx


def predict_maha(model, epoch, args):
    feature_all = []
    labels_all = []
    for i in range(epoch):
        print(i)
        with torch.no_grad():
        # for batch_idx, ((inputs, xnoisy), (targets, true_tar)) in enumerate(trainloader):
            for batch_idx, (_, (inputs, xnoisy), (targets, true_tar)) in enumerate(trainloader):

                inputs = inputs.cuda()
                if args.dataset != 'imagenet':
                    inputs = up_sample(inputs)
                if args.arch.startswith('clip'):
                    
                    # print(output.shape)
                    if args.arch == 'clip_r50':
                        output = model.encode_image(inputs)
                    elif args.arch == 'clip_r101':
                        output = model.encode_image(inputs)
                    else:
                        output = model.get_image_features(inputs)
                else:
                    output = model(inputs)
                    # print(output.shape)
                if args.arch.startswith('hug'):
                    logits = output.logits.cpu().data.numpy()
                else:
                    logits = output.cpu().data.numpy()
                labels_np = targets.cpu().data.numpy()
                if batch_idx == 0 and i == 0:
                    feature_all = logits
                    labels_all = labels_np
                else:
                    feature_all = np.concatenate((feature_all,logits),axis=0)
                    labels_all = np.concatenate((labels_all,labels_np),axis=0)
                print(feature_all.shape)
                # print(num_classes)
    # print(num_classes)
    maha_intermediate_dict = maha(feature_all,labels_all,indist_classes = num_classes)
    print(feature_all.shape)
    np.save(args.maha_file, maha_intermediate_dict)





def compute_mean_cov(pretrain_model,args):
    
    pretrain_model.cuda()
    if not args.arch.startswith('clip'):
        pretrain_model = torch.nn.DataParallel(pretrain_model)
        pretrain_model.eval()
    # train for N epochs
    predict_maha(pretrain_model, 3, args)
# Data Uplaod
print('\n[Phase 1] : Data Preparation')


if args.dataset != 'imagenet':
    trainset, testset, trainvalset = prepare_dset(args)
    num_classes = trainset.nb_classes
else:
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    trainset = ImagenetNoise(
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]),
        xnoise_rate=args.xnoise_rate,
        xnoise_arg=args.xnoise_arg,
        xnoise_type=args.xnoise_type,
        ynoise_type=args.ynoise_type,
        ynoise_rate=args.ynoise_rate,
        random_state=args.random_state,
        num_classes=args.num_classes
    )
    num_classes = args.num_classes
    testset = ImagenetNoise(
        train=False,
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]),
        num_classes=args.num_classes
    )

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False,num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)


print('| Building and loading pretrained model type [' + args.arch + ']')


up_sample = nn.Upsample(size=(224,224), mode='bilinear')


pretrain_model = get_pretrained_model(args)
# if args.arch == 'resnet34':
#     pretrain_model = resnet.resnet34(pretrained=True, num_classes=1000)
# if args.arch == 'resnet50':
#     pretrain_model = resnet.resnet50(pretrained=True, num_classes=1000)
# if args.arch == 'resnet101':
#     pretrain_model = resnet.resnet101(pretrained=True, num_classes=1000)

compute_mean_cov(pretrain_model,args)

