# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning the library models for multiple choice (Bert, Roberta, XLNet)."""


import argparse
import glob
import logging
import os
import pdb
import random
import sys
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any

import numpy as np
import torch
import ujson as json
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from transformers import (
    WEIGHTS_NAME,
    AdamW,
    get_linear_schedule_with_warmup,
    get_constant_schedule_with_warmup,
)

import utils_common
from utils_common import TaskType, get_task_type, get_task_processor, load_and_cache_examples_for_task, MODEL_CLASSES
from augment_datafiles_with_cls import convert_dataset_to_frozencls
from maxva import MAdam, LaMAdam


try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter


logger = logging.getLogger(__name__)


@dataclass
class BatchObserverInput:
    epoch_num: int
    global_step_num: int
    batch: Any
    loss: Any
    logits: Any
    model: Any


class BatchObserver:
    def __init__(self):
        pass

    def run(self, batchinfo: BatchObserverInput):
        pass


def simple_accuracy(preds, labels):
    return (preds == labels).mean()


def set_seed(args):
    logger.info("Setting global seed to {}".format(args.seed))
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

def make_optimizer_grouped_params(model, param_filters):
    per_param_options = dict()
    for name, _ in model.named_parameters():
        per_param_options[name] = dict()
        for filter_func, options in param_filters:
            if filter_func(name):
                per_param_options[name].update(options)

    # Consolodate params with identical options
    grouped_params = []
    for name, param in model.named_parameters():
        options = per_param_options[name]
        option_group_list = None
        for opts, lst in grouped_params:
            if opts == options:
                option_group_list = lst
                break
        if option_group_list is None:
            option_group_list = []
            grouped_params.append((options, option_group_list))

        option_group_list.append(param)

    optimizer_grouped_params = []
    for options, params in grouped_params:
        options['params'] = params
        optimizer_grouped_params.append(options)

    return optimizer_grouped_params


def run_model_batch_unfrozen(args, model, batch):
    inputs = {
        "input_ids": batch[0],
        "attention_mask": batch[1],
        "token_type_ids": batch[2]
        if args.model_type in ["bert", "xlnet"]
        else None,  # XLM don't use segment_ids
        "labels": batch[3],
    }
    if args.model_type == 't5':
        inputs['choices_input_ids'] = batch[5]

    outputs = model(**inputs)
    loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
    logits = outputs[1]  # model outputs are always tuple in transformers (see doc)

    return loss, logits


def run_model_batch_frozen(args, model, batch):
    task_type = get_task_type(args.task_name)
    num_choices = batch[0].shape[1]
    labels = batch[1]

    pooled_output = batch[0].unsqueeze(1)
    if task_type == TaskType.MULTIPLE_CHOICE:
        pooled_output = pooled_output.view(-1, pooled_output.shape[-1])
        pooled_output = model.dropout(pooled_output)

    logits = model.classifier(pooled_output)

    if args.model_type == 'bert':
        logits = logits.squeeze(1)

    if task_type == TaskType.MULTIPLE_CHOICE:
        reshaped_logits = logits.view(-1, num_choices)
        logits = reshaped_logits

    loss = None
    if labels is not None:
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(logits, labels)

    return loss, logits


def init_optimizer(args, model, optimizer_config, t_total):
    # Prepare optimizer and schedule (linear warmup and decay)
    freeze_core_weights = optimizer_config.get('freeze_core_weights', False)

    no_decay = ['bias', 'LayerNorm.weight']
    param_filters = [
        (lambda pname: any(nd in pname for nd in no_decay), {'weight_decay': 0.0}),
        (lambda pname: not any(nd in pname for nd in no_decay), {'weight_decay': args.weight_decay}),
    ]

    if freeze_core_weights:
        if args.model_type not in ['bert', 'roberta']:
            raise ValueError("Don't know how to freeze weights for model type: {}".format(args.model_type))
        encoder_strings = ['roberta.']
        if args.model_type == 'bert':
            encoder_strings = ['encoder.']
        param_filters.append((lambda pname: any(nd in pname for nd in encoder_strings), {'lr': 0.0}))

    optimizer_grouped_parameters = make_optimizer_grouped_params(model, param_filters)

    if freeze_core_weights:
        new_optgp = []
        for gp in optimizer_grouped_parameters:
            if ('lr' not in gp) or (gp['lr'] != 0.0):
                new_optgp.append(gp)
            else:
                # Turn off grad calculation for params not being updated
                for p in gp['params']:
                    p.requires_grad = False
        optimizer_grouped_parameters = new_optgp

    print('grouped param stats')
    for gp in optimizer_grouped_parameters:
        print(len(gp['params']), {k: v for k, v in gp.items() if k != 'params'})

    optimizer = None
    optimizer_type = optimizer_config.get('type', 'adamw')
    if optimizer_type == 'adamw':
        adam_betas = (optimizer_config.get('beta1', 0.9), optimizer_config.get('beta2', 0.999))
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon, betas=adam_betas)
    elif optimizer_type == 'maxva_madam':
        optimizer = MAdam(
                optimizer_grouped_parameters,
                lr=args.learning_rate,
                eps=args.adam_epsilon,
                beta1=optimizer_config.get('beta1', 0.9),
                beta2_range=optimizer_config.get('beta2_range', (0.5, 1)),
                adamw=False,
                max_grad_norm=args.max_grad_norm
            )
    elif optimizer_type == 'maxva_lamadam':
        optimizer = LaMAdam(
                optimizer_grouped_parameters,
                lr=args.learning_rate,
                eps=args.adam_epsilon,
                beta=optimizer_config.get('beta_max', 0.98),
                beta_min=optimizer_config.get('beta_min', 0.5),
            )
    else:
        raise ValueError("Unknown optimizer {}".format(optimizer_type))

    scheduler = None
    if args.no_lr_decay:
        scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps)
    else:
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
        )

    amp = None
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    return model, optimizer, scheduler, amp


