from collections import Counter
import json, re, string
from typing import Callable, Optional
from functools import partial
from pathlib import Path
import pandas as pd
import torch
from torch import Tensor
import torch.nn.functional as F
from torchmetrics.functional import (
    accuracy,
    f1_score,
    matthews_corrcoef,
    pearson_corrcoef,
    spearman_corrcoef,
)

from _abstract_task.training import TrainingModule
from _modeling.configuration_tsp import TSPConfig
from _modeling.modeling_tsp import (
    TSPModelForSequenceClassification,
    TSPModelForSpanComparison,
    TSPModelForReCoRD,
    TSPModelForTextPair,
)


class GLUEModel(TrainingModule):
    # Reference:
    # microsoft/DeBERTa (https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/apps/models/record_qa.py)
    # nyu-mll/jiant (https://github.com/nyu-mll/jiant/blob/de5437ae710c738a0481b13dc9d266dd558c43a4/jiant/proj/main/modeling/taskmodels.py)

    MODEL_CLASS = TSPModelForSequenceClassification
    TASL: str  # task name, used in logging and checkpointing
    LABEL_NAMES: list[str]  # names of label for "training"
    METRICS: list[Callable]  # list of metric calculating functions
    TEST_OUTPUT_FILE_NAME: str  # e.g. xxx.tsv or xxx.jsonl
    TEST_MAP_LABEL_TO_STRING: bool  # Whether to cast label from int to str according to LABEL_NAMES

    def __init__(self, config, model=None):
        super().__init__()
        self.save_hyperparameters(ignore=["model"])
        self.config = config

        # Model
        self.model = model
        if model is None:
            self.model = self.MODEL_CLASS.from_pretrained(
                TSPConfig.from_config(config), num_classes=len(self.LABEL_NAMES),
            )

    @classmethod
    def from_pretrained(cls, config, except_classifier=False):
        model_class = cls.MODEL_CLASS
        return cls(
            config=config,
            model=cls.load_pretrained_model(
                config,
                model_cls=model_class,
                task=cls.TASK,
                num_classes=len(cls.LABEL_NAMES),
                except_classifier=except_classifier,
            ),
        )

    def forward(self, batch):
        return self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["non_pad"].long(),
            token_type_ids=batch["segment_ids"],
        )  # (B,C)

    def training_step(self, batch, batch_idx):
        logits = self(batch)  # (B,C)
        loss = F.cross_entropy(logits, batch["label"])  # preds (N,C) target (N,)
        return loss

    def _get_prediction(self, logits):  # (B, NUM_CLASSES)
        return logits.argmax(dim=-1)  # <int>(B)

    def validation_step(self, batch, batch_idx, dataloader_idx=None):
        logits = self(batch)  # (B,C)
        preds = self._get_prediction(logits)
        return batch["idx"], preds, batch["label"]

    def test_step(self, batch, batch_idx, dataloader_idx=None):
        return self.validation_step(batch, batch_idx, dataloader_idx)

    def _calculate_metrics(
        self, idxs, preds, labels,  # <int>(N) or (N,...)  # <int>(N)  # <int>(N)
    ) -> dict[str, float]:
        metric_dict = {}
        for metric_name, metric_fn in self.METRICS.items():
            metric_value = metric_fn(preds, labels)
            metric_dict[metric_name] = metric_value
        return metric_dict

    # Concatenate step outputs to get epoch-wise result
    # If the task split a glue example into multiple data instance, it will overwite this method and recover it.
    def _step_outputs_to_epoch_result(
        self, step_outputs: list[tuple[Tensor, Tensor, Tensor]]
    ) -> tuple[Tensor, Tensor, Tensor]:
        idxs, preds, labels = tuple(zip(*step_outputs))
        idxs = torch.cat(idxs)  # <int>(N)
        preds = torch.cat(preds)  # <int>(N)
        if labels[0] is not None:
            labels = torch.cat(labels)  # <int>(N)
        else:
            labels = None
        return idxs, preds, labels

    def validation_epoch_end(self, step_outputs: list[tuple]):
        idxs, preds, labels = self._step_outputs_to_epoch_result(
            step_outputs
        )  # <int>(N)
        metric_dict = self._calculate_metrics(idxs, preds, labels)
        for name, value in metric_dict.items():
            self.log(name=name, value=value, prog_bar=True)

    @property
    def test_output_dir(self):
        saving_checkpoint_dir = Path(self.trainer.checkpoint_callback.dirpath)
        test_output_dir = saving_checkpoint_dir / "test_outputs"
        test_output_dir.mkdir(exist_ok=True, parents=True)
        return test_output_dir

    def test_epoch_end(self, step_outputs: list[tuple]):
        idxs, preds, _ = self._step_outputs_to_epoch_result(step_outputs)
        if not isinstance(idxs, list):
            idxs = idxs.tolist()
        if not isinstance(preds, list):
            preds = preds.tolist()

        # Cast to string label if needed
        if self.TEST_MAP_LABEL_TO_STRING:
            preds = [self.LABEL_NAMES[p] for p in preds]

        # Create and Save testing results
        identifier, extension = self.TEST_OUTPUT_FILE_NAME.split(".")
        test_file_path = self.test_output_dir / self.TEST_OUTPUT_FILE_NAME
        if extension == "tsv":  # GLUE
            test_df = pd.DataFrame({"index": idxs, "prediction": preds})
            test_df.to_csv(test_file_path, sep="\t")
        else:  # SuperGLUE
            assert extension == "jsonl"
            with test_file_path.open(mode="w") as f:
                for idx, pred in zip(idxs, preds):
                    f.write(json.dumps({"idx": idx, "label": pred}) + "\n")
        if identifier == "RTE":  # RTE is also in SuperGLUE
            with test_file_path.with_suffix(".jsonl").open(mode="w") as f:
                for idx, pred in zip(idxs, preds):
                    f.write(json.dumps({"idx": idx, "label": pred}) + "\n")


