from __future__ import print_function
import os
import json

from torchvision import datasets, transforms
from torchvision.datasets.folder import ImageFolder, default_loader

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform


from genericpath import exists
from torch import nn
import os, sys
import copy
import hashlib
import errno
import torchvision
import torch
from torch.utils.data import TensorDataset
import torchvision.transforms as transforms
import config as cf 
from networks import Wide_ResNet
# from datasets import CIFAR10Noise, CIFAR100Noise, TinyImagenetNoise, SVHN, LSUN

from transformers import BeitFeatureExtractor, BeitForImageClassification, BeitConfig, ViTForImageClassification, DeiTForImageClassification, CLIPModel
import torchvision.models as torchvision_models
from pytorch_pretrained_vit import ViT

from networks import *
import clip

def build_dataset(is_train, args):
    # transform = build_transform(is_train, args)
    # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                                  std=[0.229, 0.224, 0.225])
    normalize = transforms.Normalize(cf.mean[args.data_set.lower()], cf.std[args.data_set.lower()])                
                                      
    transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
    ])

    if args.data_set == 'CIFAR10':
        dataset = datasets.CIFAR10(args.data_path, train=is_train, transform=transform)
        nb_classes = 10
    if args.data_set == 'CIFAR100':
        dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)
        nb_classes = 100
    elif args.data_set == 'IMNET':
        root = os.path.join(args.data_path, 'train' if is_train else 'val')
        dataset = datasets.ImageFolder(root, transform=transform)
        nb_classes = 1000

    return dataset, nb_classes



def build_transform(is_train, args):
    resize_im = args.input_size > 32
    if is_train:
        # this should always dispatch to transforms_imagenet_train
        transform = create_transform(
            input_size=args.input_size,
            is_training=True,
            color_jitter=args.color_jitter,
            auto_augment=args.aa,
            interpolation=args.train_interpolation,
            re_prob=args.reprob,
            re_mode=args.remode,
            re_count=args.recount,
        )
        if not resize_im:
            # replace RandomResizedCropAndInterpolation with
            # RandomCrop
            transform.transforms[0] = transforms.RandomCrop(
                args.input_size, padding=4)
        return transform

    t = []
    if resize_im:
        size = int((256 / 224) * args.input_size)
        t.append(
            transforms.Resize(size, interpolation=3),  # to maintain same ratio w.r.t. 224 images
        )
        t.append(transforms.CenterCrop(args.input_size))

    t.append(transforms.ToTensor())
    t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
    return transforms.Compose(t)

