from __future__ import print_function
import argparse
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import sys
sys.path.append('./')

from utils.misc import *
from utils.test_helpers import *
from utils.prepare_dataset import *
from utils.rotation import *
from utils.prepare_attack_dataset import *
from adv_test_calls.advtest_TTT import * 
from shutil import copyfile, rmtree
from utils.prepare_corruption_dataset import *
from defense.DANN import *
import os



def DANN_FPA(args): 
    # Generate the attack wrt the pretrained model on the corruption dataset. Here 
    # I use the corruption dataset with different levels combined.
    
    # Private seed specifies if the defenders' private randomness is known
    # to te attacker. If it is known, the attacker will always specify the DANN 
    # with the same private seed.
    
    dir_name = 'DANN_FPA_RI'
    if args.private_seed:   
        private_seed = 140739 # I just randomly typed this one
    else: 
        private_seed = None
    
    # Generate the random seed
    if private_seed == None: 
        seed = 10419487  # I just randomly typed this one :)
        np.random.seed(seed)
        seeds = np.random.randint(100000, size = args.n_iter)
    else: 
        seeds = np.array([private_seed]*args.n_iter)
        dir_name = dir_name+str(private_seed)
    if not os.path.exists('./results/pretrain/'+args.pretrain_dir+'/' +dir_name):
        os.makedirs('./results/pretrain
        /'+args.pretrain_dir+'/' + dir_name)
    print('seeds:')
    print(seeds)
    # Source data preparation
    
    net, _, _, _ = build_model(args)

    
    if args.resume_epoch > 0:
        # Network initialization
        model = DANNWrapper(net)
    else: 
        ckpt = torch.load('./results/'+args.pretrain_dir+'/ckpt.pth')
        if args.wrap_DANN: 
            net.load_state_dict(ckpt['net'])
            model = DANNWrapper(net)
        else: 
            model = DANNWrapper(net)
            model.load_state_dict(ckpt['model'])
        torch.save({'model': model.state_dict()},
                './results/'+args.pretrain_dir+'/'+dir_name+'/ckpt.pth')
        copyfile('./results/'+args.pretrain_dir+'/'+dir_name+'/ckpt.pth',
                './results/'+args.pretrain_dir+'/'+dir_name+'/ckpt0.pth')
    _, test_source_loader = prepare_test_data(args)
    _, train_source_loader = prepare_train_data(args)
    for FPA_iter in range(1+args.resume_epoch, args.n_iter+1+args.resume_epoch): 
        # Attack generation 
        ckpt = torch.load('./results/'+args.pretrain_dir+'/'+dir_name+'/ckpt.pth')
        model.load_state_dict(ckpt['model'])
        name = args.pretrain_dir+'/'+dir_name+"_{}_fog_pgd8".format(FPA_iter) 
        (_, train_loader), (_, test_loader) = prepare_fog_data()
        dann_classifier = nn.Sequential(model.feature, model.classifier)
        prepare_pgd_attack_data(args, train_loader, dann_classifier, name, train = True, nb_iter=7)
        test_cls_acc, test_adv_acc = prepare_pgd_attack_data(args, test_loader, dann_classifier, name, train = False, nb_iter=7) 
        
        # DANN adaptation from random initialization
        
        
        # DANN target data preparation
        target_train_data = ADVDataset('attack_data/{}/train.npy'.format(name)) 
        target_test_data = ADVDataset('attack_data/{}/test.npy'.format(name)) 
        rmtree('attack_data/{}'.format(name)) # Comment out this line if you want to use the attack data.
        train_target_loader = torch.utils.data.DataLoader(
        dataset=target_train_data,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=8)
        test_target_loader = torch.utils.data.DataLoader(
        dataset=target_test_data,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=8) 
        
        
        # Model Preparation
        init_random_seed(seeds[FPA_iter -1-args.resume_epoch])
        
        net, _, _, _ = build_model(args)
        model = DANNWrapper(net)

        lr = 3e-4
        batch_size = 128
        n_epoch = 100

        # setup optimizer
        optimizer = optim.Adam(model.parameters(), lr=lr)
        model = model.cuda()

        source_dataset_name = 'cifar10'
        target_dataset_name = 'cifar10c-FPA-RI{}-fog-pgd8'.format(FPA_iter)
        # DANN training
        DANN_train_target_acc = []
        DANN_test_source_acc = []
        DANN_test_target_acc = []
        for epoch in range(n_epoch):
            train_one_epoch(model, train_source_loader, train_target_loader, optimizer, epoch, n_epoch)
            # scheduler.step()
            DANN_test_source_acc.append(test_one_epoch(
                model, test_source_loader, source_dataset_name, epoch))
            DANN_test_target_acc.append(test_one_epoch(
                model, test_target_loader, target_dataset_name, epoch))
            print("Test-Time Adaptation accuracy")
            DANN_train_target_acc.append(test_one_epoch(
                model, train_target_loader, target_dataset_name, epoch))
        torch.save({'model': model.state_dict()},
                   './results/'+args.pretrain_dir+'/'+dir_name+'/ckpt.pth')
        torch.save({'model': model.state_dict(), 
                    'test_cls_acc': test_cls_acc, 
                    'test_adv_acc': test_adv_acc,
                    'train_target_acc': DANN_train_target_acc,
                    'test_source_acc':  DANN_test_source_acc,
                    'test_target_acc':  DANN_test_target_acc,
                    },'./results/'+args.pretrain_dir+'/'+dir_name+'/ckpt{}.pth'.format(FPA_iter))
        print(DANN_train_target_acc)
            
if __name__ == '__main__': 
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='cifar10')
    parser.add_argument('--dataroot', default='/nobackup/yguo/datasets/')
    parser.add_argument('--shared', default='layer2')
    ########################################################################
    parser.add_argument('--depth', default=26, type=int)
    parser.add_argument('--width', default=1, type=int)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--group_norm', default=0, type=int)
    parser.add_argument('--fix_bn', action='store_false')
    parser.add_argument('--fix_ssh', action='store_false')
    ########################################################################
    parser.add_argument('--lr', default=0.001, type=float)
    parser.add_argument('--niter', default=1, type=int)
    parser.add_argument('--online', action='store_true')
    parser.add_argument('--threshold', default=1, type=float)
    parser.add_argument('--dset_size', default=0, type=int)
    ########################################################################
    parser.add_argument('--n_iter', default = 20, type = int)
    parser.add_argument('--private_seed', action='store_true')
    parser.add_argument('--pretrain_dir', default = 'DANN_cifar10_cifar10c') 
    parser.add_argument('--resume_epoch', default = 0, type = int)
    parser.add_argument('--wrap_DANN', action='store_true', help='Whether or not wrap the model in DANN') 

    args = parser.parse_args()
    DANN_FPA(args) 

