import os
import sys
import time
import argparse
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from tqdm import tqdm
import timm
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models
from attacks import VNIFGSM, MIFGSM, DIFGSM, SINIFGSM
from collections import OrderedDict
import cv2

import models
from utils import AverageMeter, time_file_str

'''model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(models.__dict__[name]))'''
model_names = ['resnet50', 'resnet101', 'inception_v3', 'vgg16_bn', 'densenet121', 'ensemble']

def print_log(print_string, log):
    print("{:}".format(print_string))
    sys.stdout.flush()
    log.write('{:}\n'.format(print_string))
    log.flush()

def parse_option():
    parser = argparse.ArgumentParser()
    parser.add_argument('data', metavar='DIR')
    parser.add_argument("--num_classes", default=1000)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--workers", type=int, default=4)

    ## for the source model
    parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50', choices=model_names)
    parser.add_argument('--models_dir', type=str)
    parser.add_argument('--num_models', type=int, default=3)

    ## for the target model
    parser.add_argument("--target_model", default='inception_v3', 
                        choices=['inception_v3', 'inception_v4', 'resnet50', 'resnet101',
                                 'vgg16_bn', 'densenet121', 
                                 'vit_base_patch16_224.augreg2_in21k_ft_in1k',
                                 'vit_small_patch16_224.augreg_in21k_ft_in1k'])

    ## attack
    parser.add_argument("--attack", type=str, choices=["mifgsm", "vnifgsm", "difgsm", "sinifgsm"], default="mifgsm")
    parser.add_argument("--eps", type=float, default=16/255)
    parser.add_argument("--steps", type=int, default=10)
    parser.add_argument("--step_size", type=float, default=1.6/255)
    parser.add_argument("--decay", type=float, default=1.0) # only for mifgsm
    parser.add_argument("--targeted", action='store_true')

    parser.add_argument("--save_dir", type=str)
    parser.add_argument("--prefix", type=str, default="attack_inc_v3_", help="the name used for logging")
    parser.add_argument("--attack_source_model", action="store_true")

    args = parser.parse_args()
    args.use_cuda = torch.cuda.is_available()
    assert args.use_cuda
    args.prefix += time_file_str()
    if args.targeted:
        args.prefix += ".targeted"
    return args

