# -*- coding: utf-8 -*

import random
import time
import warnings
import sys
import argparse
import copy
import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
from torch.cuda.amp import *
from torch.optim import SGD
import torch.utils.data
from torch.utils.data import DataLoader
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torch.nn.functional as F
import os.path as osp
import gc
import utils
from network import ImageClassifier
import backbone as BackboneNetwork
from utils import ContinuousDataloader
from transforms import ResizeImage
from lr_scheduler import LrScheduler
from data_list import ImageList
from Loss import *

def one_hot_embedding(labels, num_classes=65):
    # Convert to One Hot Encoding
    y = torch.eye(num_classes).cuda()
    return y[labels]

def get_current_time():
    time_stamp = time.time()
    local_time = time.localtime(time_stamp)
    str_time = time.strftime('%Y-%m-%d_%H-%M-%S', local_time)
    return str_time

def main(args: argparse.Namespace, config):
    torch.multiprocessing.set_sharing_strategy('file_system')
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
    cudnn.benchmark = True

    # load data
    train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
                                                random_horizontal_flip=not args.no_hflip,
                                                random_color_jitter=False, resize_size=args.resize_size,
                                                norm_mean=args.norm_mean, norm_std=args.norm_std)
    val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
                                            norm_mean=args.norm_mean, norm_std=args.norm_std)
    

    train_source1_dataset, train_source2_dataset, train_target_dataset, val_dataset1, val_dataset2, test_dataset1, test_dataset2, num_classes, args.class_names = \
        utils.get_dataset(args.data, args.root, args.source1, args.source2, args.target, args.target1, args.target2, train_transform, val_transform)
    train_source1_loader = DataLoader(train_source1_dataset, batch_size=args.batch_size,
                                     shuffle=True, num_workers=args.workers, drop_last=True)
    train_source2_loader = DataLoader(train_source2_dataset, batch_size=args.batch_size,
                                     shuffle=True, num_workers=args.workers, drop_last=True)
    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
                                     shuffle=True, num_workers=args.workers, drop_last=True)
    val_loader1 = DataLoader(val_dataset1, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    val_loader2 = DataLoader(val_dataset2, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    test_loader1 = DataLoader(test_dataset1, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    test_loader2 = DataLoader(test_dataset2, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)

    train_source1_iter = ContinuousDataloader(train_source1_loader)
    train_source2_iter = ContinuousDataloader(train_source2_loader)
    train_target_iter = ContinuousDataloader(train_target_loader)

    # load model
    print("=> using pre-trained model '{}'".format(args.arch))
    backbone = BackboneNetwork.__dict__[args.arch](pretrained=True)
    classifier = ImageClassifier(backbone, num_classes).cuda()

    all_parameters = classifier.get_parameters()
    optimizer = SGD(all_parameters, args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
    lr_sheduler = LrScheduler(optimizer, init_lr=args.lr, gamma=0.001, decay_rate=0.75)

    
    # define loss function
    NWD_adv = NuclearWassersteinDiscrepancy_MMDA(classifier.head).cuda()
    edl_criterion = EDL_Loss(args)
    best_acc1 = 0.0
    best_acc2 = 0.0
    for epoch in range(args.epochs):
        # train for one epoch
        train(train_source1_iter, train_source2_iter, train_target_iter, classifier, optimizer, NWD_adv,
            lr_sheduler, epoch, num_classes, edl_criterion, args)

        if args.data == "domainnet":
            if epoch >= 5:
                # evaluate on test set
                acc1 = validate(test_loader1, classifier)
                acc2 = validate(test_loader2, classifier)
                # remember best top1 accuracy and checkpoint
                if acc1 > best_acc1:
                    best_model1 = copy.deepcopy(classifier.state_dict())
                best_acc1 = max(acc1, best_acc1)
                if acc2 > best_acc2:
                    best_model2 = copy.deepcopy(classifier.state_dict())
                best_acc2 = max(acc2, best_acc2)
                print("epoch= {:02d},  acc1={:.3f}, best_acc1 = {:.3f}".format(epoch, acc1, best_acc1))
                print("epoch= {:02d},  acc2={:.3f}, best_acc2 = {:.3f}".format(epoch, acc2, best_acc2))
                config["out_file"].write("epoch = {:02d},  acc1 = {:.3f}, best_acc1 = {:.3f}".format(epoch, acc1, best_acc1) + '\n')
                config["out_file"].write("epoch = {:02d},  acc2 = {:.3f}, best_acc2 = {:.3f}".format(epoch, acc2, best_acc2) + '\n')
                config["out_file"].flush()
        else:
            # evaluate on test set
            acc1 = validate(test_loader1, classifier)
            acc2 = validate(test_loader2, classifier)
            # remember the best top1 accuracy and checkpoint
            # if acc1 > best_acc1:
            #     utils.save_model(classifier, 'checkpoint/{}'.format(args.data), "exp_{}+{}_to_{}_{:.2f}.pth".format(args.source1, args.source2, args.target1, acc1))
            # if acc2 > best_acc2:
            #     utils.save_model(classifier, 'checkpoint/{}'.format(args.data), "exp_{}+{}_to_{}_{:.2f}.pth".format(args.source1, args.source2, args.target2, acc2))
            best_acc1 = max(acc1, best_acc1)
            best_acc2 = max(acc2, best_acc2)
            print("epoch = {:02d},  acc1={:.3f}, best_acc1 = {:.3f}".format(epoch, acc1, best_acc1))
            print("epoch = {:02d},  acc2={:.3f}, best_acc2 = {:.3f}".format(epoch, acc2, best_acc2))
            print("best_acc = {:.3f}".format((best_acc1+best_acc2)/2))
            config["out_file"].write("epoch = {:02d},  best_acc1 = {:.3f}, best_acc1 = {:.3f}".format(epoch, acc1, best_acc1) + '\n')
            config["out_file"].write("epoch = {:02d},  best_acc2 = {:.3f}, best_acc2 = {:.3f}".format(epoch, acc2, best_acc2) + '\n')
            config["out_file"].write("epoch = {:02d},  best_acc = {:.3f}".format(epoch, ((best_acc1+best_acc2)/2)) + '\n')
            config["out_file"].flush()

    print("best_acc1 = {:.3f}".format(best_acc1))
    print("best_acc2 = {:.3f}".format(best_acc2))
    print("best_acc = {:.3f}".format((best_acc1+best_acc2)/2))
    config["out_file"].write("best_acc1 = {:.3f}".format(best_acc1) + '\n')
    config["out_file"].write("best_acc2 = {:.3f}".format(best_acc2) + '\n')
    config["out_file"].write("best_acc = {:.3f}".format(((best_acc1+best_acc2)/2)) + '\n')
    config["out_file"].flush()

def train(train_source1_iter: ContinuousDataloader, train_source2_iter: ContinuousDataloader, train_target_iter: ContinuousDataloader, model: ImageClassifier,
        optimizer: SGD, NWD_adv, lr_sheduler: LrScheduler, epoch: int, num_classes: int, edl_criterion, args: argparse.Namespace):
    # switch to train mode
    scaler = GradScaler(enabled=True)
    
    model.train()
    NWD_adv.train()
    max_iters = args.iters_per_epoch * args.epochs
    for i in range(args.iters_per_epoch):
        current_iter = i + args.iters_per_epoch * epoch
        rho = current_iter / max_iters
        lr_sheduler.step()

        x_s1, labels_s1 = next(train_source1_iter)[:2]
        x_s2, labels_s2 = next(train_source2_iter)[:2]
        x_t, _ = next(train_target_iter)[:2]

        x_s1 = x_s1.cuda()
        x_s2 = x_s2.cuda()
        x_t = x_t.cuda()
        labels_s1 = labels_s1.cuda()
        labels_s2 = labels_s2.cuda()
        
        optimizer.zero_grad()
        with autocast(enabled=True):   
            # get features and logit outputs
            x = torch.cat((x_s1, x_s2, x_t), dim=0)
            y, f = model(x)
            y_s1, y_s2, y_t = y.chunk(3, dim=0)

            # compute loss
            loss_cls_s1, loss_kl_s1 = edl_criterion(y_s1, labels_s1, epoch+1, 40)
            loss_kl_s1 = loss_kl_s1/num_classes
            loss_cls_s2, loss_kl_s2 = edl_criterion(y_s2, labels_s2, epoch+1, 40)
            loss_kl_s2 = loss_kl_s2/num_classes
            edl_loss = (loss_cls_s1 + loss_kl_s1 + loss_cls_s2 + loss_kl_s2)
            loss_nwd = NWD_adv(f)
            MI_item1, MI_item2 = MI(y_t)

            total_loss = edl_loss - args.nwd_tradeoff * rho * loss_nwd
            - args.MI_tradeoff * (MI_item1 - MI_item2)

            # compute gradient and do SGD step
        # total_loss.backward()
        scaler.scale(total_loss).backward()
        # optimizer.step()
        scaler.step(optimizer)
        scaler.update()

        # print training log
        if i % args.print_freq == 0:
            print("Epoch: [{:02d}][{}/{}]	total_loss:{:.3f}	cls_loss:{:.3f}	  nwd_loss:{:.3f}  kl_loss:{:.3f}".format(\
                epoch, i, args.iters_per_epoch, total_loss, edl_loss, loss_nwd, loss_kl_s1 + loss_kl_s2))

def validate(val_loader: DataLoader, model: ImageClassifier) -> float:
    # switch to evaluate mode
    model.eval()
    start_test = True
    flag = True
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            images, target = data[:2]
            images = images.cuda()
            target = target.cuda()
            # get logit outputs
            output, _ = model(images)
            
            if start_test:
                all_output = output.float()
                all_label = target.float()
                start_test = False
            else:
                all_output = torch.cat((all_output, output.float()), 0)
                all_label = torch.cat((all_label, target.float()), 0)
       
        _, predict = torch.max(all_output, 1)
        accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
        accuracy = accuracy * 100.0
        print(' accuracy:{:.3f}'.format(accuracy))
    return accuracy

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Style Adaptation and Uncertainty Estimation for Multi-source Blended-target Domain Adaptation')
    parser.add_argument('root', metavar='DIR', help='root path of dataset')
    parser.add_argument('-d', '--data', metavar='DATA', default='OfficeHome', choices=utils.get_dataset_names(),
                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
                             ' (default: Office31)')
    parser.add_argument('-s1', '--source1', help='source1 domain(s)', nargs='+')
    parser.add_argument('-s2', '--source2', help='source2 domain(s)', nargs='+')
    parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
    parser.add_argument('-t1', '--target1', help='target1 domain(s)', nargs='+')
    parser.add_argument('-t2', '--target2', help='target2 domain(s)', nargs='+')
    parser.add_argument('--train-resizing', type=str, default='default')
    parser.add_argument('--val-resizing', type=str, default='default')
    parser.add_argument('--resize-size', type=int, default=224,
                        help='the image size after resizing')
    parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
                        help='Random resize scale (default: 0.08 1.0)')
    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
                        help='Random resize aspect ratio (default: 0.75 1.33)')
    parser.add_argument('--no-hflip', action='store_false',
                        help='no random horizontal flipping during training')
    parser.add_argument('--norm-mean', type=float, nargs='+',
                        default=(0.485, 0.456, 0.406), help='normalization mean')
    parser.add_argument('--norm-std', type=float, nargs='+',
                        default=(0.229, 0.224, 0.225), help='normalization std')
    parser.add_argument('--bottleneck-dim', default=1024, type=int)
    parser.add_argument('--arch', type=str, default='resnet50', choices=['resnet50', 'resnet101'])
    parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
    parser.add_argument('--output_dir', type=str, default='log/SAUE/office31', help="output directory of logs")
    parser.add_argument('--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)')
    parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run')
    parser.add_argument('--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch')
    parser.add_argument('--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)')
    parser.add_argument('--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)')
    parser.add_argument('--lr', default=0.01, type=float, metavar='LR', help='initial learning rate', dest='lr')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
    parser.add_argument('--weight-decay', default=1e-3, type=float, metavar='W', help='weight decay (default: 1e-3)', dest='weight_decay')
    parser.add_argument('--seed', default=2024, type=int, help='seed for initializing training. ')
    parser.add_argument('--nwd_tradeoff', type=float, default=1.0, help="hyper-parameter: gamma_d")
    parser.add_argument('--trade_off', default=1., type=float,
                        help='the trade-off hyper-parameter for transfer loss')
    parser.add_argument('--MI_tradeoff', type=float, default=0.1, help="hyper-parameter: beta")
    parser.add_argument('--temp', type=float, default=10.0, help="temperature scaling parameter")
    
    args = parser.parse_args()

    config = {}
    if not osp.exists(args.output_dir):
        os.makedirs(args.output_dir)
    config["out_file"] = open(osp.join(args.output_dir, get_current_time() + "_" + "_log.txt"), "w")

    config["out_file"].write("train_SAUE.py\n")
    import PIL
    config["out_file"].write("PIL version: {}\ntorch version: {}\ntorchvision version: {}\n".format(PIL.__version__, torch.__version__, torchvision.__version__))
    for arg in vars(args):
        print("{} = {}".format(arg, getattr(args, arg)))
        config["out_file"].write(str("{} = {}".format(arg, getattr(args, arg))) + "\n")
    config["out_file"].flush()
    main(args, config)
