import torch
import torch.nn as nn
import torch.optim as optim
import sys
sys.path.append('./')
from adv_test_calls.advtest_TTT import *
import os
from defense.DANN import *
from utils.prepare_corruption_dataset import *
from shutil import copyfile
from utils.prepare_attack_dataset import *
from utils.rotation import *
from utils.prepare_dataset import *
from utils.test_helpers import *
from utils.misc import *
import argparse
from tqdm import tqdm
from PIL import Image

# Apply DANN to CIFAR10 -> CIFAR10 (sanity check)
def DANN_pretrain(args, private_seed=140739):
    # Generate the attack wrt the pretrained model on the corruption dataset. Here
    # I use the corruption dataset with different levels combined.

    # Private seed specifies if the defenders' private randomness is known
    # to te attacker. If it is known, the attacker will always specify the DANN
    # with the same private seed.

    name = 'DANN_cifar10'

    if not os.path.exists('./results/pretrain/{}'.format(name)):
        os.makedirs('./results/pretrain/{}'.format(name))
    # Source data preparation

    _, test_source_loader = prepare_test_data(args)
    _, train_source_loader = prepare_train_data(args)
    train_target_loader = train_source_loader
    test_target_loader = test_source_loader
    
    init_random_seed(private_seed)

    net, _, _, _ = build_model(args)
    model = DANNWrapper(net)
    n_epoch = 150

    # setup optimizer
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, [75, 125], gamma=0.1, last_epoch=-1)
    model = model.cuda()

    source_dataset_name = 'cifar10'
    target_dataset_name = 'cifar10'
    # DANN training
    DANN_train_target_acc = []
    DANN_test_source_acc = []
    DANN_test_target_acc = []
    for epoch in range(n_epoch):
        train_one_epoch(model, train_source_loader,
                        train_target_loader, optimizer, epoch, n_epoch)
        DANN_test_source_acc.append(test_one_epoch(
            model, test_source_loader, source_dataset_name, epoch))
        DANN_test_target_acc.append(test_one_epoch(
            model, test_target_loader, target_dataset_name, epoch))
        print("Test-Time Adaptation accuracy")
        DANN_train_target_acc.append(test_one_epoch(
            model, train_target_loader, target_dataset_name, epoch))
    torch.save({'model': model.state_dict()},
               './results/pretrain/'+name+'/ckpt.pth')
    torch.save({'model': model.state_dict(),
                'train_target_acc': DANN_train_target_acc,
                'test_source_acc':  DANN_test_source_acc,
                'test_target_acc':  DANN_test_target_acc,
                }, './results/pretrain/'+name+'/ckpt.pth')
    print(DANN_train_target_acc)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='cifar10')
    parser.add_argument('--corruption', default='fog')
    parser.add_argument('--level', default=5, type=int)
    parser.add_argument('--dataroot', default='/nobackup/yguo/datasets/')
    parser.add_argument('--shared', default='layer2')
    ########################################################################
    parser.add_argument('--depth', default=26, type=int)
    parser.add_argument('--width', default=1, type=int)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--group_norm', default=0, type=int)
    parser.add_argument('--fix_bn', action='store_false')
    parser.add_argument('--fix_ssh', action='store_false')
    ########################################################################
    parser.add_argument('--lr', default=0.1, type=float)
    parser.add_argument('--online', action='store_true')
    parser.add_argument('--threshold', default=1, type=float)
    parser.add_argument('--dset_size', default=0, type=int)
    ########################################################################
    args = parser.parse_args()
    DANN_pretrain(args, private_seed=140739)
