import argparse
import datetime
import random
import copy
import json
import os

import torch
import numpy as np
from tensorboardX import SummaryWriter

from data_utils import get_dataset, is_textdata
from model_utils import get_model
from train_utils import train_val_test



def get_parser():
    parser = argparse.ArgumentParser(description='TOFU: transfer of unstable features')
    parser.add_argument('--cuda', type=int, default=0)

    # data sample
    parser.add_argument('--num_epochs', type=int, default=1000)
    parser.add_argument('--num_batches', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=50)
    parser.add_argument('--num_query', type=int, default=20)
    parser.add_argument('--seed', type=int, default=0)

    # model
    parser.add_argument('--hidden_dim', type=int, default=300)
    parser.add_argument('--weight_decay', type=float, default=0.001)
    parser.add_argument('--dropout', type=float, default=0.1)

    #dataset
    parser.add_argument('--dataset', type=str, default='')
    parser.add_argument('--dataset_remaining', type=str, default='')
    parser.add_argument('--tar_dataset', type=str, default='')
    parser.add_argument('--val', type=str, default='in_domain')

    # method specification
    parser.add_argument('--num_clusters', type=int, default=2)
    parser.add_argument('--nickname', type=str, default='')
    parser.add_argument('--method', type=str, default='ours')
    parser.add_argument('--tar_method', type=str, default='erm')
    parser.add_argument('--transfer_ebd', action='store_true', default=False,
        help='whether to transfer the ebd function learned from the source task')
    parser.add_argument('--train_ebd', action='store_true', default=False,
        help='whether to finetune the ebd function on the target task')
    parser.add_argument('--balance', action='store_true', default=False,
        help='whether to use balance sampling')

    #optimization
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--thres', type=float, default=0.3)
    parser.add_argument('--anneal_iters', type=float, default=1)
    parser.add_argument('--clip_grad', type=float, default=1)
    parser.add_argument('--patience', type=int, default=20)

    #result file
    parser.add_argument('--results_path', type=str, default='')

    return parser


def print_args(args):
    '''
        Print arguments (only show the relevant arguments)
    '''
    print("Parameters:")
    for attr, value in sorted(args.__dict__.items()):
        print("\t{}={}".format(attr.upper(), value))
    print('''
    (Credit: Maija Haavisto)                        /
                                 _,.------....___,.' ',.-.
                              ,-'          _,.--"        |
                            ,'         _.-'              .
                           /   ,     ,'                   `
                          .   /     /                     ``.
                          |  |     .                       \.\\
                ____      |___._.  |       __               \ `.
              .'    `---""       ``"-.--"'`  \               .  \\
             .  ,            __               `              |   .
             `,'         ,-"'  .               \             |    L
            ,'          '    _.'                -._          /    |
           ,`-.    ,".   `--'                      >.      ,'     |
          . .'\\'   `-'       __    ,  ,-.         /  `.__.-      ,'
          ||:, .           ,'  ;  /  / \ `        `.    .      .'/
          j|:D  \          `--'  ' ,'_  . .         `.__, \   , /
         / L:_  |                 .  "' :_;                `.'.'
         .    ""'                  """""'                    V
          `.                                 .    `.   _,..  `
            `,_   .    .                _,-'/    .. `,'   __  `
             ) \`._        ___....----"'  ,'   .'  \ |   '  \  .
            /   `. "`-.--"'         _,' ,'     `---' |    `./  |
           .   _  `""'--.._____..--"   ,             '         |
           | ." `. `-.                /-.           /          ,
           | `._.'    `,_            ;  /         ,'          .
          .'          /| `-.        . ,'         ,           ,
          '-.__ __ _,','    '`-..___;-...__   ,.'\ ____.___.'
          `"^--'..'   '-`-^-'"--    `-^-'`.''"""""`.,^.`.--' mh
    ''')


def set_seed(seed):
    '''
        Setting random seeds
    '''
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)


