import argparse
import logging
import os
import pdb
import sys
import time

import torch
import ujson as json

from utils_train import (
        train,
        run_model_batch_frozen,
        run_model_batch_unfrozen,
        init_default_device,
        set_seed,
        do_final_evaluation,
        load_and_cache_examples,
        )
import utils_common
from utils_common import TaskType, get_task_type, get_task_processor, MODEL_CLASSES
from ActiveLearningSampler import ACTIVE_LEARNING_CLASSES
from augment_datafiles_with_cls import convert_dataset_to_frozencls

import transformers

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter

logger = logging.getLogger(__name__)

def parse_cli_args():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
    )
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES[TaskType.MULTIPLE_CHOICE].keys()),
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list from https://huggingface.co/transformers/pretrained_models.html",
    )
    parser.add_argument(
        "--task_name",
        default=None,
        type=str,
        required=True,
        help="The name of the task to train selected in the list: " + ", ".join(utils_common.get_all_task_keys()),
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )

    # Other parameters
    parser.add_argument(
        "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
    )
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help="Where do you want to store the pre-trained models downloaded from s3",
    )
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help="The maximum total input sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.",
    )
    parser.add_argument(
        "--oversize_example_method",
        default='truncate',
        choices=['truncate', 'prune'],
        help="How to handle examples that are longer than max_seq_length.  truncate=truncate the example, prune=remove the example from the dataset",
    )
    parser.add_argument(
        "--num_train_retries",
        default=1,
        type=int,
        help="Number of times to redo training on each AL dataset.  The best one (based on dev acc) is used to sample the next batch.",
    )
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
    parser.add_argument("--do_test", action="store_true", help="Whether to run test on the test set")
    parser.add_argument(
        "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
    )
    parser.add_argument(
        "--evaluate_on_train_set", action="store_true", help="Run an additional evaluation on the train set."
    )
    parser.add_argument(
        "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
    )

    parser.add_argument(
        "--model_parallel", action="store_true", help="Set this flag if you want to use model parallelism (only works for some models, like t5; raises error if not supported)."
    )
    parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument(
        "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    #parser.add_argument("--hidden_dropout_override", default=None, type=float, help="Hidden dropout probability for model.")
    #parser.add_argument("--attn_dropout_override", default=None, type=float, help="Attention dropout probability for model.")
    parser.add_argument(
        "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
    )
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
    parser.add_argument("--no_lr_decay", action="store_true", help="Don't apply linear lr decay over training")

    parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
    parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
    )
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    parser.add_argument(
        "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
    )
    parser.add_argument(
        "--no_model_save", action="store_true", help="Don't save the final model to disk"
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")

    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument("--run_config", type=str, default=None, help="One of: (1) None (loads from config_name), (2) a path to a JSON file, or (3) a config dict (if args are being parsed from a JSON file)")

    if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
        required_args = [x.dest for x in parser._actions if x.required]
        all_options = [x.dest for x in parser._actions]
        with open(sys.argv[1]) as f:
            config = json.load(f)
        missing_args = []
        # CLI is for convenience; for config files we expect *all* values to be given
        for opt in all_options:
            if opt not in config and opt != 'help':
                missing_args.append(opt)
        unmatched_args = []
        for opt in config.keys():
            if opt not in all_options:
                unmatched_args.append(opt)

        if len(unmatched_args) != 0:
            raise ValueError("Unmatched keys: {}".format(', '.join(unmatched_args)))
        if len(missing_args) != 0:
            raise ValueError("Missing required args: {}".format(', '.join(map(str, missing_args))))

        args = argparse.Namespace(**config)
    else:
        args = parser.parse_args()

    if args.local_rank != -1:
        raise ValueError("Data-parallel training is not currently supported")

    args.task_name = args.task_name.lower()
    args.model_type = args.model_type.lower()

    if args.run_config is None:
        with open(args.config_name) as f:
            _config = json.load(f)
            args.run_config = _config.get('run_config', dict())
    elif isinstance(args.run_config, str):
        with open(args.run_config) as f:
            args.run_config = json.load(f)

    if not isinstance(args.run_config, dict):
        raise ValueError("Failed to parse/load run_config; should be a dict but got {}".format(str(type(args.run_config))))

    return args


def main():
    args = parse_cli_args()

    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )

    # Create output directory if needed
    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)

    init_default_device(args)

    # Setup logging
    logging.root.handlers = []
    logging.basicConfig(
        format='%(asctime)s [%(levelname)s] (%(name)s):  %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
        handlers=[
            logging.FileHandler(os.path.join(args.output_dir, 'logging_output.log')),
            logging.StreamHandler(),
        ],
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        args.device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        raise ValueError("Distributed training not supported for this script")

    config_class, model_class, tokenizer_class = MODEL_CLASSES[get_task_type(args.task_name)][args.model_type]
    num_labels = len(get_task_processor(args.task_name).get_labels())
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=args.task_name,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    #if args.attn_dropout_override is not None:
    #    config.attention_probs_dropout_prob = args.attn_dropout_override
    #if args.hidden_dropout_override is not None:
    #    config.hidden_dropout_prob = args.hidden_dropout_override
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model = model_class.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )

    model.to(args.device)

    logger.info("Training/evaluation parameters %s", args)

    # Training
    if not args.do_train:
        logging.error("No point running without do_train here")
        raise ValueError("No point running without do_train")

    tb_writer = SummaryWriter(os.path.join(args.output_dir, 'tfevents'))

    train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, datasplit='train')

    optimizer_config = args.run_config.get('optimizer', dict())
    freeze_core_weights = optimizer_config.get('freeze_core_weights', False)
    run_model_batch_fn = run_model_batch_unfrozen
    if freeze_core_weights:
        run_model_batch_fn = run_model_batch_frozen
        train_dataset = convert_dataset_to_frozencls(model, train_dataset, get_task_type(args.task_name), batch_size=args.per_gpu_eval_batch_size, device=args.device, model_type=args.model_type)

    al_config = args.run_config['al_config']
    acquisition_start_time = time.perf_counter()
    al_sampler = None
    if al_config['score_method'] == 'alps':
        if args.model_type not in ['bert', 'roberta']:
            raise ValueError("Need an MLM-type model (bert or roberta) to use ALPS")
        model = transformers.AutoModelWithLMHead.from_pretrained(args.model_name_or_path)
        model.to(args.device)
        al_sampler = ACTIVE_LEARNING_CLASSES[al_config['score_method']](al_config, model, train_dataset, args, run_model_batch_fn)
    else:
        al_sampler = ACTIVE_LEARNING_CLASSES[al_config['score_method']](al_config, model, train_dataset, args, run_model_batch_fn)
    num_acquisitions = 1
    acquisition_time = time.perf_counter() - acquisition_start_time

    orig_output_dir = args.output_dir

    while True:
        best_train_results = {'best_dev_acc': -1}
        all_train_results = []

        for train_trial_num in range(args.num_train_retries):
            model = model_class.from_pretrained(
                args.model_name_or_path,
                from_tf=bool(".ckpt" in args.model_name_or_path),
                config=config,
                cache_dir=args.cache_dir if args.cache_dir else None,
            )
            model.to(args.device)
            if args.model_parallel:
                logger.info('parallelizing model')
                model.parallelize()

            #al_sampler.model = model

            args.output_dir = os.path.join(orig_output_dir, 'tmp_train_output_{}_r{}'.format(num_acquisitions, train_trial_num))
            os.mkdir(args.output_dir)

            current_dataset = al_sampler.get_human_labeled_dataset()
            train_start_time = time.perf_counter()
            train_results = train(args, current_dataset, model, tokenizer)
            train_results['time_elapsed'] = time.perf_counter() - train_start_time
            all_train_results.append(train_results)
            logger.info(" global_step = %s, average loss = %s", train_results['global_step'], train_results['train_loss'])

            tb_writer.add_scalar("bysize_avg_train_loss", train_results['train_loss'], len(current_dataset))
            tb_writer.add_scalar("bysize_best_eval_acc", train_results['best_dev_acc'], len(current_dataset))
            tb_writer.add_scalar("bysize_best_eval_train_loss", train_results['best_dev_train_loss'], len(current_dataset))
            tb_writer.add_scalar("bysize_final_train_loss", train_results['final_train_loss'], len(current_dataset))
            tb_writer.add_scalar("bysize_test_acc", train_results['test_acc'], len(current_dataset))
            tb_writer.add_scalar("num_acquisitions", num_acquisitions, len(current_dataset))

            if train_results['best_dev_acc'] > best_train_results['best_dev_acc']:
                logger.info('Got best dev acc ({}).  Updating model to use for AL.'.format(train_results['best_dev_acc']))
                best_train_results = train_results
                al_sampler.model = model

                # Don't need to move the model around if this was the last (or only) trial anyway
                if train_trial_num != (args.num_train_retries-1):
                    al_sampler.model = al_sampler.model.to(torch.device('cpu'))

        with open(os.path.join(orig_output_dir, 'selected_indexes.jsonl'), 'a') as f:
            selected_indexes = [al_sampler.data_source[idx][4].item() for idx in al_sampler.get_human_labeled_indices()]
            entry = {
                    'num_acquisitions': num_acquisitions,
                    'last_acquisition_elapsed': acquisition_time,
                    'labeled_dataset_size': len(current_dataset),
                    'selected_indexes': sorted(selected_indexes),
                    'all_train_results': all_train_results,
                    'bysize_best_eval_train_loss': best_train_results['best_dev_train_loss'],
                    'bysize_final_train_loss': best_train_results['best_dev_train_loss'],
                    'bysize_avg_train_loss': best_train_results['train_loss'],
                    'bysize_best_eval_acc': best_train_results['best_dev_acc'],
                    'bysize_test_acc': best_train_results['test_acc'],
                    'bysize_train_loss': best_train_results['train_loss'],
                }
            f.write(json.dumps(entry) + '\n')

        if len(current_dataset) == al_sampler.get_max_labels():
            break

        if len(current_dataset) > al_sampler.get_max_labels():
            raise ValueError("Somehow exceeded max_labels.  This is not supposed to happen.")

        if getattr(al_sampler.model, 'device', next(iter(al_sampler.model.parameters())).device) != args.device:
            al_sampler.model = al_sampler.model.to(args.device)

        acquisition_start_time = time.perf_counter()
        al_sampler.acquire_batch(min(al_sampler.refill_increment, al_sampler.get_max_labels() - len(al_sampler.get_human_labeled_indices())))
        num_acquisitions += 1
        acquisition_time = time.perf_counter() - acquisition_start_time

    tb_writer.close()

    args.output_dir = orig_output_dir

    logger.info('Saving collected dataset')
    output_dataset_path = al_config.get('output_dataset_path', None)
    if output_dataset_path is not None:
        try:
            torch.save(al_sampler.get_human_labeled_dataset(), output_dataset_path)
        except Exception as e:
            logger.error("Failed to save dataset:")
            logger.exception(e)

    torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
    config.save_pretrained(args.output_dir)

    # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        model_to_save = (
            model.module if hasattr(model, "module") else model
        )  # Take care of distributed/parallel training
        if not args.no_model_save:
            # Save a trained model, configuration and tokenizer using `save_pretrained()`.
            # They can then be reloaded using `from_pretrained()`
            logger.info("Saving model checkpoint to %s", args.output_dir)
            model_to_save.save_pretrained(args.output_dir)
            tokenizer.save_pretrained(args.output_dir)

            # Load a trained model and vocabulary that you have fine-tuned
            model = model_class.from_pretrained(args.output_dir)
            tokenizer = tokenizer_class.from_pretrained(args.output_dir)
            model.to(args.device)

    return do_final_evaluation(args, model, tokenizer)

if __name__ == '__main__':
    try:
        main()
    except:
        logger.exception('Error occurred during execution of main():')