class RTEModel(GLUEModel):  # Recognizing Textual Entailment
    TASK = "rte"
    LABEL_NAMES = ["entailment", "not_entailment"]
    METRICS = {"accuracy": accuracy}
    TEST_OUTPUT_FILE_NAME = "RTE.tsv"
    TEST_MAP_LABEL_TO_STRING = True


class CBModel(GLUEModel):  # CommitmentBank
    TASK = "cb"
    LABEL_NAMES = ["entailment", "contradiction", "neutral"]
    METRICS = {
        "accuracy": accuracy,
        "f1": partial(f1_score, average="macro", num_classes=3),
    }
    TEST_OUTPUT_FILE_NAME = "CB.jsonl"
    TEST_MAP_LABEL_TO_STRING = True


class COPAModel(GLUEModel):  # Choice of Plausible Alternatives
    TASK = "copa"
    LABEL_NAMES = ["this_choice_is_answer"]
    METRICS = {"accuracy": accuracy}
    TEST_OUTPUT_FILE_NAME = "COPA.jsonl"
    TEST_MAP_LABEL_TO_STRING = False

    def forward(self, batch):
        _logits = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["non_pad"].long(),
            token_type_ids=batch["segment_ids"],
        )  # (2B, 1)
        # Note the first/second half of batch is for choice1/choice2
        logits = _logits.view(2, -1)  # (2,B)
        return logits.transpose(1, 0)  # (B,2)


