import time
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import logging
from transformers import T5ForConditionalGeneration
from transformers.models.bart.modeling_bart import shift_tokens_right
from modeling.lightning_base import BaseTransformer, add_generic_args
from transformers import get_linear_schedule_with_warmup

from seq2seq_pl.src.utils import (
    lmap,
    flatten_list,
    save_json,
    calculate_rouge,
    ROUGE_KEYS,
    calculate_bleu_score,
    freeze_params,
    assert_all_frozen
)

from seq2seq_pl.src.dataset import Seq2SeqDataset

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)



class Seq2SeqLMModule(BaseTransformer):
    mode = "seq2seq"
    loss_names = ["loss"]
    metric_names = ROUGE_KEYS + ["bleu"]

    def __init__(self, hparams, **kwargs):
        super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)

        self.metrics_save_path = Path(self.output_dir) / "metrics.json"

        self.step_count = 0
        self.metrics = defaultdict(list)
        self.vocab_size =  self.config.vocab_size

        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir,
            max_source_length=self.hparams.max_source_length,
            max_target_length=self.hparams.max_target_length,
            prefix=self.model.config.prefix or "",
        )
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}

        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        # AssertionError: target_lens: {'train': 157, 'val': 24, 'test': 24}
        assert (
            self.target_lens["train"] <= self.target_lens["val"]
        ), f"target_lens: {self.target_lens}"
        assert (
            self.target_lens["train"] <= self.target_lens["test"]
        ), f"target_lens: {self.target_lens}"

        if self.hparams.freeze_embeds:
            self.freeze_embeds()
        if self.hparams.freeze_encoder:
            freeze_params(self.model.get_encoder())
            assert_all_frozen(self.model.get_encoder())

        # self.hparams.git_sha = get_git_info()["repo_sha"]
        try:
            self.num_workers = hparams.num_workers
        except AttributeError:
            self.num_workers = 2

        self.decoder_start_token_id = None
        self.dataset_class = Seq2SeqDataset

    def freeze_embeds(self):
        """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
        try:
            freeze_params(self.model.model.shared)
            for d in [self.model.model.encoder, self.model.model.decoder]:
                freeze_params(d.embed_positions)
                freeze_params(d.embed_tokens)
        except AttributeError:
            freeze_params(self.model.shared)
            for d in [self.model.encoder, self.model.decoder]:
                freeze_params(d.embed_tokens)

    def training_step(self, batch, batch_idx: int) -> Dict:
        loss_tensors = self._step(batch)
        logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
        self.log(
            f"train_loss",
            loss_tensors[0],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )
        return {"loss": loss_tensors[0], "log": logs}

    def _step(self, batch: dict) -> Tuple:
        pad_token_id = self.tokenizer.pad_token_id
        source_ids, source_mask, y = (
            batch["input_ids"],
            batch["attention_mask"],
            batch["decoder_input_ids"],
        )
        if isinstance(self.model, T5ForConditionalGeneration):
            decoder_input_ids = self.model._shift_right(y)
        else:
            decoder_input_ids = shift_tokens_right(y, pad_token_id)
    
        outputs = self(
            source_ids,
            attention_mask=source_mask,
            decoder_input_ids=decoder_input_ids,
            use_cache=False
        )
        lm_logits = outputs[0]
        ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)

        assert lm_logits.shape[-1] == self.vocab_size, f"{lm_logits.shape} != {self.vocab_size}"

        loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), y.view(-1))

        return (loss,)

    def forward(self, input_ids, **kwargs):
        # this is inference/ prediction
        return self.model(input_ids, **kwargs)

    def validation_step(self, batch, batch_idx) -> Dict:
        return self._generative_step(batch)

    def _generative_step(self, batch: dict) -> dict:
        pad_token_id = self.tokenizer.pad_token_id
        source_ids, source_mask, y = Seq2SeqDataset.trim_seq2seq_batch(batch, pad_token_id)
        t0 = time.time()
        generated_ids = self.model.generate(
            input_ids=source_ids,
            attention_mask=source_mask,
            use_cache=True,
            num_beams=4,
            max_length=self.hparams.max_target_length,
            decoder_start_token_id=self.decoder_start_token_id,
        )

        gen_time = (time.time() - t0) / source_ids.shape[0]
        source_text = self.ids_to_clean_text(source_ids)
        preds = self.ids_to_clean_text(generated_ids)
        target = self.ids_to_clean_text(y)
        loss_tensors = self._step(batch)
        base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
        rouge: Dict = self.calc_generative_metrics(preds, target)
        summ_len = np.mean(lmap(len, generated_ids))
        base_metrics.update(
            gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, source=source_text, **rouge
        )
        return base_metrics

    def ids_to_clean_text(self, generated_ids: List[int]):
        gen_text = self.tokenizer.batch_decode(
            # generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True

        )
        return lmap(str.strip, gen_text)

    def calc_generative_metrics(self, preds, target) -> Dict:
        rouge_metrics = calculate_rouge(preds, target)
        rouge_metrics.update(calculate_bleu_score(preds, target))
        return rouge_metrics

    def validation_epoch_end(self, outputs, prefix="val") -> Dict:
        self.step_count += 1
        losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
        loss = losses["loss"]
        rouges = {
            k: np.array([x[k] for x in outputs]).mean()
            for k in self.metric_names + ["gen_time", "summ_len"]
        }
        rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.hparams.val_metric]).type_as(
            loss
        )
        rouges.update({k: v.item() for k, v in losses.items()})
        losses.update(rouges)
        metrics = {f"{prefix}_{k}": x for k, x in losses.items()}
        metrics["avg_rouge1"] = losses["rouge1"]
        metrics["avg_rouge2"] = losses["rouge2"]
        metrics["step_count"] = self.step_count

        self.save_metrics(metrics, prefix)  # writes to self.metrics_save_path
        preds = flatten_list([x["preds"] for x in outputs])
        targets = flatten_list([x["target"] for x in outputs])
        sources = flatten_list([x["source"] for x in outputs])
        for metric_name, metric_val in metrics.items():
            self.log(metric_name, metric_val, on_epoch=True, prog_bar=True, sync_dist=True)
        

        od = Path(self.hparams.output_dir)
        generations_file = od / f"{prefix}_generations/{self.global_step:05d}.jsonl"
        generations_file.parent.mkdir(exist_ok=True)
        # save predictions and targets for analysis
        pd.DataFrame({"source": sources, "target": targets, "pred": preds}).to_json(generations_file, orient="records", lines=True)
        res = {
            "log": metrics,
            f"{prefix}_loss": loss,
            f"{prefix}_{self.hparams.val_metric}": rouge_tensor,
        }
        self.metrics[f"{prefix}_{self.hparams.val_metric}"] = rouge_tensor.item()
        return res

    def save_metrics(self, latest_metrics, type_path) -> None:
        self.metrics[type_path].append(latest_metrics)
        save_json(self.metrics, self.metrics_save_path)

    def test_step(self, batch, batch_idx):
        return self._generative_step(batch)

    def test_epoch_end(self, outputs):
        return self.validation_epoch_end(outputs, prefix="test")

    def get_dataset(self, type_path,tgt_key, src_key) -> Seq2SeqDataset:
        n_obs = self.n_obs[type_path]
        max_target_length = self.target_lens[type_path]
        dataset = self.dataset_class(
            self.tokenizer,
            type_path=type_path,
            n_obs=n_obs,
            tgt_key=tgt_key,
            src_key=src_key,
            **self.dataset_kwargs,
        )
        return dataset

    def get_dataloader(self, type_path: str, batch_size: int, tgt_key: str, src_key: str, shuffle: bool = False) -> DataLoader:
        dataset = self.get_dataset(type_path, src_key=src_key, tgt_key=tgt_key)
        sampler = None
        # if self.hparams.sortish_sampler and type_path == "train":
        #     assert self.hparams.gpus <= 1
        #     sampler = dataset.make_sortish_sampler(batch_size)
        #     shuffle = False

        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            collate_fn=dataset.collate_fn,
            shuffle=shuffle,
            num_workers=self.num_workers,
            sampler=sampler,
        )
        return dataloader

    def train_dataloader(self) -> DataLoader:
        dataloader = self.get_dataloader(
            "train", batch_size=self.hparams.train_batch_size, src_key=self.hparams.src_key, tgt_key=self.hparams.tgt_key, shuffle=True
        )
        t_total = (
            (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
            // self.hparams.accumulate_grad_batches
            * float(self.hparams.max_epochs)
        )
        scheduler = get_linear_schedule_with_warmup(
            self.opt,
            num_warmup_steps=int(self.hparams.warmup_prop * t_total),
            num_training_steps=t_total,
        )
        if max(scheduler.get_last_lr()) > 0:
            warnings.warn("All learning rates are 0")
        self.lr_scheduler = scheduler
        return dataloader

    def val_dataloader(self) -> DataLoader:
        return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size,src_key=self.hparams.src_key, tgt_key=self.hparams.tgt_key,)

    def test_dataloader(self) -> DataLoader:
        return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size,src_key=self.hparams.src_key, tgt_key=self.hparams.tgt_key,)

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        BaseTransformer.add_model_specific_args(parser, root_dir)
        add_generic_args(parser, root_dir)
        parser.add_argument("--add_lr_scheduler", action="store_true", default=False)

        parser.add_argument(
            "--max_source_length",
            type=int,
            required=True,
            help="The maximum total (training) input sequence question length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument("--src_key", default="question", help="source key in jsonl")

        parser.add_argument("--tgt_key", default="answer", help="target key in jsonl")

        parser.add_argument(
            "--max_target_length",
            required=True,
            type=int,
            help="The maximum total (training) input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--val_max_target_length",
            default=24,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--test_max_target_length",
            default=24,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--data_dir",
            type=str,
            required=True,
            help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
        )
        parser.add_argument("--freeze_encoder", action="store_true", help="freeze encoder layer in an encoder decoder architecture such as BART")
        parser.add_argument("--freeze_embeds", action="store_true")
        parser.add_argument("--sortish_sampler", action="store_true", default=False)
        parser.add_argument(
            "--logger_name",
            type=str,
            choices=["default", "wandb", "wandb_shared"],
            default="default",
        )

        parser.add_argument(
            "--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all."
        )
        parser.add_argument(
            "--n_val", type=int, default=500, required=False, help="# examples. -1 means use all."
        )
        parser.add_argument(
            "--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all."
        )
        # parser.add_argument(
        #     "--task",
        #     type=str,
        #     default="summarization",
        #     required=False,
        #     help="# examples. -1 means use all.",
        # )
        # parser.add_argument("--src_lang", type=str, default="", required=False)
        # parser.add_argument("--tgt_lang", type=str, default="", required=False)
        parser.add_argument("--val_metric", type=str, default="rouge2", required=False)
        parser.add_argument(
            "--special_tokens",
            help="Path to the file that stores the special tokens. One special token per line",
            type=str,
            default=None,
            required=False,
        )
        parser.add_argument("--save_top_k", type=int, help="Number of checkpoints to save", default=3)
        return parser
