import os
import random
import sys
import argparse
import os.path as osp

import optuna
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from common.vision.transforms import MultipleApply
from common.utils import logging
from common.utils.data import ForeverDataIterator
from common.utils.analysis import tsne
from dalib.adaptation.cdan import ImageClassifier
from dalib.modules.domain_discriminator import DomainDiscriminator
from dalib.adaptation.cdan import ConditionalDomainAdversarialLoss
from dalib.adaptation.dann import DomainAdversarialLoss
from dalib.adaptation.mcc import MinimumClassConfusionLoss
from dalib.adaptation.proto import ProtoLoss
from dalib.adaptation.self_ensemble import L2ConsistencyLoss, CEConsistencyLoss
from dalib.adaptation.mdd import MarginDisparityDiscrepancy, ClassifierHead

from utils.common_utils import save_list
import scripts.utils as utils
from utils.eval_utils import construct_metric_dataloader, ClassConcatDataset, ClassBatchSampler
from evaluation import image_classification_test, compute_metrics


torch.multiprocessing.set_sharing_strategy('file_system')
logger = logging.get_logger(__name__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main(args, opt_param_paser=None):
    os.makedirs(args.log, exist_ok=True)
    logging.setup_logging(args.log)

    if 'search' in args.log:
        args.metrics = ['accuracy', 'test_loss', 'source_accuracy', 'ACM']

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.deterministic = True
    cudnn.benchmark = False  # True

    # Data loading code
    train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip,
                                                random_color_jitter=False, resize_size=args.resize_size,
                                                norm_mean=args.norm_mean, norm_std=args.norm_std)
    val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
                                            norm_mean=args.norm_mean, norm_std=args.norm_std)

    dset_loaders = {}
    train_source_dataset, train_target_dataset, test_source_dataset, test_target_dataset, args.num_classes, args.class_names = \
        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
    if args.method.lower() == "consist":
        import torchvision.transforms as T
        from common.vision.transforms import ResizeImage
        cosist_transform = T.Compose([
            ResizeImage(256),
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0),
            utils.GaussBlur(img_blur=0.5),
            T.ToTensor(),
            T.Normalize(mean=args.norm_mean, std=args.norm_std)
        ])
        train_source_dataset, train_target_dataset, test_source_dataset, test_target_dataset, args.num_classes, args.class_names = \
            utils.get_dataset(args.data, args.root, args.source, args.target,
                          train_transform, val_transform, MultipleApply([cosist_transform, val_transform]))
    
    dset_loaders["source"] = DataLoader(train_source_dataset, batch_size=args.batch_size,
                                        shuffle=True, num_workers=args.workers, drop_last=True)
    dset_loaders["target"] = DataLoader(train_target_dataset, batch_size=args.batch_size,
                                        shuffle=True, num_workers=args.workers, drop_last=True)
    dset_loaders["source_test"] = DataLoader(test_source_dataset, batch_size=args.batch_size, shuffle=False,
                                             num_workers=args.workers, drop_last=False)
    dset_loaders["target_test"] = DataLoader(test_target_dataset, batch_size=args.batch_size, shuffle=False,
                                             num_workers=args.workers, drop_last=False)
    
    if 'ACM' in args.metrics:
        cosist_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
                                                        random_color_jitter=True, random_gaussblur=True, resize_size=args.resize_size,
                                                        norm_mean=args.norm_mean, norm_std=args.norm_std)
        _, _, _, consist_dataset, _, _ = \
            utils.get_dataset(args.data, args.root, args.source, args.target, cosist_transform, cosist_transform)
        dset_loaders["target_aug"] = DataLoader(consist_dataset, batch_size=args.batch_size, shuffle=False,
                                                num_workers=args.workers, drop_last=False)
        
    train_source_iter = ForeverDataIterator(dset_loaders["source"])
    train_target_iter = ForeverDataIterator(dset_loaders["target"])

    logger.info(f"training: {len(train_source_dataset)} source samples, {len(train_target_dataset)} target samples")
    logger.info(f"testing: {len(test_source_dataset)} source samples, {len(test_target_dataset)} target samples")

    # create model
    logger.info("=> using model '{}'".format(args.arch))
    backbone = utils.get_model(args.arch, pretrain=not args.scratch)
    pool_layer = nn.Identity() if args.no_pool else None
    classifier = ImageClassifier(backbone, args.num_classes, bottleneck_dim=args.bottleneck_dim,
                                mlp_classifier=args.mlp_classifier, width=args.bottleneck_dim,
                                pool_layer=pool_layer, finetune=not args.scratch, 
                                classifier_norm=args.classifier_norm).to(device)
    classifier_feature_dim = classifier.features_dim
    cls_params = classifier.get_parameters(lr_multi_B=args.lr_multi_B, lr_multi_G=args.lr_multi_G,
                                           lr_multi_C=args.lr_multi_C)

    if args.method.lower() == "cdan":
        if args.randomized:
            domain_discri = DomainDiscriminator(args.randomized_dim, hidden_size=1024).to(device)
        else:
            domain_discri = DomainDiscriminator(classifier_feature_dim * args.num_classes, hidden_size=1024).to(device)
    elif args.method.lower() == "mdd":
        domain_discri = ClassifierHead(args.num_classes, bottleneck_dim=args.bottleneck_dim,
                            mlp_classifier=True, width=args.bottleneck_dim).to(device)
    else:
        domain_discri = DomainDiscriminator(classifier_feature_dim, hidden_size=1024).to(device)

    dc_params = domain_discri.get_parameters(lr_multi_D=args.lr_multi_D)
    all_params = cls_params + dc_params
    # define optimizer and lr scheduler
    optimizer = SGD(all_params, lr=1, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))

    if args.method.lower() == "cdan":
        # define loss function
        train_metric = ConditionalDomainAdversarialLoss(
            domain_discri, entropy_conditioning=args.entropy,
            num_classes=args.num_classes, features_dim=classifier_feature_dim, randomized=args.randomized,
            randomized_dim=args.randomized_dim
        ).to(device)
    elif args.method.lower() == "dann":
        train_metric = DomainAdversarialLoss(domain_discri).to(device)
    elif args.method.lower() == "mcc":
        train_metric = MinimumClassConfusionLoss(temperature=args.temperature).to(device)
    elif args.method.lower() == "mdd":
        train_metric = MarginDisparityDiscrepancy(domain_discri, args.margin).to(device)
    elif args.method.lower() == "source_only":
        train_metric = None
    elif args.method.lower() == "proto":
        train_metric = ProtoLoss(nav_t=args.nav_t, s_par=args.s_par, 
                assign_type=args.assign_type, cost_type=args.cost_type, 
                balance_type=args.balance_type).to(device)
    elif args.method.lower() == "consist":
        train_metric = CEConsistencyLoss().to(device)
    else:
        raise NotImplementedError

    if args.pretrain is None and args.pretrain_epochs > 0:
        # first pretrain the classifier wish source data
        print("Pretraining the model on source domain.")
        args.pretrain = osp.join(args.log, 'pretrain')
        pretrain_optimizer = Adam(cls_params, args.pretrain_lr)
        pretrain_lr_scheduler = LambdaLR(pretrain_optimizer,
                                lambda x: args.pretrain_lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
        # start pretraining
        for epoch in range(args.pretrain_epochs):
            # pretrain for one epoch
            utils.pretrain(train_source_iter, classifier, pretrain_optimizer, pretrain_lr_scheduler, epoch, args, device)
            # validate to show pretrain process
            classifier.eval()
            temp_acc, test_loss = image_classification_test(args, dset_loaders, classifier, train_metric)
            print(f"pretrain epoch {epoch}, target acc {temp_acc}")
        torch.save(classifier.state_dict(), args.pretrain)
        print("Pretraining process is done.")
    if args.pretrain_epochs > 0:
        checkpoint = torch.load(args.pretrain, map_location='cpu')
        classifier.load_state_dict(checkpoint)

    # start training
    best_acc = 0.
    all_metric_scores = []
    summary_writer = SummaryWriter(osp.join(args.log, 'train_logs'))
    for epoch in range(args.epochs):
        logger.info(f"lr: {lr_scheduler.get_last_lr()[1]}")
        
        if args.method.lower() == "cdan":
            from scripts.cdan import train
            # train for one epoch
            train(train_source_iter, train_target_iter, classifier, train_metric, optimizer, lr_scheduler, epoch, args)
        elif args.method.lower() == "dann":
            from scripts.dann import train
            train(train_source_iter, train_target_iter, classifier, train_metric, optimizer, lr_scheduler, epoch, args)
        elif args.method.lower() == "mcc":
            from scripts.mcc import train
            train(train_source_iter, train_target_iter, classifier, train_metric, optimizer, lr_scheduler, epoch, args)
        elif args.method.lower() == "mdd":
            from scripts.mdd import train
            train(train_source_iter, train_target_iter, classifier, train_metric, optimizer, lr_scheduler, epoch, args)
        elif args.method.lower() == "source_only":
            from scripts.source_only import train
            train(train_source_iter, classifier, optimizer, lr_scheduler, epoch, args, device)
        elif args.method.lower() == "proto":
            from scripts.proto import train
            train(train_source_iter, train_target_iter, classifier, train_metric, optimizer, lr_scheduler, epoch, args)
        elif args.method.lower() == "consist":
            from scripts.self_ensemble import train
            train(train_source_iter, train_target_iter, classifier, classifier, train_metric, optimizer, lr_scheduler, epoch, args)
        else:
            raise NotImplementedError

        # evaluate on validation set
        classifier.eval()
        if train_metric is not None:
            train_metric.eval()
        logger.info("Begin evaluation and compute metrics...")
        temp_acc, test_loss, target_test_data = image_classification_test(args, dset_loaders['target_test'], classifier, train_metric)
        dset_loaders['target_test_data'] = target_test_data

        source_acc, _, source_test_data = image_classification_test(args, dset_loaders['source_test'], classifier, None)
        dset_loaders['source_test_data'] = source_test_data

        if temp_acc > best_acc:
            best_acc = temp_acc

        if 'ACM' in args.metrics:
            aug_acc, _, target_aug_data = image_classification_test(args, dset_loaders['target_aug'], classifier, None)
            dset_loaders['target_aug_data'] = target_aug_data

        dset_loaders = construct_metric_dataloader(args, dset_loaders, split_ratio=0.66)

        # plot t-SNE
        '''
        tSNE_filename = osp.join(args.log, f'TSNE_{epoch}.png')
        tsne.visualize(dset_loaders['source_test_data'][0][:5000], dset_loaders['target_test_data'][0][:5000], tSNE_filename)
        logger.info(f"Saving t-SNE to {tSNE_filename}")
        '''
        try:
            metric_scores = compute_metrics(args, dset_loaders, classifier, k_fold=args.k_fold)
        except ValueError:
            print("metric_scores is NaN!!")
            if opt_param_paser is not None:
                raise optuna.exceptions.TrialPruned()
            else:
                break
        metric_scores["accuracy"] = temp_acc
        metric_scores["test_loss"] = test_loss
        metric_scores["source_accuracy"] = source_acc
        if 'search' in args.log:
            metric_scores['search_metric'] = metric_scores['ACM']
        for k, v in metric_scores.items():
            logger.info(f"epoch: {epoch}/{args.epochs}, metric {k}: {v:.5f}")
            summary_writer.add_scalar(f"eval/{k}", v, epoch)
        metric_scores["train_epoch"] = epoch
        all_metric_scores.append(metric_scores)

        if args.save_checkpoints:
            # torch.save(classifier.state_dict(), osp.join(args.log, f"epoch{epoch}_model.pth"))
            import pickle
            with open(osp.join(args.log, f"epoch{epoch}_data.pkl"), mode="wb") as f:
                pickle.dump([dset_loaders['source_test_data'], dset_loaders['target_test_data'], classifier.head.state_dict()], f)
        
        if opt_param_paser:
            # Report intermediate objective function values for a given step.
            opt_param_paser.trial.report(metric_scores[opt_param_paser.opt_metric], epoch)
            # Handle pruning based on the intermediate value.
            if opt_param_paser.trial.should_prune():
                save_list(all_metric_scores, osp.join(args.log, 'metric_scores.xlsx'))
                raise optuna.exceptions.TrialPruned()

    logger.info("best_acc = {:.3f}".format(best_acc))

    save_list(all_metric_scores, osp.join(args.log, 'metric_scores.xlsx'))
    return all_metric_scores


def convert_value(value, v):
    if isinstance(value, bool):
        if v.strip() == "False" or v.strip() == "false":
            return False
        elif v.strip() == "True" or v.strip() == "true":
            return True
    elif isinstance(value, str):
        return str(v)
    elif isinstance(value, int):
        return int(v)
    elif isinstance(value, float):
        return float(v)
    elif isinstance(value, list):
        return [convert_value(value[0], _v.strip()) for _v in v.strip("[").strip("]").split(",")]
    else:
        raise ValueError("Don't support for config type:", type(value))


def get_args(parser):
    # parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
    parser.add_argument('--method', type=str, default='CDAN', 
                        choices=['source_only', 'DANN', 'CDAN', 'MCC', 'MDD', 'Proto', 'consist'], 
                        help="The DA training method")
    parser.add_argument('--metric_batch_size', default=64, type=int, help='mini-batch size for metrics')
    parser.add_argument('--metric_train_epochs', default=10, type=int, help='train_epochs for metrics')
    parser.add_argument('--k_fold', default=0, type=int, help='k_fold for cross validation ')
    parser.add_argument('--lr_multi_D', type=float, default=1.0, help="The lr_multi_D")
    parser.add_argument('--lr_multi_B', type=float, default=0.1, help="The lr_multi_B")
    parser.add_argument('--lr_multi_G', type=float, default=1.0, help="The lr_multi_G")
    parser.add_argument('--lr_multi_C', type=float, default=1.0, help="The lr_multi_C")
    parser.add_argument('--mlp_classifier', action='store_true', default=False,
                        help='use MLP classifier')
    parser.add_argument('--classifier_norm', action='store_true', default=False,
                        help='normalize features and weights of the classifier')
    parser.add_argument('--imagenet_test', action='store_true', default=False,
                        help='computer domain distance with imagenet')
    parser.add_argument('--save_checkpoints', action='store_true', default=False,
                        help='save the checkpoints and features of each epoch')
    parser.add_argument('--pretrain', type=str, default=None,
                        help='pretrain checkpoint for classification model')
    parser.add_argument('--pretrain-lr', '--pretrain-learning-rate', default=3e-5, type=float,
                        help='initial pretrain learning rate', dest='pretrain_lr')
    parser.add_argument('--pretrain-epochs', default=0, type=int, metavar='N',
                        help='number of total epochs(pretrain) to run')
    parser.add_argument(
        "--opts",
        help="See ./utils/config.py for all options",
        default=None,
        nargs=argparse.REMAINDER,
    )
    args, unknown = parser.parse_known_args()
    if args.method.lower() == "cdan":
        from scripts.cdan import parse_args
    elif args.method.lower() == "dann":
        from scripts.dann import parse_args
    elif args.method.lower() == "mcc":
        from scripts.mcc import parse_args
    elif args.method.lower() == "mdd":
        from scripts.mdd import parse_args
    elif args.method.lower() == "source_only":
        from scripts.source_only import parse_args
    elif args.method.lower() == "proto":
        from scripts.proto import parse_args
    elif args.method.lower() == "consist":
        from scripts.self_ensemble import parse_args
    args = parse_args(parser)
    general_metrics = ['accuracy', 'test_loss', 'source_accuracy', 'split_accuracy']
    discrepancy_metrics = ['a_distance', 'MCD', 'MDD'] 
    assign_cost_metrics = ['entropy', 'clustering_l2', 'clustering_cos', 'mlp_metrics'] 
    image_metrics = ['ACM']
    other_metrics = ['DEV', 'SND', 'BNM']
    args.metrics = general_metrics + discrepancy_metrics + assign_cost_metrics + image_metrics + other_metrics
    # Load config from command line, overwrite config from opts.
    if args.opts is not None:
        for k, v in zip(args.opts[0::2], args.opts[1::2]):
            setattr(args, k, convert_value(getattr(args, k), v))
    return args

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='train for Unsupervised Domain Adaptation')
    args = get_args(parser)
    main(args)