def clone_cpu_model_state(model):
    return OrderedDict((k, v.clone().cpu().detach()) for k, v in model.state_dict().items())


def train(args, train_dataset, model, tokenizer, batch_observer=None):
    """ Train the model """

    if batch_observer is None:
        batch_observer = BatchObserver()

    optimizer_config = args.run_config.get('optimizer', dict())
    freeze_core_weights = optimizer_config.get('freeze_core_weights', False)

    run_model_batch_fn = run_model_batch_unfrozen
    if freeze_core_weights:
        run_model_batch_fn = run_model_batch_frozen
        train_dataset = convert_dataset_to_frozencls(model, train_dataset, get_task_type(args.task_name), batch_size=args.per_gpu_eval_batch_size, device=args.device, model_type=args.model_type)

    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter(os.path.join(args.output_dir, 'tfevents'))

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    sampler_config = args.run_config.get('train_sampler', {'type': 'default'})
    if sampler_config['type'] == 'default':
        sampler_config['type'] = 'random' if args.local_rank == -1 else 'distributed'

    train_sampler = None
    if sampler_config['type'] == 'random':
        train_sampler = RandomSampler(train_dataset)
    elif sampler_config['type'] == 'distributed':
        train_sampler = DistributedSampler(train_dataset)
    else:
        raise ValueError("Unknown sampler type")
    logger.info('Using sampler type {}'.format(type(train_sampler)))
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    #args.learning_rate = 5e-4
    #args.freeze_core_weights = True
    #model, optimizer, scheduler, amp = init_optimizer(args, model, optimizer_config, t_total-1000)
    model, optimizer, scheduler, amp = init_optimizer(args, model, optimizer_config, t_total)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )


    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    best_model_state = clone_cpu_model_state(model)
    best_dev_acc = 0.0
    best_dev_train_loss = None
    final_train_loss = None
    test_acc = None
    best_steps = 0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
    for epoch_num in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        batch_tr_loss = 0
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            loss, logits = run_model_batch_fn(args, model, batch)

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            batch_observer.run(BatchObserverInput(epoch_num=epoch_num, global_step_num=global_step, batch=batch, loss=loss, logits=logits, model=model))

            tr_loss += loss.item()
            batch_tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                #if global_step < t_total-1000:
                #    for param in model.roberta.parameters():
                #        if param.grad is None:
                #            continue
                #        param.grad.data.fill_(0)
                #elif global_step == t_total-1000:
                #    args.learning_rate = 2e-5
                #    args.freeze_core_weights = False
                #    model, optimizer, scheduler, amp = init_optimizer(args, model, optimizer_config, 1000)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    final_train_loss = (tr_loss - logging_loss) / args.logging_steps
                    # Log metrics
                    if (
                        args.local_rank == -1 and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                        if results["eval_acc"] > best_dev_acc:
                            best_dev_acc = results["eval_acc"]
                            best_dev_train_loss = (tr_loss - logging_loss) / args.logging_steps
                            best_steps = global_step
                            best_model_state = clone_cpu_model_state(model)
                            with open(os.path.join(args.output_dir, 'best_dev_acc.txt'), 'w') as f:
                                f.write(str(best_dev_acc) + '\n')
                            if args.do_test:
                                results_test = evaluate(args, model, tokenizer, datasplit='test')
                                test_acc = results_test['eval_acc']
                                for key, value in results_test.items():
                                    tb_writer.add_scalar("test_{}".format(key), value, global_step)
                                logger.info(
                                    "test acc: %s, loss: %s, global steps: %s",
                                    str(results_test["eval_acc"]),
                                    str(results_test["eval_loss"]),
                                    str(global_step),
                                )
                        if args.evaluate_on_train_set:
                            results = evaluate(args, model, tokenizer, datasplit='train')
                            for key, value in results.items():
                                tb_writer.add_scalar("eval_on_train_{}".format(key), value, global_step)
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
                    logger.info(
                        "Average loss: %s at global step: %s",
                        str((tr_loss - logging_loss) / args.logging_steps),
                        str(global_step),
                    )
                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_vocabulary(output_dir)
                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    train_results = {
            'global_step': global_step,
            'train_loss': tr_loss / global_step,
            'best_dev_train_loss': best_dev_train_loss,
            'final_train_loss': final_train_loss,
            'best_steps': best_steps,
            'best_dev_acc': best_dev_acc,
            }
    if test_acc is not None:
        train_results['test_acc'] = test_acc

    model.load_state_dict(best_model_state)

    return train_results

eval_datasets = dict()
def evaluate(args, model, tokenizer, prefix="", datasplit='dev'):
    global eval_datasets
    eval_task_names = (args.task_name,)
    eval_outputs_dirs = (args.output_dir,)

    optimizer_config = args.run_config.get('optimizer', dict())
    freeze_core_weights = optimizer_config.get('freeze_core_weights', False)
    run_model_batch_fn = run_model_batch_unfrozen
    if freeze_core_weights:
        run_model_batch_fn = run_model_batch_frozen

    results = {}
    for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
        if (eval_task, datasplit) not in eval_datasets:
            eval_datasets[(eval_task, datasplit)] = load_and_cache_examples(args, eval_task, tokenizer, datasplit=datasplit)
            if freeze_core_weights:
                eval_datasets[(eval_task, datasplit)] = convert_dataset_to_frozencls(model, eval_datasets[(eval_task, datasplit)], get_task_type(args.task_name), batch_size=args.per_gpu_eval_batch_size, device=args.device, model_type=args.model_type)
        eval_dataset = eval_datasets[(eval_task, datasplit)]

        if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(eval_output_dir)

        args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
        # Note that DistributedSampler samples randomly
        eval_sampler = SequentialSampler(eval_dataset)
        eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

        # multi-gpu evaluate
        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Eval!
        logger.info("***** Running evaluation {} *****".format(prefix))
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  Batch size = %d", args.eval_batch_size)
        eval_loss = 0.0
        nb_eval_steps = 0
        preds = None
        out_label_ids = None
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            model.eval()
            batch = tuple(t.to(args.device) for t in batch)

            with torch.no_grad():
                inputs = {'labels': batch[1 if freeze_core_weights else 3]}
                tmp_eval_loss, logits = run_model_batch_fn(args, model, batch)

                eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1
            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = inputs["labels"].detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)

        eval_loss = eval_loss / nb_eval_steps
        preds = np.argmax(preds, axis=1)
        acc = simple_accuracy(preds, out_label_ids)
        result = {"eval_acc": acc, "eval_loss": eval_loss}
        results.update(result)

        output_eval_file = os.path.join(eval_output_dir, "is_test_" + str((datasplit=='test')).lower() + "_eval_results.txt")

        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results {} *****".format(str(prefix) + " is test:" + str((datasplit=='test'))))
            writer.write("model           =%s\n" % str(args.model_name_or_path))
            writer.write(
                "total batch size=%d\n"
                % (
                    args.per_gpu_train_batch_size
                    * args.gradient_accumulation_steps
                    * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)
                )
            )
            writer.write("train num epochs=%d\n" % args.num_train_epochs)
            writer.write("fp16            =%s\n" % args.fp16)
            writer.write("max seq length  =%d\n" % args.max_seq_length)
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
    return results