def get_pretrained_model(args):
    # create model
    print("=> creating model '{}'".format(args.arch))
    if args.arch.startswith('base_21k'):
        model = ViT('B_16', pretrained=True,num_classes=768)
        linear_keyword = 'fc'
        w = torch.Tensor(768,768)
        nn.init.eye_(w)
        # print(w)
        getattr(model, linear_keyword).weight.data = w
    elif args.arch.startswith('base_21k_fine_tune'):
        model = ViT('B_16_imagenet1k', pretrained=True,num_classes=768)
        linear_keyword = 'fc'
        w = torch.Tensor(768,768)
        nn.init.eye_(w)
        # print(w)
        getattr(model, linear_keyword).weight.data = w
    elif args.arch =='hug_beit_21k':
        linear_keyword = 'classifier'
        model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224-pt22k')
        model.classifier = nn.Linear(768,768)
        w = torch.Tensor(768,768)
        nn.init.eye_(w)
        getattr(model, linear_keyword).weight.data = w
    elif args.arch =='hug_deit':
        linear_keyword = 'classifier'
        model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224")
        model.classifier = nn.Linear(768,768)
        w = torch.Tensor(768,768)
        nn.init.eye_(w)
        getattr(model, linear_keyword).weight.data = w
    elif args.arch=='hug_beit_21k_fine_tune':
        linear_keyword = 'classifier'
        model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
        model.classifier = nn.Linear(768,768)
        w = torch.Tensor(768,768)
        nn.init.eye_(w)
        getattr(model, linear_keyword).weight.data = w
    elif args.arch=='hug_vit_21k':
        linear_keyword = 'classifier'
        model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k')
        model.classifier = nn.Linear(768,768)
        w = torch.Tensor(768,768)
        nn.init.eye_(w)
        getattr(model, linear_keyword).weight.data = w
    elif args.arch=='hug_mae':
        linear_keyword = 'classifier'
        model = ViTForImageClassification.from_pretrained('facebook/vit-mae-base')
        # print(model)
        model.classifier = nn.Linear(768,768)
        w = torch.Tensor(768,768)
        nn.init.eye_(w)
        getattr(model, linear_keyword).weight.data = w
    elif args.arch == 'clip32':
        model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        # model = model.get_image_features()
        # print(model)
    elif args.arch == 'clip16':
        model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
    elif args.arch == 'clip_r50':
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model, preprocess = clip.load("RN50", device=device)
    elif args.arch == 'resnet34':
        model = resnet.resnet34(pretrained=True, num_classes=1000)
    elif args.arch == 'resnet50':
        model = resnet.resnet50(pretrained=True, num_classes=1000)
    elif args.arch == 'resnet101':
        model = resnet.resnet101(pretrained=True, num_classes=1000)
    else:
        # model = torchvision_models.__dict__[args.arch]()
        model = torchvision_models.__dict__[args.arch](num_classes=768)
        linear_keyword = 'fc'
        w = torch.Tensor(768,768)
        nn.init.eye_(w)
        # print(w)
        getattr(model, linear_keyword).weight.data = w

    # freeze all layers but the last fc
    # for name, param in model.named_parameters():
    #     if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]:
    #         param.requires_grad = False
    # init the fc layer
    # getattr(model, linear_keyword).weight.data.eye_(mean=0.0, std=0.01)
    # getattr(model, linear_keyword).bias.data.zero_()
    # getattr(model, linear_keyword).weight.data.eye_()
    if not args.arch.startswith('clip') and not args.arch.startswith('resnet'):
        getattr(model, linear_keyword).bias.data.zero_()

    # load from pre-trained, before DistributedDataParallel constructor
    if args.pretrained:
        if os.path.isfile(args.pretrained):
            print("=> loading checkpoint '{}'".format(args.pretrained))
            checkpoint = torch.load(args.pretrained, map_location="cpu")

            # rename moco pre-trained keys
            state_dict = checkpoint['state_dict']
            for k in list(state_dict.keys()):
                # retain only base_encoder up to before the embedding layer
                if k.startswith('module.base_encoder') and not k.startswith('module.base_encoder.%s' % linear_keyword):
                    # remove prefix
                    state_dict[k[len("module.base_encoder."):]] = state_dict[k]
                # delete renamed or unused k
                del state_dict[k]

            args.start_epoch = 0
            msg = model.load_state_dict(state_dict, strict=False)
            assert set(msg.missing_keys) == {"%s.weight" % linear_keyword, "%s.bias" % linear_keyword}

            print("=> loaded pre-trained model '{}'".format(args.pretrained))
        else:
            print("=> no checkpoint found at '{}'".format(args.pretrained))

    return model

from sklearn.metrics import roc_auc_score
import numpy as np

class LGMLoss(nn.Module):
    def __init__(self, num_classes, feat_dim, alpha):
        super(LGMLoss, self).__init__()
        self.feat_dim = feat_dim
        self.num_classes = num_classes
        self.alpha = alpha

        self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
        self.log_covs = nn.Parameter(torch.zeros(num_classes, feat_dim))

    def forward(self, feat, label):
        batch_size = feat.shape[0]
        log_covs = torch.unsqueeze(self.log_covs, dim=0)


        covs = torch.exp(log_covs) # 1*c*d
        tcovs = covs.repeat(batch_size, 1, 1) # n*c*d
        diff = torch.unsqueeze(feat, dim=1) - torch.unsqueeze(self.centers, dim=0)
        wdiff = torch.div(diff, tcovs)
        diff = torch.mul(diff, wdiff)
        dist = torch.sum(diff, dim=-1) #eq.(18)


        y_onehot = torch.FloatTensor(batch_size, self.num_classes)
        y_onehot.zero_()
        y_onehot.scatter_(1, torch.unsqueeze(label, dim=-1), self.alpha)
        y_onehot = y_onehot + 1.0
        margin_dist = torch.mul(dist, y_onehot)

        slog_covs = torch.sum(log_covs, dim=-1) #1*c
        tslog_covs = slog_covs.repeat(batch_size, 1)
        margin_logits = -0.5*(tslog_covs + margin_dist) #eq.(17)
        logits = -0.5 * (tslog_covs + dist)

        cdiff = feat - torch.index_select(self.centers, dim=0, index=label.long())
        cdist = cdiff.pow(2).sum(1).sum(0) / 2.0

        slog_covs = torch.squeeze(slog_covs)
        reg = 0.5*torch.sum(torch.index_select(slog_covs, dim=0, index=label.long()))
        likelihood = (1.0/batch_size) * (cdist + reg)

        return logits, margin_logits, likelihood

