from __future__ import print_function, absolute_import
import argparse
import os.path as osp
import random
import numpy as np
import sys
import time

import torch
from torch.backends import cudnn
from torch.utils.data import DataLoader
from adnt import datasets
from adnt import models
from adnt.adnt_trainer import Trainer
from adnt.evaluators import extract_features_labels
from adnt.utils.data import IterLoader
from adnt.utils.data import transforms as T
from adnt.utils.data.preprocessor import Preprocessor
from adnt.utils.logging import Logger
from adnt.utils.serialization import save_checkpoint
from adnt.utils.lr_scheduler import WarmupMultiStepLR


def get_data(name, data_dir, height, width, batch_size, workers,
             subset, target=False, iters=200,
             get_fname=False):
    root = osp.join(data_dir, name)

    dataset = datasets.create(name, root, subset=subset)

    normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])

    train_set = sorted(dataset.train)

    if not target:
        transformer = T.Compose([
            T.Resize((height, width), interpolation=3),
            T.RandomHorizontalFlip(p=0.5),
            T.Pad(10),
            T.RandomCrop((height, width)),
            T.ToTensor(),
            normalizer
        ])
    else:
        transformer = T.Compose([
            T.Resize((height, width), interpolation=3),
            T.ToTensor(),
            normalizer
        ])


    sampler = None
    if not target:
        data_loader = IterLoader(
            DataLoader(Preprocessor(train_set, root=dataset.root, transform=transformer),
                       batch_size=batch_size, num_workers=workers, sampler=sampler,
                       shuffle= True, pin_memory=True, drop_last=True), length=iters)
    else:
        data_loader = DataLoader(
            Preprocessor(train_set, root=dataset.root, transform=transformer, get_fname=get_fname),
            batch_size=batch_size, num_workers=workers,
            shuffle=False, pin_memory=True)

    return dataset, data_loader


def get_train_loader(dataset, height, width, batch_size, workers,
                     iters=200, train_set=None, soft=False):

    normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    train_transformer = T.Compose([
             T.Resize((height, width), interpolation=3),
             T.RandomHorizontalFlip(p=0.5),
             T.Pad(10),
             T.RandomCrop((height, width)),
             T.ToTensor(),
             normalizer,
	         T.RandomErasing(probability=0.5, mean=[0.485, 0.456, 0.406])
         ])

    if train_set is not None:
       train_set = sorted(train_set)
    else:
        train_set = sorted(dataset.train)

    train_loader = IterLoader(
                DataLoader(Preprocessor(train_set, root=dataset.root,
                                        transform=train_transformer,
                                        mutual=False, get_soft=soft),
                            batch_size=batch_size, num_workers=workers, sampler=None,
                            shuffle=True, pin_memory=True, drop_last=True), length=iters)

    return train_loader



def main():
    args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        cudnn.deterministic = True
        cudnn.benchmark = False

    main_worker(args)


def obtain_cluster_centers(model, data_loaders):

    label_features = extract_features_labels(model, data_loaders)
    cluster_centers = []
    cluster_numbers = {}
    for label, features in label_features:
        num_features = len(features)
        cluster_numbers[label] = num_features
        features = torch.cat(features, dim=0)
        cluster_center = torch.mean(features, dim=0, keepdim=True)
        cluster_centers.append(cluster_center)
    cluster_classifier = torch.cat(cluster_centers, dim=0)

    return cluster_classifier, cluster_numbers