def main(args):
    if args.use_cuda:
        device = "cuda"
    else:
        assert 0, "CUDA not available"

    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)
    log = open(os.path.join(args.save_dir, '{}.{}.log'.format(args.arch, args.prefix)), 'w')
    sample_save_dir = os.path.join(args.save_dir, '{}.{}'.format(args.arch, args.prefix))
    if not os.path.isdir(sample_save_dir):
        os.makedirs(sample_save_dir)

    # version information
    print_log("PyThon  version : {}".format(sys.version.replace('\n', ' ')), log)
    print_log("PyTorch version : {}".format(torch.__version__), log)
    print_log("cuDNN   version : {}".format(torch.backends.cudnn.version()), log)
    print_log("Vision  version : {}".format(torchvision.__version__), log)
    print_log("args: {}".format(args), log)

    print_log("target_model: {}".format(args.target_model), log)
    print_log("models_dir: {}".format(args.models_dir), log)
    print_log("num_models: {}".format(args.num_models), log)
    
    print_log("attack: {}".format(args.attack), log)
    print_log("eps: {}".format(args.eps), log)
    print_log("steps: {}".format(args.steps), log)
    print_log("step_size: {}".format(args.step_size), log)
    print_log("decay: {}".format(args.decay), log)
    print_log("targeted: {}".format(args.targeted), log)

    normalize_mean = None
    normalize_std = None
    if "resnet" in args.arch or "vgg" in args.arch or args.arch == 'densenet121':
        normalize_mean = [0.485, 0.456, 0.406]
        normalize_std = [0.229, 0.224, 0.225]
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=normalize_mean, std=normalize_std)
        ])
    elif args.arch == 'inception_v3':
        normalize_mean = [0.5, 0.5, 0.5]
        normalize_std = [0.5, 0.5, 0.5]
        transform = transforms.Compose([
            transforms.Resize(size=341, interpolation=transforms.InterpolationMode.BICUBIC, max_size=None, antialias=None),
            transforms.CenterCrop(size=(299, 299)),
            transforms.ToTensor(),
            transforms.Normalize(mean=normalize_mean, std=normalize_std)
        ])
    elif args.arch == 'ensemble':
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor()
        ])
    else:
        raise NotImplementedError

    val_dataset = datasets.ImageFolder(
        root = args.data,
        transform = transform
    )

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True
    )

    source_model_path = []
    source_model_arch = []
    if args.models_dir:
        for filename in os.listdir(args.models_dir):
            if filename.endswith(".pth.tar"):
                prefix = filename.split('.')[0]
                epoch = filename.split('.')[1]
                arch = filename.split('.')[2]
                
                source_model_path.append(os.path.join(args.models_dir, filename))
                source_model_arch.append(arch)
    if args.arch != 'ensemble':
        source_model_arch = args.arch

    if normalize_mean == None:
        target_model = models.TimmModel(args.target_model, do_inverse_normalize=False).to(device)
    else:
        target_model = models.TimmModel(args.target_model, mean=normalize_mean, std=normalize_std).to(device)

    criterion = nn.CrossEntropyLoss().cuda()

    target_acc = validate(val_loader, target_model, criterion, log)
    print_log(f"accuracy of the target model on benign data is {target_acc}", log)

    if args.attack == "vnifgsm":
        # args.steps = 10
        # args.step_size = 1.6/255
        attack = VNIFGSM(source_model_path, source_model_arch, device, args.num_models, args.num_classes, eps=args.eps, alpha=args.step_size, steps=args.steps, N=20, beta=1.5)
    elif args.attack == "mifgsm":
        # args.steps = 10
        # args.step_size = 1.6/255
        attack = MIFGSM(source_model_path, source_model_arch, device, args.num_models, args.num_classes, eps=args.eps, alpha=args.step_size, steps=args.steps, decay=args.decay)
    elif args.attack == "difgsm":
        # args.steps = 100
        # args.step_size = 1.6/255
        attack = DIFGSM(source_model_path, source_model_arch, device, args.num_models, args.num_classes, eps=args.eps, alpha=args.step_size, steps=args.steps, diversity_prob=0.7)
    elif args.attack == "sinifgsm":
        attack = SINIFGSM(source_model_path, source_model_arch, device, args.num_models, args.num_classes, eps=args.eps, alpha=args.step_size, steps=args.steps)
    else:
        raise NotImplementedError
    if normalize_mean != None:
        attack.set_normalization_used(normalize_mean, normalize_std)
    if args.targeted:
        attack.set_mode_targeted_random()
    
    # source_acc = validate(val_loader, source_model, criterion, log, attack=attack)
    # print_log(f"accuracy of the source model on adversarial data is {source_acc}", log)
    if args.attack_source_model:
        source_model_2 = models.__dict__[args.arch](pretrained=True)
        source_model_2 = source_model_2.to(device)
        source_acc = validate(val_loader, source_model_2, criterion, log, attack=attack, targeted=args.targeted)
        if not args.targeted:
            print_log(f"accuracy of the source model on adversarial data is {source_acc}", log)
        else:
            print_log(f"success rate of the targeted attack on the source model is {source_acc}", log)

    target_acc_adv = validate(val_loader, target_model, criterion, log, attack=attack, targeted=args.targeted, sample_save_dir=None)
    if not args.targeted:
        success_rate = (target_acc - target_acc_adv) / target_acc
        print_log(f"accuracy of the target model on adversarial data is {target_acc_adv}", log)
        print_log(f"success rate of the attack is {success_rate}", log)
    else:
        print_log(f"success rate of the targeted attack on the target model is {target_acc_adv}", log)

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, dim=1, largest=True, sorted=True)
    pred = pred.t() # (maxk, batch_size)
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    # print("size of the correct tensor: ", correct.size())

    res = []
    for k in topk:
        correct_k = correct[:k].reshape((-1,)).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def validate(val_loader, model, criterion, log, attack=None, targeted=False, sample_save_dir=None):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (input, labels) in enumerate(val_loader):
        input = input.cuda()
        labels = labels.cuda()
        if attack != None:
            input, target_labels = attack(input, labels)
            if i == 0 and sample_save_dir != None:
                denormalized_input = attack.inverse_normalize(input)
                for j in range(denormalized_input.shape[0]):
                    sample_save_path = os.path.join(sample_save_dir, f"sample{j}.png")
                    save_image = denormalized_input[j].cpu().numpy().transpose((1,2,0))
                    save_image = (save_image*255).astype('uint8')
                    save_image = save_image[:,:,::-1]
                    cv2.imwrite(sample_save_path, save_image)
        if attack != None and targeted == True:
            target_var = torch.autograd.Variable(target_labels, volatile=True)
        input_var = torch.autograd.Variable(input, volatile=True)
        labels_var = torch.autograd.Variable(labels, volatile=True)

        # compute output
        output = model(input_var)
        if attack == None or targeted == False:
            loss = criterion(output, labels_var)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, labels, topk=(1, 5))
            losses.update(loss.data.item(), input.size(0))
            top1.update(prec1[0], input.size(0))
            top5.update(prec5[0], input.size(0))
        else:
            loss = criterion(output, target_var)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target_labels, topk=(1, 5))
            losses.update(loss.data.item(), input.size(0))
            top1.update(prec1[0], input.size(0))
            top5.update(prec5[0], input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 10 == 0:
            print_log('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                i, len(val_loader), batch_time=batch_time, loss=losses,
                top1=top1, top5=top5), log)

    print_log(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}'.format(top1=top1, top5=top5,
                                                                                           error1=100 - top1.avg), log)

    return top1.avg

if __name__ == "__main__":
    args = parse_option()
    main(args)