from __future__ import print_function
import argparse
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import sys
sys.path.append('./')

from utils.misc import *
from utils.test_helpers import *
from utils.prepare_dataset import *
from utils.rotation import *
from utils.prepare_attack_dataset import *
from utils.DANN_model import DANNWrapper

# Train DANN objective for one epoch. 
def train_one_epoch(model, dataloader_source, dataloader_target, optimizer, epoch, n_epoch, alpha_scale = 1):
    model.train()
    loss_class = torch.nn.CrossEntropyLoss()
    loss_domain = torch.nn.CrossEntropyLoss()

    loss_class = loss_class.cuda()
    loss_domain = loss_domain.cuda()
    len_dataloader = min(len(dataloader_source), len(dataloader_target))
    data_source_iter = iter(dataloader_source)
    data_target_iter = iter(dataloader_target)

    i = 0
    while i < len_dataloader:

        p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader
        alpha = 2. / (1. + np.exp(-10 * p)) - 1

        # training model using source data
        s_img, s_label = data_source_iter.next()
        # s_img = s_img.expand(s_img.data.shape[0], 3, 28, 28)

        s_batch_size = s_img.shape[0]
        s_domain_label = torch.zeros(s_batch_size)
        s_domain_label = s_domain_label.long()

        s_img = s_img.cuda()
        s_label = s_label.cuda()
        s_domain_label = s_domain_label.cuda()
        
        # training model using target data
        t_img, _ = data_target_iter.next()
        # t_img = t_img.expand(t_img.data.shape[0], 3, 28, 28)

        t_batch_size = t_img.shape[0]
        t_domain_label = torch.ones(t_batch_size)
        t_domain_label = t_domain_label.long()

        t_img = t_img.cuda()
        t_domain_label = t_domain_label.cuda()
        
        cat_img = torch.cat((s_img, t_img), 0)
        class_output, domain_output = model(input_data=cat_img, alpha=alpha*alpha_scale)
        # s_class_output, s_domain_output = model(input_data=s_img, alpha=alpha)

        s_class_output = class_output[:s_batch_size]
        s_domain_output = domain_output[:s_batch_size]
        t_domain_output = domain_output[s_batch_size:]

        err_s_label = loss_class(s_class_output, s_label)
        err_s_domain = loss_domain(s_domain_output, s_domain_label)
        # _, t_domain_output = model(input_data=t_img, alpha=alpha)
        err_t_domain = loss_domain(t_domain_output, t_domain_label)

        err = err_t_domain + err_s_domain + err_s_label

        optimizer.zero_grad()
        err.backward()
        optimizer.step()

        i += 1

        if i%100 == 0:
            print('epoch: %d, [iter: %d / all %d], err_s_label: %f, err_s_domain: %f, err_t_domain: %f' \
                % (epoch, i, len_dataloader, err_s_label.cpu().data.numpy(),
                    err_s_domain.cpu().data.numpy(), err_t_domain.cpu().data.numpy()))

# Test the DANN classification accuracy (epoch is only used for the display purposes)
def test_one_epoch(model, dataloader, dataset_name, epoch):

    alpha = 0

    """ training """
    model = model.eval()
    model = model.cuda()

    # i = 0
    n_total = 0
    n_correct = 0

    for t_img, t_label in dataloader:
        batch_size = t_img.shape[0]
        # t_img = t_img.expand(t_img.data.shape[0], 3, 28, 28)
        t_img = t_img.cuda()
        t_label = t_label.cuda()

        class_output, _ = model(input_data=t_img, alpha=alpha)
        pred = class_output.data.max(1, keepdim=True)[1]
        n_correct += pred.eq(t_label.data.view_as(pred)).cpu().sum()
        n_total += batch_size

    accu = n_correct.data.numpy() * 1.0 / n_total

    print('epoch: %d, accuracy of the %s dataset: %f'%(epoch, dataset_name, accu))
    return accu

if __name__ == "__main__":
    # Create a parser for loading the source dataset 
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='cifar10')
    parser.add_argument('--dataroot', default='/nobackup/yguo/datasets')
    parser.add_argument('--depth', default=26, type=int)
    parser.add_argument('--width', default=1, type=int)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--group_norm', default=0, type=int)
    parser.add_argument('--shared', default='layer2')
    args = parser.parse_args()

    net, _, _, _ = build_model(args)
    model = DANNWrapper(net)
    _, test_source_loader = prepare_test_data(args)
    _, train_source_loader = prepare_train_data(args)

    #### Data Preparation
    target_train_data = ADVDataset('attack_data/prTTT_pgd8/train.npy') 
    target_test_data = ADVDataset('attack_data/prTTT_pgd8/test.npy') 
    train_target_loader = torch.utils.data.DataLoader(
    dataset=target_train_data,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=8)
    test_target_loader = torch.utils.data.DataLoader(
    dataset=target_test_data,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=8) 

    # Model Preparation
    init_random_seed(0)

    lr = 3e-4
    batch_size = 128
    n_epoch = 100

    # setup optimizer
    # lr = 3e-4
    # optimizer = optim.Adam(model.parameters(), lr=lr)

    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, [75, 125], gamma=0.1, last_epoch=-1) 
    
    model = model.cuda()


    source_dataset_name = 'cifar10'
    target_dataset_name = 'cifar10-pgd8'
    # DANN training
    for epoch in range(n_epoch):
        train_one_epoch(model, train_source_loader, train_target_loader, optimizer, epoch, n_epoch)
        scheduler.step()
        test_one_epoch(model, test_source_loader, source_dataset_name, epoch)
        test_one_epoch(model, test_target_loader, target_dataset_name, epoch)
        print("Test-Time Adaptation accuracy")
        test_one_epoch(model, train_target_loader, target_dataset_name, epoch)

         
