# import gdown
import torch.nn.functional as F
from robustbench import load_model
from torch import nn
import torch
import os
import wget
from pathlib import Path
from modelZoo.ViTPytorch.models.modeling import VisionTransformer, CONFIGS
from modelZoo.robustOverfitting.preactresnet import PreActResNet18
import argparse
from modelZoo.ECCV2020OSAD.misc.utils import init_model
from modelZoo.ECCV2020OSAD.models import DenoiseResnet
from modelZoo.ALOE.models import densenet as dn
from modelZoo.informativeOutlierMining.eval_ood_detection import get_model_ATOM


def added_embedding_Rade2021Helper_R18(self, x):
    if self.padding > 0:
        x = F.pad(x, (self.padding,) * 4)
    out = (x - self.mean) / self.std
    out = self.conv_2d(out)
    out = self.layer_0(out)
    out = self.layer_1(out)
    out = self.layer_2(out)
    out = self.layer_3(out)
    out = self.relu(self.batchnorm(out))
    out = F.avg_pool2d(out, 4)
    out = out.view(out.size(0), -1)
    return out


def added_embedding_ViT_adv(self, x):
    x = self.forward_features(x)
    return x


def added_embedding_ViT_L(self, x):
    x, attn_weights = self.transformer(x)
    x = x[:, 0]
    out = x.view(x.size(0), -1)
    return out


def added_embedding_preactresnet(self, x):
    out = self.module.conv1(x)
    out = self.module.layer1(out)
    out = self.module.layer2(out)
    out = self.module.layer3(out)
    out = self.module.layer4(out)
    out = F.relu(self.module.bn(out))
    out = F.avg_pool2d(out, 4)
    out = out.view(out.size(0), -1)
    return out

def added_embedding_preactresnet_tiny_imagenet(self, x):
    out = self.conv1(x)
    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.layer4(out)
    out = F.relu(self.bn(out))
    out = F.avg_pool2d(out, 4)
    out = out.view(out.size(0), -1)
    return out


def added_embedding_DenseNet3(self, x):
    x = x.clone()
    x = x.cuda()
    out = self.conv1(x)
    out = self.trans1(self.block1(out))
    out = self.trans2(self.block2(out))
    out = self.block3(out)
    out = self.relu(self.bn1(out))
    out = F.avg_pool2d(out, 8)
    out = out.view(-1, self.in_planes)
    return out


def added_embedding_DenseNet3_ATOM_implementation(self, x):
    if self.normalizer is not None:
        x = x.clone()
        x[:, 0, :, :] = (x[:, 0, :, :] - self.normalizer.mean[0]) / self.normalizer.std[0]
        x[:, 1, :, :] = (x[:, 1, :, :] - self.normalizer.mean[1]) / self.normalizer.std[1]
        x[:, 2, :, :] = (x[:, 2, :, :] - self.normalizer.mean[2]) / self.normalizer.std[2]

    out = self.conv1(x)
    out = self.trans1(self.block1(out))
    out = self.trans2(self.block2(out))
    out = self.block3(out)
    out = self.relu(self.bn1(out))
    out = F.avg_pool2d(out, 8)
    out = out.view(-1, self.in_planes)
    return out