class LGMLoss_fix(nn.Module):
    def __init__(self, num_classes, feat_dim, alpha):
        super(LGMLoss_fix, self).__init__()
        self.feat_dim = feat_dim
        self.num_classes = num_classes
        self.alpha = alpha

        self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
        self.log_covs = nn.Parameter(torch.zeros(num_classes, feat_dim))

    def forward(self, feat, label):
        batch_size = feat.shape[0]
        log_covs = torch.unsqueeze(self.log_covs, dim=0)


        covs = torch.exp(log_covs) # 1*c*d
        tcovs = covs.repeat(batch_size, 1, 1) # n*c*d
        diff = torch.unsqueeze(feat, dim=1) - torch.unsqueeze(self.centers, dim=0)
        wdiff = torch.div(diff, tcovs)
        diff = torch.mul(diff, wdiff)
        dist = torch.sum(diff, dim=-1) #eq.(18)


        y_onehot = torch.FloatTensor(batch_size, self.num_classes)
        y_onehot.zero_()
        y_onehot.scatter_(1, torch.unsqueeze(label, dim=-1), self.alpha)
        y_onehot = y_onehot + 1.0
        margin_dist = torch.mul(dist, y_onehot)

        slog_covs = torch.sum(log_covs, dim=-1) #1*c
        tslog_covs = slog_covs.repeat(batch_size, 1)
        margin_logits = -0.5*(tslog_covs + margin_dist) #eq.(17)
        logits = -0.5 * (tslog_covs + dist)

        cdiff = feat - torch.index_select(self.centers, dim=0, index=label.long())
        cdist = cdiff.pow(2).sum(1).sum(0) / 2.0

        slog_covs = torch.squeeze(slog_covs)
        reg = 0.5*torch.sum(torch.index_select(slog_covs, dim=0, index=label.long()))
        likelihood = (1.0/batch_size) * (cdist + reg)

        return logits, margin_logits, likelihood

class ranking_loss(nn.Module):
    def __init__(self):
        super(ranking_loss, self).__init__()
    
    def forward(self,rank_input1, rank_input2, target_rank, margin):
        # print(rank_input1)
        zeros = torch.zeros_like(rank_input1)
        # loss = torch.max(zeros,-target_rank*(rank_input1-rank_input2))
        loss = torch.max(zeros,-target_rank*(rank_input1-rank_input2)+margin)
        # print(-target_rank*(rank_input1-rank_input2))
        # temp_margin = torch.abs(loss.ceil() * margin)
        # print(loss[0:20])
        # print(temp_margin[0:20])

        return torch.mean(loss)

class LogitsMinMax:
    def __init__(self,):
        super(LogitsMinMax, self).__init__()
        self.min, self.max = 1e20, -1e10

    def run(self, x):
        self.min = min(x.min(), self.min)
        self.max = max(x.max(), self.max)
        return self.min, self.max

class MahaDistNormalizer:
    def __init__(self,):
        super(MahaDistNormalizer, self).__init__()
        self.min, self.max = 1e20, -1e10

    def run(self, x, left, right):
        self.min = min(x.min(), self.min)
        self.max = max(x.max(), self.max)
        k = (right-left)/(self.max - self.min)
        return left+k*(x - self.min)

# class MahaDistNormalizer:
#     def __init__(self,):
#         super(MahaDistNormalizer, self).__init__()
#         self.min, self.max = 1e20, -1e10

#     def run(self, x):
#         self.min = min(x.min(), self.min)
#         self.max = max(x.max(), self.max)
#         return (x - self.min) / (self.max - self.min)

def maha_distance(xs,cov_inv_in,mean_in,norm_type=None):
  diffs = xs - mean_in.reshape([1,-1])
#   print(cov_inv_in.shape,mean_in.shape,diffs.shape)

  second_powers = np.matmul(diffs,cov_inv_in)*diffs
#   print(second_powers.shape)

  if norm_type in [None,"L2"]:
    return np.sum(second_powers,axis=1)
  elif norm_type in ["L1"]:
    return np.sum(np.sqrt(np.abs(second_powers)),axis=1)
  elif norm_type in ["Linfty"]:
    return np.max(second_powers,axis=1)

