### This code is based on transformers/examples/seq2seq/finetune.py ###
### Link: https://github.com/huggingface/transformers/tree/master/examples/seq2seq ###

import argparse
import glob
import logging
import os
import time
import sys
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pytorch_lightning as pl

import torch
from torch.utils.data import DataLoader

from transformers import get_linear_schedule_with_warmup

from huggingface.lightning_base import BaseTransformer, add_generic_args, generic_train
from huggingface.callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
from huggingface.utils import (
    use_task_specific_params,
    SummarizationDataset,
    lmap,
    flatten_list,
    pickle_save,
    save_git_info,
    freeze_params,
    calculate_rouge,
    get_git_info,
    ROUGE_KEYS,
    pickle_load,
    get_led, 
    get_red,
    get_bart,
)

from pytorch_lightning.utilities.cloud_io import load as pl_load

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.WARNING)

class SummarizationModule(BaseTransformer):
    mode = "summarization"
    loss_names = ["loss"]

    def __init__(self, hparams, model=None, tokenizer=None, pad_to_length=False, **kwargs):
        super().__init__(
            hparams, 
            num_labels=None, 
            mode=self.mode, 
            model=model, 
            tokenizer=tokenizer,
            **kwargs
        )
        use_task_specific_params(self.model, "summarization")
        save_git_info(self.hparams.output_dir)
        self.metrics_save_path = Path(self.output_dir) / "metrics.pkl"
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
        self.step_count = 0
        self.metrics = {"train": [], "val": [], "test": []}

        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir,
            max_source_length=self.hparams.max_source_length,
            prefix=self.model.config.prefix or "",
        )
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}

        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"

        if self.hparams.freeze_embeds:
            self.freeze_embeds()
        if self.hparams.freeze_encoder:
            freeze_params(self.model.model.encoder)  # TODO: this will break for t5
        self.hparams.git_sha = get_git_info()["repo_sha"]
        self.num_workers = hparams.num_workers
        self.pad_to_length = pad_to_length
        self.use_cache = False if hparams.grad_checkpointing else True

    def freeze_embeds(self):
        """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
        if self.model.config.model_type == "bart":
            freeze_params(self.model.model.shared)
            for d in [self.model.model.encoder, self.model.model.decoder]:
                freeze_params(d.embed_positions)
                freeze_params(d.embed_tokens)
        else:
            freeze_params(self.model.shared)
            for d in [self.model.encoder, self.model.decoder]:
                freeze_params(d.embed_tokens)

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    def ids_to_clean_text(self, generated_ids: List[int]):
        gen_text = self.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
        return lmap(str.strip, gen_text)

    def _step(self, batch: dict) -> Tuple:
        pad_token_id = self.tokenizer.pad_token_id
        source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
        y_ids = y[:, :-1].contiguous()
        lm_labels = y[:, 1:].clone()
        lm_labels[y[:, 1:] == pad_token_id] = -100
        outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, labels=lm_labels,)
        loss = outputs[0]
        return (loss,)

    def training_step(self, batch, batch_idx) -> Dict:
        loss_tensors = self._step(batch)
        logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
        return {"loss": loss_tensors[0], "log": logs}

    def validation_step(self, batch, batch_idx) -> Dict:
        return self._generative_step(batch)

    def validation_epoch_end(self, outputs, prefix="val") -> Dict:
        self.step_count += 1
        losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
        loss = losses["loss"]
        rouges = {k: np.array([x[k] for x in outputs]).mean() for k in ROUGE_KEYS + ["gen_time", "summ_len"]}
        rouge_tensor: torch.FloatTensor = torch.tensor(rouges["rouge2"]).type_as(loss)
        rouges.update({k: v.item() for k, v in losses.items()})
        losses.update(rouges)
        metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
        metrics["step_count"] = self.step_count
        self.save_metrics(metrics, prefix)  # writes to self.metrics_save_path
        preds = flatten_list([x["preds"] for x in outputs])
        target = flatten_list([x["target"] for x in outputs])
        return {
            "log": metrics,
            "preds": preds, 
            f"{prefix}_loss": loss, 
            f"{prefix}_rouge": rouge_tensor,
            "target": target,
        }

    def save_metrics(self, metrics, prefix) -> None:
        self.metrics[prefix].append(metrics)
        pickle_save(self.metrics, self.metrics_save_path)

    def _generative_step(self, batch: dict) -> dict:
        pad_token_id = self.tokenizer.pad_token_id if not self.pad_to_length else -100
        source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
        t0 = time.time()
        # self.model.config.num_beams = 2 if not self.hparams.longbart_base_model or 'bart-tiny-random' in self.hparams.longbart_base_model else self.model.config.num_beams
        generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=self.use_cache,)
        gen_time = time.time() - t0 / source_ids.shape[0]
        preds = self.ids_to_clean_text(generated_ids)
        target = self.ids_to_clean_text(y)
        loss_tensors = self._step(batch)
        base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
        rouge: Dict = calculate_rouge(preds, target)
        summ_len = np.mean(lmap(len, generated_ids))
        base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge)
        return base_metrics

    def test_step(self, batch, batch_idx):
        return self._generative_step(batch)

    def test_epoch_end(self, outputs):
        return self.validation_epoch_end(outputs, prefix="test")

    def get_dataset(self, type_path) -> SummarizationDataset:
        n_obs = self.n_obs[type_path]
        max_target_length = self.target_lens[type_path]
        dataset = SummarizationDataset(
            self.tokenizer,
            type_path=type_path,
            n_obs=n_obs,
            max_target_length=max_target_length,
            pad_to_length=self.pad_to_length,
            **self.dataset_kwargs,
        )
        return dataset

    def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
        dataset = self.get_dataset(type_path)
        sampler = None
        if self.hparams.sortish_sampler and type_path == "train":
            assert self.hparams.gpus <= 1  # TODO: assert earlier
            sampler = dataset.make_sortish_sampler(batch_size)
            shuffle = False

        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            collate_fn=dataset.collate_fn,
            shuffle=shuffle,
            num_workers=self.num_workers,
            sampler=sampler,
        )
        return dataloader

    def train_dataloader(self) -> DataLoader:
        dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
        t_total = (
            (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
            // self.hparams.gradient_accumulation_steps
            * float(self.hparams.num_train_epochs)
        )
        scheduler = get_linear_schedule_with_warmup(
            self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
        )
        self.lr_scheduler = scheduler
        return dataloader

    def val_dataloader(self) -> DataLoader:
        return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)

    def test_dataloader(self) -> DataLoader:
        return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        BaseTransformer.add_model_specific_args(parser, root_dir)
        add_generic_args(parser, root_dir)
        parser.add_argument(
            "--max_source_length",
            default=1024,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--max_target_length",
            default=56,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--val_max_target_length",
            default=142,  # these defaults are optimized for CNNDM. For xsum, see README.md.
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--test_max_target_length",
            default=142,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--data_dir",
            type=str,
            required=True,
            help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
        )
        parser.add_argument(
            "--attn_window",
            default=512,
            type=int,
            required=False,
            help="The attention window size (for longbart only).",
        )
        parser.add_argument(
            "--longbart_base_model",
            default='facebook/bart-large-cnn',
            type=str,
            required=False,
            help="The base model (for longbart only).",
        )
        parser.add_argument(
            "--model_variant",
            type=str,
            required=False,
            choices=['longbart', 'led', 'reformer_encoder_decoder', 'bart'],
            help="The base model (for longbart only).",
        )
        parser.add_argument(
            "--grad_checkpointing",
            action="store_true",
            default=False,
            help="Whether to use gradient checkpointing (for longbart only).",
        )
        parser.add_argument(
            "--led",
            action="store_true",
            default=False,
            help="Whether to use longformer encoder decoder (rather than longbart).",
        )
        parser.add_argument(
            "--reformerencoderdecoder_attn_type",
            type=str,
            default='mix',
            choices=['mix', 'lsh', 'local'],
            required=False,
            help=f"What type of reformer self attn to use. One of ['mix', 'lsh', 'loc'].",
        )
        parser.add_argument(
            "--custom_tokenizer",
            action="store_true",
            default=False,
            help=f"Use the custom tokenizer?",
        )
        parser.add_argument(
            "--label_smoothing",
            type=float,
            default=None,
            help=f"What smoothing to use with label smoothing loss?",
        )
        parser.add_argument(
            "--memory_alloc",
            action="store_true",
            default=False,
            help=f"Test memory allocation?",
        )
        parser.add_argument(
            "--random_weights",
            action="store_true",
            default=False,
            help=f"Randomises the weights of the longbart encoder",
        )
        parser.add_argument("--freeze_encoder", action="store_true")
        parser.add_argument("--freeze_embeds", action="store_true")
        parser.add_argument("--sortish_sampler", action="store_true", default=False)
        parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
        parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
        return parser


def main(args, model=None) -> SummarizationModule:
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if model is None:
        model: BaseTransformer = SummarizationModule(args)
    if (
        args.logger == "default"
        or args.fast_dev_run
        or str(args.output_dir).startswith("/tmp")
        or str(args.output_dir).startswith("/var")
    ):
        logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name)
    elif args.logger == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        # TODO: separate LB for CNN, we should use Path(args.data_dir).name to determine the correct LB.
        logger = WandbLogger(name=model.output_dir.name, project="hf_summarization")
    if args.do_train:
        pickle_save(model.hparams, model.output_dir / "hparams.pkl")
    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_rouge2_checkpoint_callback(args.output_dir),
        logger=logger,
        # TODO: early stopping callback seems messed up
    )
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
    if checkpoints and model.hparams.do_predict:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)

    # load model ckpt
    if hasattr(model.hparams, 'model_variant') and model.hparams.model_variant in ['longbart', 'led', 'reformer_encoder_decoder']: 
        ckpt = pl_load(checkpoints[-1], map_location=lambda storage, loc: storage)
        model.load_state_dict(ckpt['state_dict'])

    # Freeze weights
    model.eval()
    freeze_params(model.model.base_model.encoder)
    freeze_params(model.model.base_model.decoder)
    model.freeze_embeds()

    trainer.test(model)  # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
    return model


def predict_args(args):
    args_file = os.path.join(args.output_dir, 'hparams.pkl')
    orig_args = argparse.Namespace(**pickle_load(args_file))
    if hasattr(orig_args, 'led') and orig_args.led:
        orig_args.model_variant = 'led'
    elif hasattr(orig_args, 'longbart_base_model') and orig_args.longbart_base_model in ['longbart', 'reformer_encoder_decoder']:
        orig_args.model_variant = orig_args.longbart_base_model
    elif orig_args.model_name_or_path in ['longbart', 'reformer_encoder_decoder']:
        orig_args.model_variant = orig_args.model_name_or_path
    elif not hasattr(orig_args, 'model_variant'):
        orig_args.model_variant = None
    orig_args.model_name_or_path = args.model_name_or_path
    orig_args.do_train = False
    orig_args.eval_batch_size = 1
    orig_args.logger = "default"
    orig_args.data_dir = args.data_dir

    orig_args.output_dir = args.output_dir
    
    # Set new args to None
    for arg in ['grad_checkpointing', 'custom_tokenizer', 'longbart_base_model']:
        if arg not in orig_args: 
            setattr(orig_args, arg, None)

    return orig_args


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
    args = parser.parse_args()
    if 'best_tfmr' in args.model_name_or_path:
        args = predict_args(args)
    print('NAMESPACE:\n', args)
    Path(args.output_dir).mkdir(exist_ok=True)

    if args.model_variant:
        if args.model_variant in ['longbart', 'led']:
            model, tokenizer = get_led(args)
        elif args.model_variant == 'reformer_encoder_decoder':
            model, tokenizer = get_red(args)
        model: BaseTransformer = SummarizationModule(args, model, tokenizer, pad_to_length=True)
        main(args, model=model)
    else:
        main(args)
