import time
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import logging
from transformers import get_linear_schedule_with_warmup

from seq2seq_pl.src.modeling.lightning_base import BaseTransformer, add_generic_args

from seq2seq_pl.src.utils import (
    lmap,
    flatten_list,
    save_json,
    calculate_rouge,
    ROUGE_KEYS,
    calculate_bleu_score,
    _accuracy
)

from seq2seq_pl.src.dataset import AutoregLMDataset

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class NextTokenLMModule(BaseTransformer):
    mode = "next-token-prediction"
    loss_names = ["loss"]
    metric_names = ROUGE_KEYS + ["bleu"]

    def __init__(self, hparams, **kwargs):
        super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)

        self.metrics_save_path = Path(self.output_dir) / "metrics.json"

        self.step_count = 0
        self.metrics = defaultdict(list)
        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir,
            max_source_length=self.hparams.max_source_length,
            max_target_length=self.hparams.max_target_length,
            prefix=self.model.config.prefix or "",
        )

        self.num_workers = hparams.num_workers

        self.dataset_class = AutoregLMDataset

    def get_dataset(self, type_path, tgt_key, src_key) -> AutoregLMDataset:
        max_input_length = (
            self.hparams.max_source_length + self.hparams.max_target_length
        )

        dataset = self.dataset_class(
            self.tokenizer,
            type_path=type_path,
            tgt_key=tgt_key,
            src_key=src_key,
            **self.dataset_kwargs,
        )
        return dataset

    def get_dataloader(
        self,
        type_path: str,
        batch_size: int,
        tgt_key: str,
        src_key: str,
        shuffle: bool = False,
    ) -> DataLoader:
        dataset = self.get_dataset(type_path, src_key=src_key, tgt_key=tgt_key)
        sampler = None

        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            collate_fn=dataset.collate_fn
            if type_path == "train"
            else dataset.collate_fn_eval,
            shuffle=shuffle,
            num_workers=self.num_workers,
            sampler=sampler,
        )
        return dataloader

    def train_dataloader(self) -> DataLoader:
        dataloader = self.get_dataloader(
            "train",
            batch_size=self.hparams.train_batch_size,
            src_key=self.hparams.src_key,
            tgt_key=self.hparams.tgt_key,
            shuffle=True,
        )

        scheduler = get_linear_schedule_with_warmup(
            self.opt,
            num_warmup_steps=int(self.hparams.warmup_prop * self.hparams.max_steps),
            num_training_steps=self.hparams.max_steps,
        )
        if max(scheduler.get_last_lr()) > 0:
            warnings.warn("All learning rates are 0")
        self.lr_scheduler = scheduler
        return dataloader

    def val_dataloader(self) -> DataLoader:
        return self.get_dataloader(
            "val",
            batch_size=self.hparams.eval_batch_size,
            src_key=self.hparams.src_key,
            tgt_key=self.hparams.tgt_key,
        )

    def test_dataloader(self) -> DataLoader:
        return self.get_dataloader(
            "test",
            batch_size=self.hparams.eval_batch_size,
            src_key=self.hparams.src_key,
            tgt_key=self.hparams.tgt_key,
        )

    def training_step(self, batch, batch_idx: int) -> Dict:
        loss_tensors = self._step(batch)
        logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
        self.log(
            f"train_loss",
            loss_tensors[0],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )
        return {"loss": loss_tensors[0], "log": logs}

    def _step(self, batch: dict) -> Tuple:
        input_ids, attention_mask, labels = (
            batch["input_ids"],
            batch["attention_mask"],
            batch["labels"],
        )

        outputs = self(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels
        )
        loss = outputs[0]
        # print("--" * 20)
        # # print first element of the batch for each input_ids, attention_mask, labels
        # print(f"input_ids: {input_ids[0]}")
        # print(f"attention_mask: {attention_mask[0]}")
        # print(f"labels: {labels[0]}")

        # print(f"loss: {loss}")
        # print("--" * 20)
        return (loss,)

    def forward(self, input_ids, attention_mask, labels):
        return self.model(
            input_ids=input_ids, labels=labels, attention_mask=attention_mask
        )

    def test_epoch_end(self, outputs):
        return self.validation_epoch_end(outputs, prefix="test")

    def test_step(self, batch, batch_idx):
        return self._generative_step(batch)

    @torch.no_grad()
    def _generative_step(self, batch: dict) -> dict:
        


        def calc_generative_metrics(preds, target) -> Dict:
            rouge_metrics = calculate_rouge(preds, target)
            rouge_metrics.update(calculate_bleu_score(preds, target))
            return rouge_metrics

        def _ids_to_clean_text(generated_ids: List[int]):
            gen_text = self.tokenizer.batch_decode(
                generated_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )
            return lmap(str.strip, gen_text)

        # hacks to do batch generation with gpt2
        if "gpt2" in self.hparams.model_name_or_path:
            padding_side_bak = self.tokenizer.padding_side
            pad_token_bak = self.tokenizer.pad_token
            pad_token_id_bak = self.model.config.pad_token_id
            self.tokenizer.padding_side = "left"
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.model.config.pad_token_id = self.model.config.eos_token_id

        input_ids = batch["input_ids"]
        source_text, target_text = batch["source_text"], batch["target_text"]

        sep_token = torch.tensor(
            [[self.tokenizer.io_sep_token_id]] * len(input_ids)
        ).to(self.device)
        input_ids = torch.cat((input_ids, sep_token), dim=1)

        t0 = time.time()

        generated_ids = self.model.generate(
            input_ids=input_ids,
            pad_token_id=self.model.config.eos_token_id,
            max_length=self.hparams.max_target_length
            + self.hparams.max_source_length
            + 20,
        )
        preds = _ids_to_clean_text(generated_ids)

        gen_time = (time.time() - t0) / input_ids.shape[0]

        if "gpt2" in self.hparams.model_name_or_path:
            # https://discuss.huggingface.co/t/batch-generation-with-gpt2/1517/2
            self.tokenizer.padding_side = padding_side_bak
            self.tokenizer.pad_token = pad_token_bak
            self.model.config.pad_token_id = pad_token_id_bak

        cleaned_preds = []
        for pred in preds:
            if self.tokenizer.io_sep_token in pred:
                cleaned_preds.append(
                    pred.split(self.tokenizer.io_sep_token)[-1].strip()
                )
            else:
                cleaned_preds.append(pred)

        preds = cleaned_preds

        loss_tensors = self._step(batch)
        # print("GENERATIVE STEP")
        # print(loss_tensors)
        base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
        rouge: Dict = calc_generative_metrics(preds, target_text)
        summ_len = np.mean(lmap(len, generated_ids))
        base_metrics.update(
            gen_time=gen_time,
            summ_len=summ_len,
            preds=preds,
            target=target_text,
            source=source_text,
            **rouge,
        )
        return base_metrics

    def validation_step(self, batch, batch_idx) -> Dict:
        return self._generative_step(batch)

    def validation_epoch_end(self, outputs, prefix="val") -> Dict:
        self.step_count += 1
        losses = {
            k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names
        }
        loss = losses["loss"]
        rouges = {
            k: np.array([x[k] for x in outputs]).mean()
            for k in self.metric_names + ["gen_time", "summ_len"]
        }
        preds = flatten_list([x["preds"] for x in outputs])
        targets = flatten_list([x["target"] for x in outputs])
        sources = flatten_list([x["source"] for x in outputs])

        rouges.update({k: v.item() for k, v in losses.items()})
        losses.update(rouges)
        metrics = {f"{prefix}_{k}": x for k, x in losses.items()}
        metrics["avg_rouge1"] = losses["rouge1"]
        metrics["avg_rouge2"] = losses["rouge2"]
        metrics["step_count"] = self.step_count
        metrics[f"{prefix}_accuracy"] = _accuracy(preds, targets)

        

        self.save_metrics(metrics, prefix)  # writes to self.metrics_save_path
        for metric_name, metric_val in metrics.items():
            self.log(
                metric_name, metric_val, on_epoch=True, prog_bar=True, sync_dist=True
            )

        od = Path(self.hparams.output_dir)
        generations_file = od / f"{prefix}_generations/{self.global_step:05d}.jsonl"
        generations_file.parent.mkdir(exist_ok=True)
        # save predictions and targets for analysis
        pd.DataFrame({"source": sources, "target": targets, "pred": preds}).to_json(
            generations_file, orient="records", lines=True
        )
        res = {
            "log": metrics,
            f"{prefix}_loss": loss.mean()
            # f"{prefix}_{self.hparams.val_metric}": rouge_tensor,
        }

        return res

    def save_metrics(self, latest_metrics, type_path) -> None:
        self.metrics[type_path].append(latest_metrics)
        save_json(self.metrics, self.metrics_save_path)

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        BaseTransformer.add_model_specific_args(parser, root_dir)
        add_generic_args(parser, root_dir)
        parser.add_argument("--add_lr_scheduler", action="store_true", default=False)

        parser.add_argument(
            "--max_source_length",
            default=48,
            type=int,
            help="The maximum total (training) input sequence question length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument("--src_key", default="question", help="source key in jsonl")

        parser.add_argument("--tgt_key", default="answer", help="target key in jsonl")

        parser.add_argument(
            "--max_target_length",
            default=24,
            type=int,
            help="The maximum total (training) input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--val_max_target_length",
            default=24,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--test_max_target_length",
            default=24,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--data_dir",
            type=str,
            required=True,
            help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
        )
        parser.add_argument(
            "--freeze_encoder",
            action="store_true",
            help="freeze encoder layer in an encoder decoder architecture such as BART",
        )
        parser.add_argument("--freeze_embeds", action="store_true")
        # Sortish (batches will get similar length inputs)". Not very impactful, so commenting
        # parser.add_argument("--sortish_sampler", action="store_true", default=False)
        parser.add_argument(
            "--logger_name",
            type=str,
            choices=["default", "wandb", "wandb_shared"],
            default="default",
        )

        parser.add_argument(
            "--n_train",
            type=int,
            default=-1,
            required=False,
            help="# examples. -1 means use all.",
        )
        parser.add_argument(
            "--n_val",
            type=int,
            default=500,
            required=False,
            help="# examples. -1 means use all.",
        )
        parser.add_argument(
            "--n_test",
            type=int,
            default=-1,
            required=False,
            help="# examples. -1 means use all.",
        )
        # parser.add_argument(
        #     "--task",
        #     type=str,
        #     default="summarization",
        #     required=False,
        #     help="# examples. -1 means use all.",
        # )
        # parser.add_argument("--src_lang", type=str, default="", required=False)
        # parser.add_argument("--tgt_lang", type=str, default="", required=False)
        parser.add_argument("--val_metric", type=str, default="rouge2", required=False)
        parser.add_argument(
            "--special_tokens",
            help="Path to the file that stores the special tokens. One special token per line",
            type=str,
            default=None,
            required=False,
        )
        parser.add_argument(
            "--save_top_k", type=int, help="Number of checkpoints to save", default=3
        )
        return parser