def likelihood(xs,cov_inv_in,mean_in,prior=1.):
    Xm = xs - mean_in
    Xm_cov = (Xm @ cov_inv_in) * Xm
    Xm_cov_sum = Xm_cov.sum(axis=1)
    return 0.5*Xm_cov_sum + np.log(prior)
    x_mu_t = 2 * xs - mean_in.reshape([1,-1])
    mu_t = mean_in.reshape([1,-1])
    second_powers = 0.5 * np.matmul(mu_t,cov_inv_in) * x_mu_t
    return -np.sum(second_powers,axis=1) + np.log(prior)

def maha(
    indist_train_embeds_in,
    indist_train_labels_in,
    subtract_mean = False,
    normalize_to_unity = False,
    indist_classes = 100,
    ):
  
  # storing the replication results
  maha_intermediate_dict = dict()
  
  description = ""
  
  all_train_mean = np.mean(indist_train_embeds_in,axis=0,keepdims=True)

  indist_train_embeds_in_touse = indist_train_embeds_in

  if subtract_mean:
    indist_train_embeds_in_touse -= all_train_mean
    description = description+" subtract mean,"

  if normalize_to_unity:
    indist_train_embeds_in_touse = indist_train_embeds_in_touse / np.linalg.norm(indist_train_embeds_in_touse,axis=1,keepdims=True)
    description = description+" unit norm,"

  #full train single fit
  mean = np.mean(indist_train_embeds_in_touse,axis=0)
  tiny = 1e-15
  cov = np.cov((indist_train_embeds_in_touse-(mean.reshape([1,-1]))).T + tiny)

  eps = 1e-8
  
  cov_inv = np.linalg.inv(cov)
  if str(np.max(cov_inv)) == 'nan':
      cov_inv = np.linalg.pinv(cov)

  #getting per class means and covariances
  class_means = []
  class_cov_invs = []
  class_covs = []
  for c in range(indist_classes):

    mean_now = np.mean(indist_train_embeds_in_touse[indist_train_labels_in == c],axis=0)

    cov_now = np.cov((indist_train_embeds_in_touse[indist_train_labels_in == c]-(mean_now.reshape([1,-1]))).T + tiny)
    class_covs.append(cov_now)
    # print(c)

    eps = 1e-8
    try:
        cov_inv_now = np.linalg.inv(cov_now)
    except:
        cov_inv_now = np.linalg.pinv(cov_now)

    class_cov_invs.append(cov_inv_now)
    class_means.append(mean_now)

  #the average covariance for class specific
  class_cov_invs = [np.linalg.inv(np.mean(np.stack(class_covs,axis=0),axis=0))]*len(class_covs)

  maha_intermediate_dict["class_cov_invs"] = class_cov_invs
  maha_intermediate_dict["class_means"] = class_means
  maha_intermediate_dict["cov_inv"] = cov_inv
  maha_intermediate_dict["mean"] = mean

  # out_totrain = maha_distance(outdist_test_embeds_in_touse,cov_inv,mean,norm_name)
  # in_totrain = maha_distance(indist_test_embeds_in_touse,cov_inv,mean,norm_name)

  # out_totrainclasses = [maha_distance(outdist_test_embeds_in_touse,class_cov_invs[c],class_means[c],norm_name) for c in range(indist_classes)]
  # in_totrainclasses = [maha_distance(indist_test_embeds_in_touse,class_cov_invs[c],class_means[c],norm_name) for c in range(indist_classes)]

  # out_scores = np.min(np.stack(out_totrainclasses,axis=0),axis=0)
  # in_scores = np.min(np.stack(in_totrainclasses,axis=0),axis=0)

  # if subtract_train_distance:
  #   out_scores = out_scores - out_totrain
  #   in_scores = in_scores - in_totrain


  # onehots = np.array([1]*len(out_scores) + [0]*len(in_scores))
  # scores = np.concatenate([out_scores,in_scores],axis=0)

  return maha_intermediate_dict

def get_maha_distance(train_logits, class_cov_invs, class_means, targets, norm_name = "L2"):
  # out_totrainclasses = [maha_distance(train_logits,class_cov_invs[c],class_means[c],norm_name) for c in range(indist_classes)]
  # out_scores = np.min(np.stack(out_totrainclasses,axis=0),axis=0)
  #print(targets)
  #print(train_logits[0,:].shape)
  #print(class_means[targets[0]].shape)
  #print(class_cov_invs[targets[0]].shape)
  scores = np.array([maha_distance(train_logits[i].reshape([1,-1]),class_cov_invs[targets[i]],class_means[targets[i]],norm_name) for i in range(len(targets))])
  