class MultiRCModel(GLUEModel):  # Multi-Sentence Reading Comprehension
    TASK = "multirc"
    LABEL_NAMES = ["False", "True"]
    METRICS = {"em": accuracy, "f1a": partial(f1_score, multiclass=False)}
    TEST_OUTPUT_FILE_NAME = "MultiRC.jsonl"
    TEST_MAP_LABEL_TO_STRING = False

    def _calculate_metrics(
        self, idxs, preds, labels,  # <int>(N,3)  # <int>(N)  # <int>(N)
    ) -> dict[str, float]:
        # f1a is calculated over all answer choices inside the dataset
        f1 = f1_score(preds, labels)

        # EM here is calculated over all questions.
        ## Note that only if all answer choice of the question is correctly predicted will the
        ## questionconsidered exactly matched.
        df = pd.DataFrame(
            {
                "question_idx": idxs[:, 1].tolist(),
                "preds": preds.tolist(),
                "labels": labels.tolist(),
            }
        )
        em = (
            df.groupby("question_idx")
            .apply(lambda _df: (_df["preds"] == _df["labels"]).all())
            .mean()
            .item()
        )

        return {"em": em, "f1a": f1}

    def test_epoch_end(self, outputs):
        idxs, preds, _ = tuple(zip(*outputs))
        idxs = torch.cat(idxs)  # <int>(N,3)
        preds = torch.cat(preds).tolist()  # <int>(N)

        # Nested grouping
        result_dict = {}
        for idx, pred in zip(idxs, preds):
            passage_idx, question_idx, answer_idx = idx.tolist()
            if passage_idx not in result_dict:
                result_dict[passage_idx] = {}
            if question_idx not in result_dict[passage_idx]:
                result_dict[passage_idx][question_idx] = {}
            result_dict[passage_idx][question_idx][answer_idx] = pred

        # Convert to required format
        passage_dicts = []
        for passage_idx, passage in result_dict.items():
            question_dicts = {"questions": []}
            for question_idx, question in passage.items():
                answer_dicts = []
                for answer_idx, pred in question.items():
                    answer_dict = {"idx": answer_idx, "label": pred}
                    answer_dicts.append(answer_dict)
                question_dict = {"idx": question_idx, "answers": answer_dicts}
                question_dicts["questions"].append(question_dict)
            passage_dict = {"idx": passage_idx, "passage": question_dicts}
            passage_dicts.append(passage_dict)

        # Create and Save testing results
        test_file_path = self.test_output_dir / self.TEST_OUTPUT_FILE_NAME
        with test_file_path.open(mode="w") as f:
            for passage_dict in passage_dicts:
                f.write(json.dumps(passage_dict) + "\n")


class WiCModel(GLUEModel):  # Words in Context
    TASK = "wic"
    MODEL_CLASS = TSPModelForSpanComparison
    LABEL_NAMES = ["false", "true"]
    METRICS = {"accuracy": accuracy}
    TEST_OUTPUT_FILE_NAME = "WiC.jsonl"
    TEST_MAP_LABEL_TO_STRING = True

    def forward(self, batch):
        return self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["non_pad"].long(),
            token_type_ids=batch["segment_ids"],
            span_ranges=batch["word_ranges"],
        )  # (B,C)


class WSCModel(GLUEModel):  # The Winograd Schema Challenge
    TASK = "wsc"
    MODEL_CLASS = TSPModelForSpanComparison
    LABEL_NAMES = ["False", "True"]
    METRICS = {"accuracy": accuracy}
    TEST_OUTPUT_FILE_NAME = "WSC.jsonl"
    TEST_MAP_LABEL_TO_STRING = True

    def forward(self, batch):
        return self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["non_pad"].long(),
            token_type_ids=batch["segment_ids"],
            span_ranges=batch["word_ranges"],
        )  # (B,C)


class BoolQModel(GLUEModel):  # BoolQ
    TASK = "boolq"
    LABEL_NAMES = ["false", "true"]
    METRICS = {"accuracy": accuracy}
    TEST_OUTPUT_FILE_NAME = "BoolQ.jsonl"
    TEST_MAP_LABEL_TO_STRING = True