def get_model(model_name, dataset, device):
    try:
        os.mkdir("modelZoo")
    except:
        pass
    try:
        os.mkdir("modelZoo/modelsData")
    except:
        pass

    if model_name == "AOE" and dataset == "cifar100":
        saved_file_address = 'modelZoo/modelsData/cifar100_AOE.pth'
        file_url = "https://bit.ly/3wzrEdE"
        if not Path(saved_file_address).is_file():
            wget.download(url=file_url, out=saved_file_address)

        inner_model = dn.DenseNet3(100, 100, 12, reduction=0.5,
                                   bottleneck=True, dropRate=0.0)

        checkpoint = torch.load(saved_file_address)
        inner_model.load_state_dict(checkpoint['state_dict'])
        model = Wraper(inner_model, [0.5070751592371323, 0.48654887331495095, 0.4409178433670343],
                       [0.2673342858792401, 0.2564384629170883, 0.27615047132568404], device, added_embedding_DenseNet3,
                       "AOE")
        model.set_normalize(True)
        model.eval_mode()
        return model

    if model_name == "AOE" and dataset == "cifar10":
        saved_file_address = 'modelZoo/modelsData/cifar10_AOE.pth'
        file_url = "https://bit.ly/3wBll9r"
        if not Path(saved_file_address).is_file():
            wget.download(url=file_url, out=saved_file_address)

        inner_model = dn.DenseNet3(100, 10, 12, reduction=0.5,
                                   bottleneck=True, dropRate=0.0)

        checkpoint = torch.load(saved_file_address)
        inner_model.load_state_dict(checkpoint['state_dict'])
        model = Wraper(inner_model, [0.4914, 0.4822, 0.4465], [0.2471, 0.2435, 0.2616], device,
                       added_embedding_DenseNet3,
                       "AOE")
        model.set_normalize(True)
        model.eval_mode()
        return model

    if model_name == "ALOE" and dataset == "cifar100":
        saved_file_address = 'modelZoo/modelsData/cifar100_ALOE.pth'
        file_url = "https://bit.ly/3ahRk5Z"
        if not Path(saved_file_address).is_file():
            wget.download(url=file_url, out=saved_file_address)

        inner_model = dn.DenseNet3(100, 100, 12, reduction=0.5,
                                   bottleneck=True, dropRate=0.0)

        checkpoint = torch.load(saved_file_address)
        inner_model.load_state_dict(checkpoint['state_dict'])
        model = Wraper(inner_model, [0.5070751592371323, 0.48654887331495095, 0.4409178433670343],
                       [0.2673342858792401, 0.2564384629170883, 0.27615047132568404], device, added_embedding_DenseNet3,
                       "ALOE")
        model.set_normalize(True)
        model.eval_mode()
        return model

    if model_name == "ALOE" and dataset == "cifar10":
        saved_file_address = 'modelZoo/modelsData/cifar10_ALOE.pth'
        file_url = "https://bit.ly/38C342U"
        if not Path(saved_file_address).is_file():
            wget.download(url=file_url, out=saved_file_address)

        inner_model = dn.DenseNet3(100, 10, 12, reduction=0.5,
                                   bottleneck=True, dropRate=0.0)

        checkpoint = torch.load(saved_file_address)
        inner_model.load_state_dict(checkpoint['state_dict'])
        model = Wraper(inner_model, [0.4914, 0.4822, 0.4465], [0.2471, 0.2435, 0.2616], device,
                       added_embedding_DenseNet3,
                       "ALOE")
        model.set_normalize(True)
        model.eval_mode()
        return model

    if model_name == "ALOE" and dataset == "tiny_imagenet":
        saved_file_address = 'modelZoo/modelsData/ALOE_tinyImageNet.pth'
        file_url = "https://dl.dropboxusercontent.com/s/c0cgpt9s0uvrqdj/ALOE_tinyImageNet.pth"
        if not Path(saved_file_address).is_file():
            wget.download(url=file_url, out=saved_file_address)

        inner_model = dn.DenseNet3(100, 200, 12, reduction=0.5,
                                   bottleneck=True, dropRate=0.0)

        checkpoint = torch.load(saved_file_address)
        inner_model.load_state_dict(checkpoint['state_dict'])
        model = Wraper(inner_model, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], device,
                       added_embedding_DenseNet3,
                       "ALOE")
        model.set_normalize(True)
        model.eval_mode()
        return model

    elif model_name == "ViT-L_32" and dataset == "cifar10":
        config = CONFIGS["ViT-L_32"]
        num_classes = 10
        model = VisionTransformer(config, 224, zero_head=True, num_classes=num_classes).to(device)
        saved_file_address = 'modelZoo/modelsData/cifar10-10000_128_comp_checkpoint.bin'
        file_url = "https://bit.ly/3ySdYfy"
        if not Path(saved_file_address).is_file():
            wget.download(url=file_url, out=saved_file_address)
        model.load_state_dict(torch.load(saved_file_address))
        model = Wraper(model, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], device, added_embedding_ViT_L, "clean",
                       multiple_output=0)
        model.set_normalize(True)
        model.eval_mode()
        return model
    elif model_name == "ViT-L_32" and dataset == "cifar100":
        config = CONFIGS["ViT-L_32"]
        num_classes = 100
        model = VisionTransformer(config, 224, zero_head=True, num_classes=num_classes).to(device)
        saved_file_address = 'modelZoo/modelsData/cifar100-10000_128_comp_checkpoint.bin'
        file_url = "https://bit.ly/3lAYODm"
        if not Path(saved_file_address).is_file():
            wget.download(url=file_url, out=saved_file_address)
        model.load_state_dict(torch.load(saved_file_address))
        model = Wraper(model, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], device, added_embedding_ViT_L, "clean",
                       multiple_output=0)
        model.set_normalize(True)
        model.eval_mode()
        return model
    elif model_name == "ViT-L_32" and dataset == "tiny_imagenet":
        config = CONFIGS["ViT-L_32"]
        num_classes = 200
        model = VisionTransformer(config, 224, zero_head=True, num_classes=num_classes).to(device)
        saved_file_address = 'modelZoo/modelsData/tiny_imagenet-10000_128_comp_checkpoint.bin'
        file_url = "https://dl.dropboxusercontent.com/s/lzfu974nnnthq2g/tiny_imagenet-10000_128_comp_checkpoint.bin"
        if not Path(saved_file_address).is_file():
            wget.download(url=file_url, out=saved_file_address)
        model.load_state_dict(torch.load(saved_file_address))
        model = Wraper(model, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], device, added_embedding_ViT_L, "clean",
                       multiple_output=0)
        model.set_normalize(True)
        model.eval_mode()
        return model
    # elif model_name == "adv ViT" and dataset == "cifar10":
    #     args = get_args_adv_ViT()
    #     saved_file_address = '/kaggle/OOD-adv/modelZoo/modelsData/advViTcheckpoint50.pth'
    #     file_url = ""
    #     if not Path(saved_file_address).is_file():
    #         print("DOWNLOAD")
    #         # gdown.download(file_url, fuzzy=True)
    #
    #     resize_size = args.resize
    #     crop_size = args.crop
    #
    #     num_classes = args.num_classes
    #     from modelZoo.timm_vit.vit import (
    #         vit_base_patch2, vit_base_patch16_224_in21k, vit_large_patch16_224_in21k)
    #     model = eval(args.model)(
    #         pretrained=(not args.scratch),
    #         img_size=crop_size, num_classes=num_classes, patch_size=args.patch, args=args).cuda()
    #
    #     model.load_state_dict(torch.load(saved_file_address)['state_dict'])
    #     model.eval()
    #
    #     cifar10_mean = [0.4914, 0.4822, 0.4465]  # equals np.mean(train_set.train_data, axis=(0,1,2))/255
    #     cifar10_std = [0.2471, 0.2435, 0.2616]  # equals np.std(train_set.train_data, axis=(0,1,2))/255
    #
    #     model = Wraper(model, cifar10_mean, cifar10_std, device, added_embedding_ViT_adv, "adversarial")
    #     model.set_normalize(True)
    #     model.eval_mode()
    #     return model
    # elif model_name == "adv ViT" and dataset == "cifar100":
    #     raise ValueError
    elif model_name == "Rade2021Helper_R18_extra" and dataset == "cifar10":
        model = load_model(model_name='Rade2021Helper_R18_extra', dataset='cifar10', threat_model='Linf',
                           model_dir="modelZoo/modelsData").to(device)
        model.embedding = added_embedding_Rade2021Helper_R18
        model.train_mode = "adversarial"
        model.eval()
        return model
    elif model_name == "Rade2021Helper_R18_ddpm" and dataset == 'cifar100':
        model = load_model(model_name='Rade2021Helper_R18_ddpm', dataset='cifar100', threat_model='Linf',
                           model_dir="modelZoo/modelsData").to(
            device)
        model.embedding = added_embedding_Rade2021Helper_R18
        model.train_mode = "adversarial"
        model.eval()
        return model
    elif model_name == "adversarial_train_Madry" and dataset == "cifar10":
        model = PreActResNet18(num_classes=10)
        saved_file_address = 'modelZoo/modelsData/cifar10_madry.pth'
        file_url = "https://bit.ly/3Nt2e79"
        if not Path(saved_file_address).is_file():
            wget.download(url=file_url, out=saved_file_address)
        model = nn.DataParallel(model).cuda()
        model.load_state_dict(torch.load(saved_file_address))
        cifar10_mean = [0.4914, 0.4822, 0.4465]  # equals np.mean(train_set.train_data, axis=(0,1,2))/255
        cifar10_std = [0.2471, 0.2435, 0.2616]  # equals np.std(train_set.train_data, axis=(0,1,2))/255

        model = Wraper(model, cifar10_mean, cifar10_std, device, added_embedding_preactresnet, "adversarial")
        model.set_normalize(True)
        model.eval_mode()
        return model

    elif model_name == "adversarial_train_Madry" and dataset == "cifar100":
        model = PreActResNet18(num_classes=100)
        saved_file_address = 'modelZoo/modelsData/cifar100_madry.pth'
        file_url = "https://bit.ly/3LwnLdX"
        if not Path(saved_file_address).is_file():
            wget.download(url=file_url, out=saved_file_address)
        model = nn.DataParallel(model).cuda()
        model.load_state_dict(torch.load(saved_file_address))
        CIFAR100_MEAN = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
        CIFAR100_STD = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]
        model = Wraper(model, CIFAR100_MEAN, CIFAR100_STD, device, added_embedding_preactresnet, "adversarial")
        model.set_normalize(True)
        model.eval_mode()
        return model
    elif model_name == "adversarial_train_Madry" and dataset == "tiny_imagenet":
        saved_file_address = 'modelZoo/modelsData/Madry_tiny_imagenet.pth'
        file_url = "https://dl.dropboxusercontent.com/s/jh784rvzwh5qn7o/Madry_tinyimage_net.pth"
        if not Path(saved_file_address).is_file():
            wget.download(url=file_url, out=saved_file_address)

        model = PreActResNet18(num_classes=200).to(device)
        ckpt = torch.load("modelZoo/modelsData/Madry_tiny_imagenet.pth")
        model.load_state_dict(ckpt)
        model = Wraper(model, None, None, device, added_embedding_preactresnet_tiny_imagenet, "adversarial")
        model.eval_mode()
        return model

    elif model_name == "open-set" and dataset == "cifar10":
        args = get_ECCV2020OSAD_args()
        nclass = 10
        Encoder_address = 'modelZoo/modelsData/Encoder-cifar10-final.pt'
        Encoder_file_url = "https://bit.ly/3NxBKlc"
        if not Path(Encoder_address).is_file():
            wget.download(url=Encoder_file_url, out=Encoder_address)
        Encoder = init_model(net=DenoiseResnet.ResnetEncoder(denoisemean=args['denoisemean'],
                                                             latent_size=args['latent_size'], denoise=args['denoise']),
                             init_type=args['init_type'], restore=Encoder_address,
                             parallel_reload=args['parallel_train'])

        NorClsfier_address = 'modelZoo/modelsData/NorClsfier-cifar10-final.pt'
        NorClsfier_file_url = "https://bit.ly/3sMqmcT"
        if not Path(NorClsfier_address).is_file():
            wget.download(url=NorClsfier_file_url, out=NorClsfier_address)
        NorClsfier = init_model(net=DenoiseResnet.NorClassifier(latent_size=args['latent_size'], num_classes=nclass),
                                init_type=args['init_type'], restore=NorClsfier_address,
                                parallel_reload=args['parallel_train'])

        model = OSADWraper(Encoder, NorClsfier)
        model.eval_mode()
        return model

    elif model_name == "open-set" and dataset == "cifar100":
        args = get_ECCV2020OSAD_args()
        nclass = 100
        Encoder_address = 'modelZoo/modelsData/Encoder-cifar100-final.pt'
        Encoder_file_url = "https://bit.ly/3G9ah6I"
        if not Path(Encoder_address).is_file():
            wget.download(url=Encoder_file_url, out=Encoder_address)
        Encoder = init_model(net=DenoiseResnet.ResnetEncoder(denoisemean=args['denoisemean'],
                                                             latent_size=args['latent_size'], denoise=args['denoise']),
                             init_type=args['init_type'], restore=Encoder_address,
                             parallel_reload=args['parallel_train'])

        NorClsfier_address = 'modelZoo/modelsData/NorClsfier-cifar100-final.pt'
        NorClsfier_file_url = "https://bit.ly/3z8F1DB"
        if not Path(NorClsfier_address).is_file():
            wget.download(url=NorClsfier_file_url, out=NorClsfier_address)
        NorClsfier = init_model(net=DenoiseResnet.NorClassifier(latent_size=args['latent_size'], num_classes=nclass),
                                init_type=args['init_type'], restore=NorClsfier_address,
                                parallel_reload=args['parallel_train'])

        model = OSADWraper(Encoder, NorClsfier)
        model.eval_mode()
        return model
    elif model_name == "ATOM" and dataset == "cifar10":
        model = get_model_ATOM("CIFAR-10", "atom", "ATOM", 100, 'densenet')
        return model
    elif model_name == "ATOM" and dataset == "cifar100":
        model = get_model_ATOM("CIFAR-100", "atom", "ATOM", 100, 'densenet')
        return model
    else:
        raise ValueError