#   print(scores)

  return scores

def get_gda_posterior(train_logits, class_cov_invs, class_means, num_classes):
    
    scores = np.array([likelihood(train_logits,class_cov_invs[num],class_means[num]) for num in range(num_classes)])

    print(scores.shape)
    # scores = scores.reshape(-1, num_classes)
    return scores

def get_maha_distance_cov(train_logits, cov_invs, class_means, targets, norm_name = "L2"):
  # out_totrainclasses = [maha_distance(train_logits,class_cov_invs[c],class_means[c],norm_name) for c in range(indist_classes)]
  # out_scores = np.min(np.stack(out_totrainclasses,axis=0),axis=0)
  #print(targets)
  #print(train_logits[0,:].shape)
  #print(class_means[targets[0]].shape)
  #print(class_cov_invs[targets[0]].shape)
  scores = np.array([maha_distance(train_logits[i].reshape([1,-1]),cov_invs,class_means[targets[i]],norm_name) for i in range(len(targets))])
  
#   print(scores)

  return scores

def get_maha_predict(train_logits, class_cov_invs, class_means, num_classes, norm_name = "L2"):
    out_totrainclasses = [maha_distance(train_logits,class_cov_invs[c],class_means[c],norm_name) for c in range(num_classes)]
    out_scores = np.argmin(np.stack(out_totrainclasses,axis=0),axis=0)
    # print(out_scores)
    return out_scores

def get_relative_maha_predict(train_logits,cov_invs, class_cov_invs, means, class_means, num_classes, norm_name = "L2"):
    maha_0 = np.array([maha_distance(train_logits[i],cov_invs,means,norm_name) for i in range(len(train_logits))]).reshape(1,-1)
    # print(train_logits.shape)

    out_totrainclasses = np.array([maha_distance(train_logits,class_cov_invs[c],class_means[c],norm_name) for c in range(num_classes)])
    # print(out_totrainclasses.shape,maha_0.shape)
    out_totrainclasses = out_totrainclasses - maha_0
    out_scores = np.argmin(out_totrainclasses,axis=0)
    # print('1111',out_scores)
    return out_totrainclasses

# def get_relative_maha_distance(train_logits, cov_invs, means, class_means, targets, norm_name = "L2"):
#   maha_0 = np.array([maha_distance(train_logits[i],cov_invs,means,norm_name) for i in range(len(targets))])
#   maha_k = np.array([maha_distance(train_logits[i].reshape([1,-1]),cov_invs,class_means[targets[i]],norm_name) for i in range(len(targets))])
#   print(maha_0.shape,maha_k.shape)
#   scores = maha_k - maha_0

#   return scores

def get_relative_maha_distance(train_logits, cov_invs, class_cov_invs, means, class_means, targets, norm_name = "L2"):
    maha_0 = np.array([maha_distance(train_logits[i],cov_invs,means,norm_name) for i in range(len(targets))])
    maha_k = np.array([maha_distance(train_logits[i].reshape([1,-1]),class_cov_invs[targets[i]],class_means[targets[i]],norm_name) for i in range(len(targets))])
    # print(maha_0.shape,maha_k.shape)
    scores = maha_k - maha_0

    return scores
  
def get_maha_distance_feature(feature, cov_invs, means):
    x_minus_mu = feature - means
    second_powers = np.dot(x_minus_mu,cov_invs) * x_minus_mu
    # mahal = np.dot(left_term, x_minus_mu.T)
    mahal = np.sum(second_powers,axis=1)

    return mahal

def get_transforms(train=True, dataset='cifar10'):
    if dataset == 'ti':
        if train:
            return transforms.Compose([
                transforms.RandomResizedCrop(56),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(cf.mean[dataset], cf.std[dataset]),
            ]) # meanstd transformation
        else:
            return transforms.Compose([transforms.Resize(64),
                transforms.CenterCrop(56),
                transforms.ToTensor(),
                transforms.Normalize(cf.mean[dataset], cf.std[dataset]),
            ])
    elif dataset == "lsun":
        return transforms.Compose([
            # transforms.Resize([32, 32]),
            transforms.Resize([224, 224]),
            transforms.ToTensor(),
            transforms.Normalize(cf.mean[dataset], cf.std[dataset]),
        ])
    else:
        if train:
            return transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(cf.mean[dataset], cf.std[dataset]),
            ]) # meanstd transformation
        else:
            return transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(cf.mean[dataset], cf.std[dataset]),
            ])