def load_and_cache_examples(args, task, tokenizer, datasplit='train'):
    processor = get_task_processor(task)
    filepath = os.path.join(args.data_dir, processor.get_standard_datasplit_filename(datasplit))

    return load_and_cache_examples_for_task(task, args, filepath, processor, tokenizer, datasplit=datasplit)

def init_default_device(args):
    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        #args.n_gpu = torch.cuda.device_count()
        args.n_gpu = 1
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
    args.device = device
    return device

def do_final_evaluation(args, model, tokenizer):
    config_class, model_class, tokenizer_class = MODEL_CLASSES[get_task_type(args.task_name)][args.model_type]
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        if not args.do_train:
            args.output_dir = args.model_name_or_path
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
            )
            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""

            if not (args.do_train and args.no_model_save and (not args.eval_all_checkpoints)):
                model = model_class.from_pretrained(checkpoint)
                model.to(args.device)
            result = evaluate(args, model, tokenizer, prefix=prefix)
            result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
            results.update(result)

    if args.do_test and args.local_rank in [-1, 0]:
        if not args.do_train:
            args.output_dir = args.model_name_or_path
        checkpoints = [args.output_dir]
        # if args.eval_all_checkpoints: # can not use this to do test!!
        #     checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
        #     logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""

            if not (args.do_train and args.no_model_save and (not args.eval_all_checkpoints)):
                model = model_class.from_pretrained(checkpoint)
                model.to(args.device)
            result = evaluate(args, model, tokenizer, prefix=prefix, datasplit='test')
            result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
            results.update(result)
    return results