class InputNormalize(nn.Module):
    """
    Normalizes inputs according to (x - mean) / std
    """

    def __init__(self, new_mean, new_std, device):
        super(InputNormalize, self).__init__()
        self.new_std = new_std[..., None, None].to(device)
        self.new_mean = new_mean[..., None, None].to(device)

    def forward(self, x):
        x = torch.clamp(x, 0, 1)
        x_normalized = (x - self.new_mean) / self.new_std
        return x_normalized


class Wraper(nn.Module):
    def __init__(self, inner_model, mean, std, device, embedding_function, train_mode, multiple_output=None):
        super(Wraper, self).__init__()
        self.inner_model = inner_model.to(device)
        self.need_normalize = False
        if mean is None:
            self.mean = None
            self.std = None
            self.normalize = None
        else:
            self.mean = torch.tensor(mean)
            self.std = torch.tensor(std)
            self.normalize = InputNormalize(self.mean, self.std, device)

        self.embedding_function = embedding_function
        self.train_mode = train_mode
        self.multiple_output = multiple_output

    def eval_mode(self):
        self.eval()
        self.inner_model.eval()

    def set_normalize(self, need_normalize):
        """
        Whether to normalize inputs or not.
        """
        self.need_normalize = need_normalize

    def forward(self, x):
        if self.need_normalize:
            x = self.normalize(x)
        if self.multiple_output is None:
            return self.inner_model(x)
        else:
            return self.inner_model(x)[self.multiple_output]

    @staticmethod
    def embedding(self, x):
        if self.need_normalize:
            x = self.normalize(x)
        return self.embedding_function(self.inner_model, x)