# DSETS = {
#     'cifar10': CIFAR10Noise,
#     'cifar100': CIFAR100Noise,
#     'ti': TinyImagenetNoise,
#     'svhn': SVHN,
#     'lsun': LSUN,
# }

DSETS = {
}
def prepare_dset(args):
    transform_train = get_transforms(train=True, dataset=args.dataset)
    transform_test = get_transforms(train=False, dataset=args.dataset)
    DSET = DSETS[args.dataset]
    print(f"| Preparing {args.dataset} dataset with noisy img...")
    trainset = DSET(
                    download=True,
                    train=True,
                    transform=transform_train,
                    xnoise_type=args.xnoise_type,
                    xnoise_rate=args.xnoise_rate,
                    xnoise_arg=args.xnoise_arg,
                    ynoise_type=args.ynoise_type,
                    ynoise_rate=args.ynoise_rate,
                    # trigger_size = args.trigger_size,
                    # trigger_rate = args.trigger_ratio,
                    random_state=args.random_state
                    )
    testset = DSET(
                    download=True,
                    train=False,
                    transform=transform_test)
    trainvalset = copy.deepcopy(trainset)
    trainvalset.transform = transform_test # no data aug
    return trainset, testset, trainvalset


def prepare_dset_multi(args):
    transform_train = get_transforms(train=True, dataset=args.dataset)
    transform_test = get_transforms(train=False, dataset=args.dataset)
    trainset = TinyImagenetNoiseMulti(
                    download=True,
                    train=True,
                    transform=transform_train,
                    xnoise_types=args.xnoise_types,
                    xnoise_rates=args.xnoise_rates,
                    xnoise_args=args.xnoise_args,
                    ynoise_type=args.ynoise_type,
                    ynoise_rate=args.ynoise_rate,
                    random_state = args.random_state
                    )
    testset = TinyImagenetNoise(
                    download=True,
                    train=False,
                    transform=transform_test)
    trainvalset = copy.deepcopy(trainset)
    trainvalset.transform = transform_test # no data aug
    return trainset, testset, trainvalset


def prepare_dset_large(args):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                            std=[0.229, 0.224, 0.225])
    transform_train = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]) 
    transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])  
    if args.dataset == 'webvision':
        train_dataset = webvision_dataset(transform=transform_train, mode="all", num_class=50)
        val_dataset =  webvision_dataset(transform=transform_test, mode="test", num_class=50)    
    else:
        raise NotImplementedError
   
    return train_dataset, val_dataset


def prepare_dset_test(dataset):
    print(f"| Preparing {dataset}")
    if dataset == "gaussian_noise":
        # 10000 samples
        # return TensorDataset(torch.randn(10000, 3, 32, 32), torch.ones(10000))
        return TensorDataset(torch.randn(10000, 3, 224, 224), torch.ones(10000))
    elif dataset == "uniform_noise":
        # 10000 samples
        return TensorDataset(torch.rand(10000, 3, 32, 32), torch.ones(10000))
        return TensorDataset(torch.rand(10000, 3, 224, 224), torch.ones(10000))
    elif dataset == "ti":
        transform_test = transforms.Compose([transforms.Resize(32),
                                            transforms.ToTensor(),
                                            transforms.Normalize(cf.mean[dataset], cf.std[dataset]),])
    else:
        transform_test = get_transforms(train=False, dataset=dataset)
    DSET = DSETS[dataset]
    testset = DSET(download=True,
                   train=False,
                   transform=transform_test)
    return testset
    

def adjust_learning_rate(optimizer, epoch, init_lr, steps):
    """Sets the learning rate"""
    lr = init_lr
    for step in steps:
        if epoch > step: lr = lr * 0.1 
        else: break 
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def check_dir(path):
    if not os.path.exists(path):
        os.makedirs(path, exist_ok=True)


def update_print(s):
    sys.stdout.write('\r')
    sys.stdout.write(s)
    sys.stdout.flush()

class AverageMeter():
    def __init__(self):
        self.sum = 0.
        self.cnt = 0
        self.history = []
    def append(self,x):
        self.history += list(x.cpu().numpy())
        self.cnt += x.shape[0]
        self.sum += x.sum().cpu().item()
    def get(self):
        return self.sum / self.cnt