import argparse
import datetime
import random
import copy
import json
import os

import torch
import numpy as np
from tensorboardX import SummaryWriter

from data_utils import get_dataset, is_textdata
from model_utils import get_model
from training.multitask import multitask, finetune



def get_parser():
    parser = argparse.ArgumentParser(description='TOFU: transfer of unstable features')
    parser.add_argument('--cuda', type=int, default=0)

    # data sample
    parser.add_argument('--num_epochs', type=int, default=1000)
    parser.add_argument('--num_batches', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=50)
    parser.add_argument('--num_query', type=int, default=20)
    parser.add_argument('--seed', type=int, default=0)

    # model
    parser.add_argument('--hidden_dim', type=int, default=300)
    parser.add_argument('--weight_decay', type=float, default=0.001)
    parser.add_argument('--dropout', type=float, default=0.0)

    #dataset
    parser.add_argument('--dataset', type=str, default='')
    parser.add_argument('--dataset_2', type=str, default='')
    parser.add_argument('--dataset_3', type=str, default='')
    parser.add_argument('--tar_dataset', type=str, default='')
    parser.add_argument('--val', type=str, default='in_domain')

    # method specification
    parser.add_argument('--num_clusters', type=int, default=2)
    parser.add_argument('--nickname', type=str, default='')
    parser.add_argument('--method', type=str, default='ours')
    parser.add_argument('--tar_method', type=str, default='erm')
    parser.add_argument('--transfer_ebd', action='store_true', default=False,
        help='whether to transfer the ebd function learned from the source task')
    parser.add_argument('--train_ebd', action='store_true', default=False,
        help='whether to finetune the ebd function on the target task')

    #optimization
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--l_regret', type=float, default=1)
    parser.add_argument('--thres', type=float, default=0.3)
    parser.add_argument('--anneal_iters', type=float, default=1)
    parser.add_argument('--clip_grad', type=float, default=1)
    parser.add_argument('--patience', type=int, default=20)

    #result file
    parser.add_argument('--results_path', type=str, default='')

    return parser


def print_args(args):
    '''
        Print arguments (only show the relevant arguments)
    '''
    print("Parameters:")
    for attr, value in sorted(args.__dict__.items()):
        print("\t{}={}".format(attr.upper(), value))
    print('''
    (Credit: Maija Haavisto)                        /
                                 _,.------....___,.' ',.-.
                              ,-'          _,.--"        |
                            ,'         _.-'              .
                           /   ,     ,'                   `
                          .   /     /                     ``.
                          |  |     .                       \.\\
                ____      |___._.  |       __               \ `.
              .'    `---""       ``"-.--"'`  \               .  \\
             .  ,            __               `              |   .
             `,'         ,-"'  .               \             |    L
            ,'          '    _.'                -._          /    |
           ,`-.    ,".   `--'                      >.      ,'     |
          . .'\\'   `-'       __    ,  ,-.         /  `.__.-      ,'
          ||:, .           ,'  ;  /  / \ `        `.    .      .'/
          j|:D  \          `--'  ' ,'_  . .         `.__, \   , /
         / L:_  |                 .  "' :_;                `.'.'
         .    ""'                  """""'                    V
          `.                                 .    `.   _,..  `
            `,_   .    .                _,-'/    .. `,'   __  `
             ) \`._        ___....----"'  ,'   .'  \ |   '  \  .
            /   `. "`-.--"'         _,' ,'     `---' |    `./  |
           .   _  `""'--.._____..--"   ,             '         |
           | ." `. `-.                /-.           /          ,
           | `._.'    `,_            ;  /         ,'          .
          .'          /| `-.        . ,'         ,           ,
          '-.__ __ _,','    '`-..___;-...__   ,.'\ ____.___.'
          `"^--'..'   '-`-^-'"--    `-^-'`.''"""""`.,^.`.--' mh
    ''')


def set_seed(seed):
    '''
        Setting random seeds
    '''
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)


if __name__ == '__main__':
    parser = get_parser()
    args = parser.parse_args()

    torch.cuda.set_device(args.cuda)

    print_args(args)

    set_seed(args.seed)

    best_model_ebd = None
    res = {}
    for attr, value in sorted(args.__dict__.items()):
        res[attr] = value

    # get data loaders for training and testing:
    # env 0 and 1 are used for source training
    # env 2 is used for source validation # not useful for Tofu, but useful for
    # direct transfer to determine when to early stopping on the source task
    # env 3 is used for source testing, just to see how well the source
    # classifier is

    src_tasks = args.dataset.split(',')
    src_train_datasets = []
    src_test_datasets = []
    src_models = []
    src_opts = []

    for dataset in src_tasks:
        args.dataset = dataset

        # get source dataset:
        train_src_data, test_src_data = get_dataset(args.dataset, args.val)

        # define source model
        if is_textdata(args.dataset):
            src_model, src_opt = get_model(args, train_src_data.vocab)
        else:
            src_model, src_opt = get_model(args)

        src_train_datasets.append(train_src_data)
        src_test_datasets.append(test_src_data)
        src_models.append(src_model)
        src_opts.append(src_opt)

    # get target dataset:
    args.dataset = args.tar_dataset
    if is_textdata(args.dataset):
        train_tar_data, test_tar_data = get_dataset(args.dataset, args.val,
                                                    True, train_src_data.vocab)
    else:
        train_tar_data, test_tar_data = get_dataset(args.dataset, args.val, True)

    # define target model
    if is_textdata(args.dataset):
        tar_model, tar_opt = get_model(args, train_src_data.vocab)
    else:
        tar_model, tar_opt = get_model(args)

    # start training
    print("{}, Start training {} on train env".format(
        datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
        args.method), flush=True)

    if args.nickname in ['finetune', 'reuse']:
        cur_res = finetune(src_train_datasets, src_test_datasets, train_tar_data,
                            test_tar_data, src_models, tar_model, src_opts, tar_opt, args)
    else:
        cur_res = multitask(src_train_datasets, src_test_datasets, train_tar_data,
                            test_tar_data, src_models, tar_model, src_opts, tar_opt, args)

    res['src_test_acc'] = cur_res['src_test_acc']
    res['tar_val_acc'] = cur_res['tar_val_acc']
    res['tar_test_acc'] = cur_res['tar_test_acc']


    if args.results_path != '':
        with open(args.results_path, 'w') as f:
            json.dump(res, f)
