#! /usr/bin/env python3
# coding=utf-8


import torch
from torch.utils.data import DataLoader, RandomSampler

from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    AdamW,
    get_linear_schedule_with_warmup,
    set_seed,
)

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter
from datasets import load_dataset

import random
import ast
import multiprocessing
import ptan
import numpy as np

from configs.il_config import ILConfig
from rl_lib.rl import decode_chain_sampling, decode_chain_argmax, run_eval
from rl_lib.utils import *
from rl_lib.reward import compute_new_reward

logger = logging.getLogger(__name__)

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter

SAFE_LOSS = 0.01

# set_global_logging_level(logging.ERROR, ["transformers", "nlp", "torch", "tensorflow", "tensorboard", "wandb"])
logger = logging.getLogger(__name__)

config = ILConfig()

if __name__ == '__main__':

    ### Setup Part ###
    # setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if config.local_rank in [-1, 0] else logging.WARN,
    )

    # setup tensorboard setting
    writer = SummaryWriter(comment="-" + config.log_name)

    logger.info("Training/evaluation parameters %s", config)

    if config.should_continue:
        sorted_checkpoints = sorted_ckpts(config.output_dir)
        if len(sorted_checkpoints) == 0:
            raise ValueError(
                "Used --should_continue but no checkpoint was found in --output_dir.")
        else:
            config.model_name_or_path = sorted_checkpoints[-1]

    if (
            os.path.exists(config.output_dir)
            and os.listdir(config.output_dir)
            and not config.overwrite_output_dir
            and not config.should_continue
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                config.output_dir
            )
        )

    # setup cuda and gpu
    device = torch.device("cpu")
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, 16-bits training: %s",
        config.local_rank,
        device,
        config.n_gpu,
        config.fp16,
    )

    # setup random seed
    set_seed(config.seed)

    ### Model ###
    lm_config = AutoConfig.from_pretrained(config.config_name,
                                           cache_dir=os.path.join(config.output_dir, config.cache_dir))
    # tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name,
    #                                           cache_dir=os.path.join(config.output_dir, config.cache_dir))
    tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')
    model = AutoModelForSeq2SeqLM.from_pretrained(
        config.model_name_or_path,
        from_tf=False,
        config=lm_config,
        cache_dir=os.path.join(config.output_dir, config.cache_dir),
    )
    model.to(device)
    # model.resize_token_embeddings(len(tokenizer) + 1)

    ### Dataset ###
    # init
    dataset = load_dataset('csv',
                           data_files={'train': config.train_data_path,
                                       'eval': config.eval_data_path},
                           cache_dir=os.path.join(config.output_dir, config.cache_dir))

    train_dataset, eval_dataset = dataset['train'], dataset['eval']

    test_dataset = load_dataset('csv',
                                data_files={'test': config.test_data_path},
                                cache_dir=os.path.join(config.output_dir, config.cache_dir))['test']

    def preprocess_function(examples):
        model_inputs = {}
        input_ids = tokenizer.encode(examples['sent1'],
                                     return_tensors='pt',
                                     max_length=256,
                                     padding='max_length',
                                     truncation=True)
        # input_ids = [[(l if l != tokenizer.pad_token_id else -100) for l in input_ids]]

        labels = [tokenizer.encode(sent,
                                   return_tensors='pt',
                                   max_length=256,
                                   padding='max_length',
                                   truncation=True) for sent in ast.literal_eval(examples['sent2'])]

        model_inputs['input_ids'] = input_ids
        model_inputs['input_tokens'] = examples['sent1']
        model_inputs['labels_ids'] = labels
        model_inputs['labels_tokens'] = ast.literal_eval(examples['sent2'])
        # model_inputs['spice_score'] = ast.literal_eval(examples['spice_score'])
        return model_inputs

    def test_process(examples):
        model_inputs = {}
        input_ids = tokenizer.encode(examples['sent1'],
                                     return_tensors='pt',
                                     max_length=256,
                                     padding='max_length',
                                     truncation=True)

        labels = tokenizer.encode(examples['sent2'],
                                  return_tensors='pt',
                                  max_length=256,
                                  padding='max_length',
                                  truncation=True)

        model_inputs['input_ids'] = input_ids
        model_inputs['input_tokens'] = examples['sent1']
        model_inputs['labels_ids'] = labels
        model_inputs['labels_tokens'] = examples['sent2']
        return model_inputs

    if config.max_train_samples is not None:
        train_dataset = train_dataset.select(range(config.max_train_samples))
        test_dataset = test_dataset.select(range(50))  # use top 50 to eval

    train_dataset = train_dataset.map(
        preprocess_function,
        num_proc=multiprocessing.cpu_count(),
        load_from_cache_file=not config.overwrite_cache,
    )

    test_dataset = test_dataset.map(
        test_process,
        num_proc=multiprocessing.cpu_count(),
        load_from_cache_file=not config.overwrite_cache,
    )

    def rl_collator(batch):
        return {
            'input_ids': [x['input_ids'] for x in batch],
            'labels_ids': [x['labels_ids'] for x in batch],
            'input_tokens': [x['input_tokens'] for x in batch],
            'labels_tokens': [x['labels_tokens'] for x in batch],
        }

    def test_collator(batch):
        return {
            'input_ids': [x['input_ids'] for x in batch],
            'labels_ids': [x['labels_ids'] for x in batch],
            'input_tokens': [x['input_tokens'] for x in batch],
            'labels_tokens': [x['labels_tokens'] for x in batch],
        }

    train_dataloader = DataLoader(train_dataset,
                                  sampler=RandomSampler(train_dataset),
                                  collate_fn=rl_collator,
                                  batch_size=config.per_gpu_train_batch_size,
                                  drop_last=True)

    test_dataloader = DataLoader(test_dataset,
                                 sampler=RandomSampler(test_dataset),
                                 collate_fn=test_collator,
                                 batch_size=config.per_gpu_eval_batch_size,
                                 drop_last=True)

    ### Optimizer and Schedule (linear warmup and decay) ###
    # total steps
    if config.max_steps > 0:
        t_total = config.max_steps
        config.num_train_epochs = config.max_steps // (
            len(train_dataloader) // config.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader) // config.gradient_accumulation_steps * config.num_train_epochs

    # optimizer and scheduler
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": config.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(
            nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=config.learning_rate, eps=config.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=config.warmup_steps, num_training_steps=t_total
    )

    # check whether we have saved optimizer and scheduler
    if (
            config.model_name_or_path
            and os.path.isfile(os.path.join(config.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(config.model_name_or_path, "scheduler.pt"))
    ):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(torch.load(os.path.join(
            config.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(
            config.model_name_or_path, "scheduler.pt")))

    # fp16 initialization
    if config.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(
            model, optimizer, opt_level=config.fp16_opt_level)

    ### Logging before Training ###
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataloader))
    logger.info("  Num Epochs = %d", config.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                config.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        config.per_gpu_train_batch_size
        * config.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if config.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                config.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0

    # check if continuing training from a checkpoint
    if config.model_name_or_path and os.path.exists(config.model_name_or_path):
        try:
            # set global_step to global_step of last saved checkpoint from model path
            checkpoint_suffix = config.model_name_or_path.split(
                "-")[-1].split("/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             config.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // config.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step")
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info(
                "  Continuing training from global step %d", global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    with ptan.common.utils.TBMeanTracker(writer, 5) as tb_tracker:
        batch_idx = 0
        best_reward = None

        for epoch in range(int(config.num_train_epochs)):

            dial_shown = False

            # make sure these vars are cleared each epoch
            total_samples = 0
            skipped_samples = 0
            rewards_argmax = []
            rewards_sample = []

            for batch in train_dataloader:

                # batch is a dict:
                #
                # input_ids: (batch, token ids)
                # input_tokens: (batch, string)
                # labels_ids: (batch, list of token ids)
                # labels_tokens: (batch, list of strings)
                # spice_scores: (batch, list of float scores)

                # skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                batch_idx += 1

                logger.info("Epoch: %d; Batch: %d", epoch + 1, batch_idx)

                optimizer.zero_grad()

                rl_policies = []
                rl_actions = []
                rl_advantages = []

                for idx, labels_tokens in enumerate(batch['labels_tokens']):
                    total_samples += 1
                    curr_input_ids = torch.tensor(
                        batch['input_ids'][idx], device=device)
                    argmax_token_probs, argmax_actions, argmax_tokens = decode_chain_argmax(model,
                                                                                            tokenizer,
                                                                                            curr_input_ids,
                                                                                            config)
                    argmax_reward = compute_new_reward(
                        argmax_tokens, labels_tokens, device)
                    rewards_argmax.append(argmax_reward)

                    if not config.disable_skip:
                        if argmax_reward < SAFE_LOSS or argmax_reward != argmax_reward:
                            skipped_samples += 1
                            print("skipped")
                            continue

                    if total_samples % 500 == 0:
                        logger.info("Input: %s", batch['input_tokens'][idx])
                        ref = " ~~|~~ ".join(random.sample(batch['labels_tokens'][idx],
                                                           k=min(3, len(batch['labels_tokens'][idx]))))
                        logger.info("Refer: %s", ref)
                        logger.info("Argmax: %s, reward=%.4f",
                                    argmax_tokens, argmax_reward)

                    for _ in range(config.rl_sample_count):  # DEFAULT 4
                        sample_token_probs, sample_actions, sample_tokens = decode_chain_sampling(model,
                                                                                                  tokenizer,
                                                                                                  curr_input_ids,
                                                                                                  config)
                        sample_reward = compute_new_reward(
                            sample_tokens, labels_tokens, device)

                        rl_policies.append(sample_token_probs)
                        rl_actions.extend(sample_actions)
                        adv = sample_reward - argmax_reward

                        # adv is broadcast to all tokens in the sequence
                        rl_advantages.extend([adv] * len(sample_actions))
                        rewards_sample.append(sample_reward)

                if not rl_policies:
                    continue

                policies_v = torch.cat(rl_policies).requires_grad_(True)
                # print(policies_v.size())
                actions_t = torch.LongTensor(rl_actions).to(device)
                adv_v = torch.FloatTensor(rl_advantages).to(
                    device).requires_grad_(True)
                # log_prob_v = F.softmax(policies_v, dim=1)  # TODO: double check this. already log?
                log_prob_v = policies_v  #

                # print("probs sum across:", torch.sum(log_prob_v, dim=1))

                # numpy like index
                lp_a = log_prob_v[range(len(rl_actions)), actions_t]

                log_prob_actions_v = adv_v[:, None] * lp_a
                loss_v = log_prob_actions_v.mean()

                if loss_v != loss_v:
                    continue

                loss_v.backward()
                optimizer.step()

                tb_tracker.track("advantage", adv_v, batch_idx)
                tb_tracker.track("loss_total", loss_v, batch_idx)

                logger.info("Loss=%.6f", loss_v.item())

                ### run eval ###
                if batch_idx % 200 == 0:
                    run_eval(test_dataloader, model, tokenizer,
                             batch_idx, device, config)

                    # writer.add_scalar("BERT Score during evaluation", mean_reward_eval, batch_idx)
                    writer.add_scalar("reward_argmax", np.mean(
                        rewards_argmax), batch_idx)
                    writer.add_scalar("reward_sample", np.mean(
                        rewards_sample), batch_idx)
                    writer.add_scalar(
                        "skipped_samples", skipped_samples / total_samples, batch_idx)
                    writer.add_scalar("epoch", batch_idx, batch_idx)

                    # logger.info("Batch %d, average reward of EVAL: %.3f", batch, mean_reward_eval)

            # if best_reward is None or best_reward < mean_reward_eval:
            #     best_reward = mean_reward_eval
            #     logger.info("Best reward updated: %.4f", best_reward)

            # save model checkpoint every epoch
            output_dir = os.path.join(
                config.output_dir, "{}-{}".format("rl_checkpoint", global_step))
            os.makedirs(output_dir, exist_ok=True)
            model_to_save = model.module if hasattr(model, "module") else model
            model_to_save.save_pretrained(output_dir)
            tokenizer.save_pretrained(output_dir)
            torch.save(config, os.path.join(output_dir, "training_args.bin"))
            logger.info("Saving model checkpoint to %s", output_dir)

            # delete unnecessary checkpoints
            rotate_ckpts(config.output_dir,
                         config.save_total_limit, "rl_checkpoint")

            torch.save(optimizer.state_dict(), os.path.join(
                output_dir, "optimizer.pt"))
            torch.save(scheduler.state_dict(), os.path.join(
                output_dir, "scheduler.pt"))
            logger.info(
                "Saving optimizer and scheduler states to %s", output_dir)

    writer.close()