class OSADWraper(nn.Module):
    def __init__(self, encoder, NorClsfier):
        super(OSADWraper, self).__init__()
        self.encoder = encoder
        self.norClsfier = NorClsfier
        self.train_mode = "adversarial"

    def eval_mode(self):
        self.eval()
        self.encoder.eval()
        self.norClsfier.eval()

    def forward(self, x):
        embd = self.encoder(x)
        out = self.norClsfier(embd)
        return out

    @staticmethod
    def embedding(self, x):
        return self.encoder(x)


def get_ECCV2020OSAD_args():
    parser = {}
    parser['description'] = "AdvOpenset"

    parser['training_type'] = 'Test'
    parser['parallel_train'] = False  # cifar10 svhn False; tinyimagenet True
    parser['datasetname'] = 'cifar10'  # cifar10 tinyimagenet svhn
    parser['split'] = '0'
    parser['imgsize'] = 32  # cifar svhn 32 tinyimagenet 64

    parser['adv'] = 'FGSMattack'  # clean PGDattack FGSMattack
    parser['adv_iter'] = 5

    parser['defense'] = 'Ours_FD'
    parser['denoisemean'] = 'gaussian'
    parser['init_type'] = 'normal'

    parser['defensesnapshot'] = 'f'
    parser['denoise'] = [True, True, True, True, True]

    parser['batchsize'] = 64
    parser['latent_size'] = 512

    parser['results_path'] = './results/'
    parser['manual_seed'] = None

    return parser


