from __future__ import print_function
import argparse
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import sys
sys.path.append('./')

from utils.misc import *
from utils.test_helpers import *
from utils.prepare_dataset import *
from utils.rotation import *
from utils.prepare_attack_dataset import *
from adv_test_calls.advtest_TTT import * 
from shutil import copyfile, rmtree
from utils.prepare_corruption_dataset import *
from defense.DANN import *
from utils.prepare_dann_attack_dataset import *
import os

def DANN_bilevel_joint(args, dann_attack=False, corruption='fog'): 
    # Generate the attack wrt the pretrained model on the corruption dataset. Here 
    # I use the corruption dataset with different levels combined.

    # The alternating bilevel optimization attack can be viewed as one epoch FPA 
    
    # Private seed specifies if the defenders' private randomness is known
    # to te attacker. If it is known, the attacker will always specify the DANN 
    # with the same private seed.
    
    if dann_attack: 
        dir_name = 'DANN_bilevel_joint_dann_attack'
    else: 
        dir_name = 'DANN_bilevel_joint'
    args.pretrain_dir = args.pretrain_dir+'/{}'.format(corruption)
    if args.private_seed:   
        private_seed = 140739 # I just randomly typed this one
    else: 
        private_seed = None
    
    # Generate the random seed
    if private_seed == None: 
        seed = 10419487  # I just randomly typed this one :)
        np.random.seed(seed)
        seeds = np.random.randint(100000, size = args.n_iter)
    else: 
        seeds = np.array([private_seed]*args.n_iter)
        dir_name = dir_name+str(private_seed)
    if not os.path.exists('./results/pretrain/'+args.pretrain_dir+'/' +dir_name):
        os.makedirs('./results/pretrain/'+args.pretrain_dir+'/' + dir_name)
    print('seeds:')
    print(seeds)
    # Source data preparation
    
    net, _, _, _ = build_model(args)

    
    if args.resume_epoch > 0:
        # Network initialization
        model = DANNWrapper(net)
    else: 
        ckpt = torch.load('./results/pretrain/'+args.pretrain_dir+'/ckpt.pth')
        if args.wrap_DANN: 
            net.load_state_dict(ckpt['net'])
            model = DANNWrapper(net)
        else: 
            model = DANNWrapper(net)
            model.load_state_dict(ckpt['model'])
        torch.save({'model': model.state_dict()},
                './results/pretrain/'+args.pretrain_dir+'/'+dir_name+'/ckpt.pth')
        copyfile('./results/pretrain/'+args.pretrain_dir+'/'+dir_name+'/ckpt.pth',
                './results/pretrain/'+args.pretrain_dir+'/'+dir_name+'/ckpt0.pth')
    _, test_source_loader = prepare_test_data(args)
    _, train_source_loader = prepare_train_data(args)
    ckpt = torch.load('./results/pretrain/'+args.pretrain_dir +
                            '/'+dir_name+'/ckpt.pth')
    # Load the pretrained model for generating the adversarial samples
    model.load_state_dict(ckpt['model'])
    (_, train_target_loader), (_, test_target_loader) = prepare_corruption_data(corruption)

    lr = 3e-4
    batch_size = 128
    n_epoch = 100
    model = model.cuda()
    model.train()
    loss_class = torch.nn.CrossEntropyLoss()
    loss_domain = torch.nn.CrossEntropyLoss()
    dataloader_source = train_source_loader
    
    dataloader_target = train_target_loader 

    loss_class = loss_class.cuda()
    loss_domain = loss_domain.cuda()
    len_dataloader = min(len(dataloader_source), len(dataloader_target))

    # setup optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)
    for epoch in range(1+args.resume_epoch, args.n_iter+1+args.resume_epoch): 
        i = 0 
        n_correct_source = 0
        n_correct_target = 0 
        n_total = 0 
        data_source_iter = iter(dataloader_source)
        data_target_iter = iter(dataloader_target)
        while i < len_dataloader:

            p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader
            alpha = 2. / (1. + np.exp(-10 * p)) - 1

            # training model using source data
            s_img, s_label = data_source_iter.next()
            # s_img = s_img.expand(s_img.data.shape[0], 3, 28, 28)

            s_batch_size = s_img.shape[0]
            s_domain_label = torch.zeros(s_batch_size)
            s_domain_label = s_domain_label.long()

            s_img = s_img.cuda()
            s_label = s_label.cuda()
            s_domain_label = s_domain_label.cuda()

            # training model using target data
            t_img, t_label = data_target_iter.next()
            # t_img = t_img.expand(t_img.data.shape[0], 3, 28, 28)

            t_batch_size = t_img.shape[0]
            t_domain_label = torch.ones(t_batch_size)
            t_domain_label = t_domain_label.long()

            t_img, t_label = t_img.cuda(), t_label.cuda()
            t_domain_label = t_domain_label.cuda()

            dann_classifier = nn.Sequential(model.feature, model.classifier)
            eps = 16/255
            nb_iter = 7

            if dann_attack:
                t_img = dann_perturb(t_img, t_label, alpha, model,
                                     nb_iter, eps_iter=4/255, eps=eps)
            else:
                adversary = LinfPGDAttack(
                    dann_classifier, loss_fn=nn.CrossEntropyLoss().cuda(), eps=eps,
                    nb_iter=nb_iter, eps_iter=4/255, rand_init=True, clip_min=-1.0, clip_max=1.0,
                    targeted=False)
                with ctx_noparamgrad_and_eval(model):
                    inputs_adv = adversary.perturb(t_img, t_label)
                t_img = inputs_adv
            cat_img = torch.cat((s_img, t_img), 0)
            class_output, domain_output = model(
                input_data=cat_img, alpha=alpha)

            s_class_output = class_output[:s_batch_size]
            t_class_output = class_output[s_batch_size:]
            s_domain_output = domain_output[:s_batch_size]
            t_domain_output = domain_output[s_batch_size:]

            err_s_label = loss_class(s_class_output, s_label)
            err_s_domain = loss_domain(s_domain_output, s_domain_label)
            err_t_domain = loss_domain(t_domain_output, t_domain_label)

            err = err_t_domain + err_s_domain + err_s_label

            pred_source = s_class_output.data.max(1, keepdim=True)[1]
            pred_target = t_class_output.data.max(1, keepdim=True)[1]
            # print(s_label.data.view_as(pred_source))
            # print(pred_source)
            n_correct_source += pred_source.eq(s_label.data.view_as(pred_source)).cpu().sum()
            n_correct_target += pred_target.eq(t_label.data.view_as(pred_target)).cpu().sum()
            n_total += batch_size
            optimizer.zero_grad()
            err.backward()
            optimizer.step()
            i+=1 
            if i % 100 == 0:
                print('epoch: %d, [iter: %d / all %d], err_s_label: %f, err_s_domain: %f, err_t_domain: %f'
                    % (epoch, i, len_dataloader, err_s_label.cpu().data.numpy(),
                        err_s_domain.cpu().data.numpy(), err_t_domain.cpu().data.numpy()))
        source_acc = n_correct_source.data.numpy() * 1.0 / n_total
        target_acc = n_correct_target.data.numpy() * 1.0 / n_total
        print('epoch: %d, source acc %f, target acc %f' % (epoch, source_acc, target_acc)) 
        torch.save({'model': model.state_dict()},
                   './results/pretrain/'+args.pretrain_dir+'/'+dir_name+'/ckpt.pth')
        torch.save({
                    'source_acc': source_acc,
                    'target_acc': target_acc,
                    }, './results/pretrain/'+args.pretrain_dir+'/'+dir_name+'/ckpt{}.pth'.format(epoch))

