import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler
import numpy as np
from utils.utils import *
from utils import convmix

def train_src(args, featurizer, classifier, data_loader_train, data_loader_valid, ul1_data_loader_valid, ul2_data_loader_valid, test_data_loader):
    
    featurizer.train()
    classifier.train()

    optimizer = optim.SGD([
        {'params': featurizer.parameters()},
        {'params': classifier.parameters()}
    ], lr=args.lr_pre, weight_decay=args.wd_pre, momentum=args.momentum)
    
    if args.schedule_pre == 'step':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)
    elif args.schedule_pre == 'poly':
        optimizer = op_copy(optimizer)

    criterion = nn.CrossEntropyLoss()

    if args.augmentation == 'convmix':
        aug = convmix.ConvMix(args)

    src_valid_loss_list = []
    src_valid_acc_list = []
    ul1_valid_acc_list = []
    ul2_valid_acc_list = []
    test_acc_list = []

    iter_num = 0
    if args.dataset == 'PACS':
        max_epoch = 500
    elif args.dataset == 'OfficeHome':
        max_epoch = 300
    elif args.dataset == 'Digits':
        max_epoch = 2000
    max_iter = max_epoch * len(data_loader_train)

    for epoch in range(args.epoch_pre):
        for step, (images, labels, domains) in enumerate(data_loader_train):
            
            iter_num += 1
            if args.schedule_pre == 'poly':
                lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter, power=args.power_pre)

            images = images.cuda()
            labels = labels.cuda()

            # zero gradients for optimizer
            optimizer.zero_grad()

            org_images = images
            if not args.augmentation == 'none':
                images = aug(images)
            if args.aug_ogcat:
                images = torch.cat([org_images, images])
                labels = torch.cat([labels, labels])
            
            # compute loss for critic
            preds = classifier(featurizer(images))

            loss = criterion(preds, labels)

            # optimize source model
            loss.backward()
            optimizer.step()
        
        if args.schedule_pre == 'step':
            scheduler.step()

        if args.plot_pre:
            src_valid_loss, src_valid_acc = eval_src(featurizer, classifier, data_loader_valid, False)
            _, ul1_valid_acc = eval_src(featurizer, classifier, ul1_data_loader_valid, False)
            _, ul2_valid_acc = eval_src(featurizer, classifier, ul2_data_loader_valid, False)
            _, test_acc = eval_src(featurizer, classifier, test_data_loader, False)
            src_valid_loss_list.append(src_valid_loss)
            src_valid_acc_list.append(src_valid_acc)
            ul1_valid_acc_list.append(ul1_valid_acc)
            ul2_valid_acc_list.append(ul2_valid_acc)
            test_acc_list.append(test_acc)

    if args.plot_pre:
        make_plot_pre(args, src_valid_loss_list, src_valid_acc_list, ul1_valid_acc_list, ul2_valid_acc_list, test_acc_list)

    # save final model
    
    if not args.augmentation == 'none':
        if args.schedule_pre == 'poly':
            title_featurizer = "source-featurizer-pretrain_{}_ogcat_{}_wn_{}_ep{}_lr{}_sch_{}_pw{}_wd{}_{}-{}".format(args.augmentation, args.aug_ogcat, args.classifier_wn, args.epoch_pre, args.lr_pre, args.schedule_pre, args.power_pre, args.wd_pre, args.dataset, args.order[0])
            title_classifier = "source-classifier-pretrain_{}_ogcat_{}_wn_{}_ep{}_lr{}_sch_{}_pw{}_wd{}_{}-{}".format(args.augmentation, args.aug_ogcat, args.classifier_wn, args.epoch_pre, args.lr_pre, args.schedule_pre, args.power_pre, args.wd_pre, args.dataset, args.order[0])
        else:
            title_featurizer = "source-featurizer-pretrain_{}_ogcat_{}_wn_{}_ep{}_lr{}_sch_{}_wd{}_{}-{}".format(args.augmentation, args.aug_ogcat, args.classifier_wn, args.epoch_pre, args.lr_pre, args.schedule_pre, args.wd_pre, args.dataset, args.order[0])
            title_classifier = "source-classifier-pretrain_{}_ogcat_{}_wn_{}_ep{}_lr{}_sch_{}_wd{}_{}-{}".format(args.augmentation, args.aug_ogcat, args.classifier_wn, args.epoch_pre, args.lr_pre, args.schedule_pre, args.wd_pre, args.dataset, args.order[0])
    else:    
        title_featurizer = "source-featurizer-pretrain_ep{}_lr{}_sch_{}_wd{}_{}-{}".format(args.epoch_pre, args.lr_pre, args.schedule_pre, args.wd_pre, args.dataset, args.order[0])
        title_classifier = "source-classifier-pretrain_ep{}_lr{}_sch_{}_wd{}_{}-{}".format(args.epoch_pre, args.lr_pre, args.schedule_pre, args.wd_pre, args.dataset, args.order[0])

    if args.fea_norm:
        title_featurizer += '_feanorm_temp{}'.format(args.classifier_temp)
        title_classifier += '_feanorm_temp{}'.format(args.classifier_temp)
    elif args.fea_norm2:
        title_featurizer += '_feanorm2_temp{}'.format(args.classifier_temp)
        title_classifier += '_feanorm2_temp{}'.format(args.classifier_temp)

    if args.no_bias:
        title_featurizer += '_nobias'
        title_classifier += '_nobias'

    title_featurizer += '.pt'
    title_classifier += '.pt'

    save_model(args, featurizer, title_featurizer)
    save_model(args, classifier, title_classifier)

    return featurizer, classifier

def eval_src(featurizer, classifier, data_loader, print_acc):
    """Evaluate classifier for source domain."""
    # set eval state for Dropout and BN layers
    featurizer.eval()
    classifier.eval()

    # init loss and accuracy
    loss = 0
    acc = 0.0

    # set loss function
    criterion = nn.CrossEntropyLoss()

    # evaluate network
    with torch.no_grad():
        for (images, labels, domains) in data_loader:

            preds = classifier(featurizer(images.cuda()))
            loss += criterion(preds, labels.cuda()).data.item()

            pred_cls = preds.data.max(1)[1]
            acc += pred_cls.eq(labels.cuda().data).cpu().sum()

    loss /= len(data_loader)
    acc /= len(data_loader.dataset)

    featurizer.train()
    classifier.train()

    if print_acc:
        print("Avg Loss = {}, Avg Accuracy = {:2%}".format(loss, acc))

    return loss, acc
    
def lr_scheduler(optimizer, iter_num, max_iter, power=0.75, gamma=10):
    decay = (1 + gamma * iter_num / max_iter) ** (-power)
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr0'] * decay
    return optimizer

def op_copy(optimizer):
    for param_group in optimizer.param_groups:
        param_group['lr0'] = param_group['lr']
    return optimizer