class get_args_adv_ViT:
    def __init__(self):
        self.model = 'vit_base_patch16_224_in21k'
        self.method = 'pgd'
        self.run_dummy = 'store_true'
        self.accum_steps = 1
        self.grad_clip = 1.0
        self.save_interval = 1000
        self.log_interval = 10
        self.prefetch = 1  # 2
        self.data = 'cifar10'
        self.tfds_dir = '~/dataset/tar'
        self.eval = 'store_true'
        self.eval_steps = 1000
        self.batch_size = 128
        self.batch_size_eval = 512
        self.eval_interval = 100
        self.no_timm = 'store_true'
        self.crop = 32
        self.resize = 32
        self.load = None  # str
        self.data_loader = 'torch'
        self.no_inception_crop = 'store_true'
        self.scratch = 'store_true'
        self.custom_vit = 'store_true'
        self.base_lr = 0.03
        self.warmup_steps = 500
        self.depth = 12
        self.optimizer = 'sgd'
        self.attack_iters = 7
        self.patch = 4
        self.load_state_dict_only = 'store_true'
        self.patch_embed_scratch = 'store_true'
        self.num_layers = None  # int

        self.eval_restarts = 1
        self.eval_iters = 10
        self.downsample_factor = 'store_true'
        self.eval_all = 'store_true'
        self.eval_aa = 'store_true'
        self.num_classes = 10

        self.data_dir = './data'
        self.epochs = 20
        # self.lr_decay_milestones=int, nargs='+', default=[15,18]
        # self.lr_schedule='multistep', choices=['cyclic', 'multistep']
        self.lr_min = 0.
        self.lr_max = 0.1
        self.lr_natural = 5e-4
        self.weight_decay = 2e-4
        self.momentum = 0.9
        self.epsilon = 8
        self.alpha = 10
        self.delta_init = 'random'
        self.out_dir = 'output_dir'
        self.dir = 'output_dir'
        self.seed = 0
        self.early_stop = 'store_true'
        # self.opt_level', default='O2=str, choices=['O0', 'O1', 'O2'],
        #     help='O0 is FP32 training, O1 is Mixed Precision, and O2 is "Almost FP16" Mixed Precision'
        # self.loss_scale', default='1.0=str, choices=['1.0', 'dynamic'],
        #     help='If loss_scale is "dynamic", adaptively adjust the loss scale over time'
        # self.master_weights='store_true',
        #     help='Maintain FP32 master weights to accompany any FP16 model weights, not applicable for O1 opt level'

        self.pretrain_pos_only = 'store_true'

        assert self.batch_size % self.accum_steps == 0