# Evaluate the alternating bilevel optimization attack with DANN of 100 step[s
def eval_bilevel(args, corruption = 'fog', dann_attack = False):
    if dann_attack:
        dir_name = 'DANN_bilevel_joint_dann_attack'
    else:
        dir_name = 'DANN_bilevel_joint'    
    args.pretrain_dir = args.pretrain_dir+'/{}'.format(corruption)
    net, _, _, _ = build_model(args)
    model = DANNWrapper(net)
    name = args.pretrain_dir+'/'+dir_name+"_{}_{}_pgd8".format(101,corruption)
    ckpt = torch.load('./results/pretrain/' +
                      args.pretrain_dir+'/'+dir_name+'/ckpt.pth')
    model.load_state_dict(ckpt['model'])
    (_, train_loader), (_, test_loader) = prepare_corruption_data(corruption)
    dann_classifier = nn.Sequential(model.feature, model.classifier)
    prepare_pgd_attack_data(
        args, train_loader, dann_classifier, name, train=True, nb_iter=7)
    prepare_pgd_attack_data(
        args, test_loader, dann_classifier, name, train=False, nb_iter=7)

    target_test_data = ADVDataset('attack_data/{}/test.npy'.format(name))

    net, _, _, _ = build_model(args)
    model = DANNWrapper(net)
    _, test_source_loader = prepare_test_data(args)
    _, train_source_loader = prepare_train_data(args)

    #### Data Preparation
    test_target_loader = torch.utils.data.DataLoader(
        dataset=target_test_data,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=8)

    # Model Preparation
    init_random_seed(0)

    n_epoch = 100

    # setup optimizer
    lr = 3e-4
    optimizer = optim.Adam(model.parameters(), lr=lr)

    model = model.cuda()

    source_dataset_name = 'cifar10'
    target_dataset_name = 'cifar10c-{}-bilevel'.format(corruption)
    # DANN training
    for epoch in range(n_epoch):
        train_one_epoch(model, train_source_loader,
                        test_target_loader, optimizer, epoch, n_epoch)
        # scheduler.step()
        test_one_epoch(model, test_target_loader, target_dataset_name, epoch)




if __name__ == '__main__': 
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='cifar10')
    parser.add_argument('--dataroot', default='/nobackup/yguo/datasets/')
    parser.add_argument('--shared', default='layer2')
    ########################################################################
    parser.add_argument('--depth', default=26, type=int)
    parser.add_argument('--width', default=1, type=int)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--group_norm', default=0, type=int)
    parser.add_argument('--fix_bn', action='store_false')
    parser.add_argument('--fix_ssh', action='store_false')
    ########################################################################
    parser.add_argument('--lr', default=0.001, type=float)
    parser.add_argument('--online', action='store_true')
    parser.add_argument('--threshold', default=1, type=float)
    parser.add_argument('--dset_size', default=0, type=int)
    ########################################################################
    parser.add_argument('--n_iter', default = 100, type = int)
    parser.add_argument('--private_seed', action='store_true')
    parser.add_argument('--pretrain_dir', default = 'DANN_cifar10_cifar10c') 
    parser.add_argument('--resume_epoch', default = 0, type = int)
    parser.add_argument('--wrap_DANN', action='store_true', help='Whether or not wrap the model in DANN') 

    args = parser.parse_args()
    # Perform alternating bilevel optimization attack.
    DANN_bilevel_joint(args, dann_attack=True, corruption='brightness')
    # eval_bilevel(args, dann_attack=True, corruption='brightness')
    # Fog, pgd attack: 31.70
    # Fog, DANN (v1,v2) attack: 34.73, 56.04
    # Gaussian noise, PGD attack: 40.22 
    # Gaussian noise, DANN attack: 50.44, 69.47
    # Brightness, PGD attack: 47.11 
    # Brightness, DANN (v1, v2) attack: 52.59, 59.55