if __name__ == '__main__':
    parser = get_parser()
    args = parser.parse_args()

    torch.cuda.set_device(args.cuda)

    print_args(args)

    set_seed(args.seed)

    # print out all the arguments
    res = {}
    for attr, value in sorted(args.__dict__.items()):
        res[attr] = value

    best_model_ebd = None
    partition = None

    train_partition_loaders = []
    val_partition_loaders = []

    # specify source tasks, separated by comma
    if args.method == 'skip':
        # skip source task for erm method
        args.dataset = ''
    source_tasks = args.dataset.split(',')
    args.dataset = source_tasks[0]
    if len(source_tasks) == 1:
        args.dataset_remaining = []
    else:
        args.dataset_remaining = source_tasks[1:]

    while True:
        if args.dataset == '':
            # no source task
            break

        #############################
        # Learning on the source task
        #############################

        # get source dataset:
        #
        # there are 4 envs:
        #
        # env 0 and 1: source training
        #
        # env 2: source validation. not that this is not useful for Tofu, but
        # useful for direct transfer to determine when to early stopping on the
        # source task
        #
        # env 3 is used for source testing, just to see how well the source
        # classifier is
        #
        train_src_data, test_src_data = get_dataset(args.dataset, args.val,
                                                    target=False)

        # initialize model and optimizer based on the dataset and the method
        if is_textdata(args.dataset):
            model, opt = get_model(args, train_src_data.vocab)
        else:
            model, opt = get_model(args)

        # start training
        print("{}, Start training {} on train env for {}".format(
            datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
            args.method, args.dataset), flush=True)

        # save the partition results on the source task for learning the unstable
        # feature space
        cur_res, train_partition_loaders, val_partition_loaders = train_val_test(
            train_src_data, test_src_data, model, opt, args,
            train_partition_loaders=train_partition_loaders,
            val_partition_loaders=val_partition_loaders)

        # transfer the source robust model's feature extractor
        best_model_ebd = copy.deepcopy(model['ebd'].state_dict())

        res['src_val_acc'] = cur_res['val']['acc']
        res['src_val_loss'] = cur_res['val']['loss']
        res['src_test_acc'] = cur_res['test']['acc']
        res['src_test_loss'] = cur_res['test']['loss']

        if args.method == 'erm' or args.method == 'metric':
            # this is for visualization
            # directly training a biased ERM model on the source data does not
            # necessarily split the target data based on the unstable features
            partition = model
        else:
            partition = cur_res['partition']

        if len(args.dataset_remaining) == 0:
            break

        args.dataset = args.dataset_remaining[0]
        args.dataset_remaining = args.dataset_remaining[1:]

        # check number of partition loaders
        print('train partition loader numbers')
        print(len(train_partition_loaders))


    #############################
    # Transfer to the target task
    #############################

    # get target dataset:
    # the structure of the envs are the same as in the source dataset
    args.dataset = args.tar_dataset

    if (args.transfer_ebd or args.tar_method == 'ours') and  is_textdata(args.dataset):
        train_tar_data, test_tar_data = get_dataset(args.dataset, args.val,
                                                    True, train_src_data.vocab)
    else:
        train_tar_data, test_tar_data = get_dataset(args.dataset, args.val, True)

    args.method = args.tar_method

    # initialize model and optimizer based on the dataset and the method
    if is_textdata(args.dataset):
        if args.transfer_ebd:
            # for text data, if we want to transfer the feature representation
            # we also need to transfer the vocabulary
            model, opt = get_model(args, train_src_data.vocab)
        else:
            model, opt = get_model(args, train_tar_data.vocab)
    else:
        model, opt = get_model(args)

    # transfer the embedding function
    if args.transfer_ebd:
        model['ebd'].load_state_dict(best_model_ebd)

    # start training
    print("{}, Start training {} on train env for {}".format(
        datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
        args.method, args.dataset), flush=True)

    cur_res = train_val_test(
        train_tar_data, test_tar_data, model, opt, args,
        partition_model=partition, train_ebd=args.train_ebd)

    res['tar_val_acc'] = cur_res['val']['acc']
    res['tar_val_loss'] = cur_res['val']['loss']
    res['tar_test_acc'] = cur_res['test']['acc']
    res['tar_test_loss'] = cur_res['test']['loss']

    if args.results_path != '':
        with open(args.results_path, 'w') as f:
            json.dump(res, f)