def main_worker(args):
    start_epoch = args.start_epoch
    start_time = time.time()
    cudnn.benchmark = True

    sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    root = osp.join(args.data_dir, args.dataset)
    dataset = datasets.create(args.dataset, root, subset=args.dataset_target)

    _, test_loader = get_data(args.dataset, args.data_dir, args.height,
                              args.width, args.batch_size, args.workers,
                              subset=args.dataset_target, target=True)

    _, train_loader = get_data(args.dataset, args.data_dir, args.height,
                              args.width, args.batch_size, args.workers,
                              subset=args.dataset_target, target=False)

    if args.use_source:
        dataloaders = []
        testloaders = []
        for subset in args.dataset_sources:
            _, source_loader = get_data(args.dataset, args.data_dir, args.height,
                                        args.width, args.batch_size, args.workers,
                                        subset=subset, target=False)
            _, source_test_loader = get_data(args.dataset, args.data_dir, args.height, 
                                             args.width, args.batch_size, args.workers,
                                             subset=subset, target=True)

            dataloaders.append(source_loader)
            testloaders.append(source_test_loader)
            
        dataloaders.append(train_loader)
        testloaders.append(test_loader)
        
    num_classes = dataset.num_classes
    print('----------')
    print(num_classes)
    print('----------')

    # Create model
    model = models.create(args.arch, num_classes=num_classes, num_domains=len(args.dataset_sources) + 1)
    model.cuda()

    # Optimizer
    params = []
    for key, value in model.named_parameters():
        if not value.requires_grad:
            continue
        if 'classifier' in key:
            adjust_value = 0.1
        else:
            adjust_value = 1
        params += [{"params": [value], "lr": args.lr * adjust_value, "weight_decay": args.weight_decay}]

    optimizer = torch.optim.Adam(params)

    print('---Getting Cluster Centers---')
    cluster_centers, cluster_numbers = obtain_cluster_centers(model, test_loader)
    memory = models.create('MemoryBank', feature_centers=cluster_centers,
                           cluster_numbers=cluster_numbers)

    trainer = Trainer(model, memory=memory, num_classes=num_classes, dataloaders=testloaders)

    best_accu = 0
    best_epoch = 0
    lr_scheduler = WarmupMultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma,
                                     warmup_factor=args.warmup_factor, warmup_iters=args.warmup_step)

    for epoch in range(start_epoch, args.epochs):

        # Trainer
        lr_scheduler.step()
        if args.use_source:
            for source_train_loader in dataloaders:
                source_train_loader.new_epoch()

            trainer.train(epoch, dataloaders, optimizer=optimizer,
                          print_freq=args.print_freq, train_iters=len(dataloaders[0]),
                          merge=args.merge_features)

        else:
            train_loader.new_epoch()
            trainer.train(epoch, train_loader, optimizer=optimizer,
                          print_freq=args.print_freq, train_iters=len(train_loader))


        prec_accu = trainer.test(test_loader)
        prec = list(prec.item() for prec in prec_accu)
        prec1, prec2, prec3, prec4, prec5 = prec
        is_best = prec1 > best_accu
        best_accu = max(prec1, best_accu)
        best_epoch = epoch if is_best else best_epoch

        print('============ Test on target datasets {} ============\n'
              'Epoch: [{}]\t'
              'Top1 Prec {:.2%}\t '
              'Top2 Prec {:.2%}\t '
              'Top3 Prec {:.2%}\t '
              'Top4 Prec {:.2%}\t '
              'Top5 Prec {:.2%}\t '
              'Best Top1 {:.2%}\t '
              'Best Epoch [{}]\n'
              '========================================================'
              .format(args.dataset_target, epoch,
                      prec1, prec2, prec3, prec4,
                      prec5, best_accu, best_epoch)
              )

        if (epoch + 1) % 10 == 0:
            save_checkpoint({
                'state_dict': model.state_dict(),
                'epoch': epoch + 1,
                'Prec': prec_accu
            },
                is_best=False,
                fpath=osp.join(args.logs_dir, str(epoch + 1) + '_checkpoint.pth.tar'))

    print('✿✿ヽ(°▽°)ノ✿ (*^▽^*)----Finish Training----(*^▽^*) ✿ヽ(°▽°)ノ✿✿')
    total_time = time.time() - start_time

    print('Training Spend Time {:.0f}h {:.0f}min {:.0f}s'
          .format(total_time // 3600,
                  (total_time % 3600) // 60,
                  (total_time % 3500) % 60))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="ADNT Training")
    # data
    parser.add_argument('--use-source', action='store_true')
    parser.add_argument('--dataset', type=str, default='office_home')
    parser.add_argument('-ds', '--dataset-sources', type=str, default=['CL', 'PR', 'RW'],
                        choices=datasets.names())
    parser.add_argument('-dt', '--dataset-target', type=str, default='AR',
                        choices=datasets.names())
    parser.add_argument('-b', '--batch-size', type=int, default=64)
    parser.add_argument('-j', '--workers', type=int, default=4)
    parser.add_argument('--height', type=int, default=224,
                        help="input height")
    parser.add_argument('--width', type=int, default=224,
                        help="input width")
    # model
    parser.add_argument('-a', '--arch', type=str, default='resnet50',
                        choices=models.names())
    # parser.add_argument('--dropout', type=float, default=0)
    # optimizer
    parser.add_argument('--warmup-step', type=int, default=5)
    parser.add_argument('--gamma', type=float, default=0.1)
    parser.add_argument('--milestones', nargs='+', type=int, default=[10, 15])
    parser.add_argument('--warmup-factor', type=float, default=0.01)
    parser.add_argument('--lr', type=float, default=0.00035,
                        help="learning rate of new parameters, for pretrained "
                             "parameters it is 10 times smaller than this")
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--alpha', type=float, default=0.999)
    parser.add_argument('--weight-decay', type=float, default=5e-4)
    parser.add_argument('--epochs', type=int, default=40)
    parser.add_argument('--iters', type=int, default=400)
    # training configs
    parser.add_argument('--checkpoint', type=str, default='', metavar='PATH')
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--print-freq', type=int, default=10)
    # path
    working_dir = osp.dirname(osp.abspath(__file__))
    parser.add_argument('--data-dir', type=str, metavar='PATH',
                        default=osp.join('/root/autodl-tmp', 'datasets'))
    parser.add_argument('--logs-dir', type=str, metavar='PATH',
                        default=osp.join(working_dir, 'logs'))

    main()