class ReCoRDModel(GLUEModel):  # Reading Comprehension with Commonsense Reasoning
    TASK = "record"
    MODEL_CLASS = TSPModelForReCoRD
    LABEL_NAMES = ["this_entity_is_the_answer"]
    METRICS = {"accuracy": accuracy, "f1": f1_score}
    TEST_OUTPUT_FILE_NAME = "ReCoRD.jsonl"
    TEST_MAP_LABEL_TO_STRING = False

    def forward(self, batch):
        logits = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["non_pad"].long(),
            token_type_ids=batch["segment_ids"],
            masked_idx=batch["masked_idx"],  # <int>(B)
            entity_ranges=batch["entity_ranges"],  # <int>(B, #entities, 2)
        )
        logits = logits.squeeze(-1)  # (B, #entities)
        is_pad = batch["entity_labels"] == -1  # <bool>(B, #entities)
        logits[is_pad] = -float("inf")
        return logits  # (B, #entities)

    def training_step(self, batch, batch_idx):
        logits = self(batch)  # (B, #entities)
        non_pad = logits.isfinite()
        loss = F.binary_cross_entropy_with_logits(
            logits[non_pad], batch["entity_labels"][non_pad].float()
        )
        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx=None):
        logits = self(batch)  # (B, #entities)
        preds = self._get_prediction(logits)
        return batch["idx"], preds, None

    def _step_outputs_to_epoch_result(
        self, step_outputs: list[tuple[Tensor, Tensor, Tensor]]
    ) -> tuple[list[int], list[str], list[list[str]]]:
        split = "test" if self.trainer.testing else "validation"
        dataset = self.trainer.datamodule.raw_datasets[split]
        idxs, preds, _ = super()._step_outputs_to_epoch_result(
            step_outputs
        )  # <int>(N), <int>(N)
        idxs = idxs.tolist()
        preds = preds.tolist()

        pred_answers: list[str] = []
        all_answers: list[list[str]] = []
        for idx, pred in zip(idxs, preds):
            example = dataset[idx]
            pred_answers.append(example["entity_spans"]["text"][pred])
            all_answers.append(example["answers"])

        return idxs, pred_answers, all_answers

    def _calculate_metrics(
        self, idxs: list[int], pred_answers: list[str], all_answers: list[list[str]]
    ) -> dict[str, float]:
        ems, f1s = [], []
        for pred_answer, answers in zip(pred_answers, all_answers):
            pred_answer = normalize_text(pred_answer)
            answers = [normalize_text(a) for a in answers]
            ems.append(self._exact_match_over_answers(pred_answer, answers))
            f1s.append(self._f1_over_answers(pred_answer, answers))
        return {
            "accuracy": sum(ems) / len(ems),
            "f1": sum(f1s) / len(f1s),
        }

    def _exact_match_over_answers(self, pred_answer: str, answers: list[str]) -> float:
        return pred_answer in answers

    def _f1_over_answers(self, pred_answer: str, answers: list[str]) -> float:
        _f1s = []
        pred_tokens = pred_answer.split()
        for answer in answers:
            gold_tokens = answer.split()
            common = Counter(pred_tokens) & Counter(gold_tokens)
            num_same = sum(common.values())
            if num_same == 0:
                return 0
            precision = 1.0 * num_same / len(pred_tokens)
            recall = 1.0 * num_same / len(gold_tokens)
            f1 = (2 * precision * recall) / (precision + recall)
            _f1s.append(f1)
        return max(_f1s)


class AXbModel(GLUEModel):  # Broadcoverage Diagnostics
    TASK = "rte"  # evaluating on this dataset using RTE checkpoint
    LABEL_NAMES = ["entailment", "not_entailment"]
    METRICS = {"matthewscorr": partial(matthews_corrcoef, num_classes=2)}
    TEST_OUTPUT_FILE_NAME = "AX-b.jsonl"
    TEST_MAP_LABEL_TO_STRING = True


class AXgModel(GLUEModel):  # Winogender Schema Diagnostics
    TASK = "rte"  # evaluating on this dataset using RTE checkpoint
    LABEL_NAMES = ["entailment", "not_entailment"]
    METRICS = {"accuracy": accuracy}
    TEST_OUTPUT_FILE_NAME = "AX-g.jsonl"
    TEST_MAP_LABEL_TO_STRING = True


def normalize_text(text: str) -> str:
    text = text.lower()  # lower
    text = "".join(ch for ch in text if ch not in string.punctuation)  # remove_punc
    text = re.sub(r"\b(a|an|the)\b", " ", text)  # remove_articles
    text = " ".join(text.split())  # white_space_fix
    return text


MODEL_MODULES = {
    "cola": CoLAModel,
    "sst2": SST2Model,
    "mrpc": MRPCModel,
    "stsb": STSBModel,
    "qqp": QQPModel,
    "mnli": MNLIModel,
    "qnli": QNLIModel,
    "rte": RTEModel,
    "wnli": WNLIModel,
    "ax": AXModel,
    "cb": CBModel,
    "copa": COPAModel,
    "multirc": MultiRCModel,
    "wic": WiCModel,
    "wsc": WSCModel,
    "boolq": BoolQModel,
    "record": ReCoRDModel,
    "axb": AXbModel,
    "axg": AXgModel,
}
