import os
import json
import torch
import datetime
import argparse

import numpy as np

import torch.multiprocessing as mp

from src.utils import str2bool
from src.runners import EmoSentRunner


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--given_model_id', default=None)

    parser.add_argument('--pt', default=False, type=str2bool)
    parser.add_argument('--pt_test', default=False, type=str2bool)
    parser.add_argument('--train', default=True, type=str2bool)
    parser.add_argument('--dev', default=False, type=str2bool)
    parser.add_argument('--test', default=False, type=str2bool)
    parser.add_argument('--inspect', default=False, type=str2bool)

    parser.add_argument('--token_mask_pct', default=0.15, type=float)
    parser.add_argument('--entity_attn', default=True, type=str2bool)
    parser.add_argument('--e2e_attn', default=False, type=str2bool)
    parser.add_argument('--e2e_attn_masking', default=False, type=str2bool)
    parser.add_argument('--e2e_attn_mask_prob', default=0.15, type=float)
    parser.add_argument('--use_queue', default=True, type=str2bool)

    parser.add_argument('--torch_amp', default=True, type=str2bool)
    parser.add_argument('--freeze_roberta', default=False, type=str2bool)
    parser.add_argument('--add_cls_to_entity', default=False, type=str2bool)
    parser.add_argument('--v2_attn', default=True, type=str2bool)

    parser.add_argument('--use_meld', default=False, type=str2bool)
    parser.add_argument('--use_iemocap', default=False, type=str2bool)
    parser.add_argument('--use_emory_nlp', default=False, type=str2bool)

    parser.add_argument('--do_sentiment', default=False, type=str2bool)
    parser.add_argument('--objective', default='moco')
    parser.add_argument('--supcon_loss_fn', default='cosine_sim')
    parser.add_argument('--project_embds_for_contrast', default=True, type=str2bool)
    parser.add_argument('--contrastive_dim', default=64)
    parser.add_argument('--queue_size', default=32, help='Queue size per emotion label', type=int)

    parser.add_argument('--out', default='/home/czh/nvme1/SportsAnalytics/out/emo_sent_clf',
                        help='Directory to put output')

    parser.add_argument('--meld_dir', default='/home/czh/md0/datasets/MELD/', help='Directory to find MELD data')
    parser.add_argument('--iemocap_dir', default='/home/czh/md0/datasets/IEMOCAP_full_release/formatted/',
                        help='Directory to find IEMOCAP data')
    parser.add_argument('--emory_nlp_dir', default='/home/czh/md0/datasets/emory_nlp_emotion_detection/',
                        help='Directory to find Emory NLP data')

    parser.add_argument('--epochs', default=50, type=int, help='# epochs to train for')
    parser.add_argument('--batch_size', default=1, type=int, help='Batch size to use')
    parser.add_argument('--lr', default=1e-5, type=float, help='Learning rate')
    parser.add_argument('--l2', default=0.01, type=float)
    parser.add_argument('--max_seq_len', default=512, type=int)
    parser.add_argument('--max_n_utterances', default=50, type=int)
    parser.add_argument('--general_dropout_prob', default=0.25, type=float)
    parser.add_argument('--n_grad_accum', default=1, type=int)
    parser.add_argument('--n_warmup_iters', default=250, type=int)

    # logging parms
    parser.add_argument('--seed', default=16, type=int)
    parser.add_argument('--ckpt_file', default=None)
    parser.add_argument('--ckpt_file_tmplt', default='model_{}e.pt')
    parser.add_argument('--print_every', default=1, type=int)
    parser.add_argument('--save_model_every', default=5, type=int)
    parser.add_argument('--log_every', default=15, type=int)
    parser.add_argument('--summary_every', default=3, type=int)
    parser.add_argument('--dev_every', default=1, type=int)
    parser.add_argument('--grad_summary', default=False, type=str2bool)
    parser.add_argument('--grad_summary_every', default=5, type=int)
    parser.add_argument('--arg_out_file', default='args.txt', help='File to write cli args to')

    # hardware parms
    parser.add_argument('--gpus', default=[0], help='Which GPUs to use', type=int, nargs='+')
    parser.add_argument('--port', default='12345', help='Port to use for DDP')
    parser.add_argument('--on_cpu', default=False, type=str2bool)
    parser.add_argument('--n_data_workers', default=3, help='# threads used to fetch data', type=int)
    parser.add_argument('--n_data_workers_else', default=3, help='# threads used to fetch data', type=int)

    args = parser.parse_args()
    args.world_size = len(args.gpus)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.random.manual_seed(args.seed)

    run_modes = []

    if args.train:
        run_modes.append('train')
        if args.given_model_id is not None:
            curr_time = args.given_model_id
        else:
            curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        print('*' * len('* Model ID: {} *'.format(curr_time)))
        print('* Model ID: {} *'.format(curr_time))
        print('*' * len('* Model ID: {} *'.format(curr_time)))

        args.out = os.path.join(args.out, curr_time)
        os.makedirs(args.out)

        args.tb_dir = os.path.join(args.out, 'tb_dir')
        if not os.path.exists(args.tb_dir):
            os.makedirs(args.tb_dir)

        args.model_save_dir = os.path.join(args.out, 'models')
        os.makedirs(args.model_save_dir)

        args.model_log_dir = os.path.join(args.out, 'logs')
        os.makedirs(args.model_log_dir)

        args.arg_out_file = os.path.join(args.out, args.arg_out_file)
        args_d = vars(args)

        with open(args.arg_out_file, 'w+') as f:
            f.write(json.dumps(args_d, indent=2))
    elif args.pt:
        run_modes.append('pt')
        if args.given_model_id is not None:
            curr_time = args.given_model_id
        else:
            curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        print('*' * len('* Model ID: {} *'.format(curr_time)))
        print('* Model ID: {} *'.format(curr_time))
        print('*' * len('* Model ID: {} *'.format(curr_time)))

        args.out = os.path.join(args.out, curr_time)
        os.makedirs(args.out)

        args.tb_dir = os.path.join(args.out, 'tb_dir')
        if not os.path.exists(args.tb_dir):
            os.makedirs(args.tb_dir)

        args.model_save_dir = os.path.join(args.out, 'models')
        os.makedirs(args.model_save_dir)

        args.model_log_dir = os.path.join(args.out, 'logs')
        os.makedirs(args.model_log_dir)

        args.arg_out_file = os.path.join(args.out, args.arg_out_file)
        args_d = vars(args)

        with open(args.arg_out_file, 'w+') as f:
            f.write(json.dumps(args_d, indent=2))
    if args.pt_test:
        run_modes.append('pt_test')

        if not args.pt:
            curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
            print('*' * len('* Model ID: {} *'.format(curr_time)))
            print('* Model ID: {} *'.format(curr_time))
            print('*' * len('* Model ID: {} *'.format(curr_time)))

            args.out = os.path.join(args.out, curr_time)
            os.makedirs(args.out)

            args.tb_dir = os.path.join(args.out, 'tb_dir')
            if not os.path.exists(args.tb_dir):
                os.makedirs(args.tb_dir)

            args.model_save_dir = os.path.join(args.out, 'models')
            os.makedirs(args.model_save_dir)

            args.model_log_dir = os.path.join(args.out, 'logs')
            os.makedirs(args.model_log_dir)

            args.arg_out_file = os.path.join(args.out, args.arg_out_file)
            args_d = vars(args)

            with open(args.arg_out_file, 'w+') as f:
                f.write(json.dumps(args_d, indent=2))

    if args.dev:
        run_modes.append('dev')
    if args.test:
        run_modes.append('test')
    if args.inspect:
        run_modes.append('inspect')

    if (args.dev or args.test or args.inspect) and not args.train:
        args.out = os.path.dirname(args.ckpt_file)
        if os.path.basename(args.out) == 'models':
            args.out = os.path.dirname(args.out)
        print('args.out: {}'.format(args.out))

        args.tb_dir = os.path.join(args.out, 'tb_dir')
        if not os.path.exists(args.tb_dir):
            os.makedirs(args.tb_dir)

        args.model_save_dir = os.path.join(args.out, 'models')
        if not os.path.exists(args.model_save_dir):
            os.makedirs(args.model_save_dir)

        args.model_log_dir = os.path.join(args.out, 'logs')
        if not os.path.exists(args.model_log_dir):
            os.makedirs(args.model_log_dir)

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = args.port
    for mode in run_modes:
        print('Creating {} distributed models for {}...'.format(len(args.gpus), mode))
        mp.spawn(EmoSentRunner, nprocs=len(args.gpus), args=(mode, args))

    print('All done :)')

