from typing import Union
from pathlib import Path
import warnings
import torch
import torch.nn.functional as F
from torchmetrics import Accuracy

from _abstract_task.training import TrainingModule
from _modeling.modeling_tsp import TSPConfig, TSPModelForPreTraining
from Pretraining.convert_hf_checkpoint import load_converted_hf_state_dict


class Pretraining(TrainingModule):
    def __init__(self, config, model=None):
        super().__init__()
        self.save_hyperparameters(ignore=["model"])
        self.config = config
        if self.config.tsp_without_hierarchy:
            self.tsp_classes = {
                "same_doc_and_reverse": 0,
                "same_para_and_reverse": 0,
                "neighbor_and_reverse": 0,
                "neighbor_and_forward": 1,
                "same_para_and_forward": 1,
                "same_doc_and_forward": 1,
            }
            num_tsp_classes = 2
        elif self.config.tsp_without_order:
            self.tsp_classes = {
                "same_doc_and_reverse": 0,
                "same_para_and_reverse": 1,
                "neighbor_and_reverse": 2,
                "neighbor_and_forward": 0,
                "same_para_and_forward": 1,
                "same_doc_and_forward": 2,
            }
            num_tsp_classes = 3
        elif self.config.tsp_without_paragraph:
            self.tsp_classes = {
                "same_doc_and_reverse": 0,
                "same_para_and_reverse": 0,
                "neighbor_and_reverse": 1,
                "neighbor_and_forward": 2,
                "same_para_and_forward": 3,
                "same_doc_and_forward": 3,
            }
            num_tsp_classes = 4
        else:
            self.tsp_classes = {
                "same_doc_and_reverse": 0,
                "same_para_and_reverse": 1,
                "neighbor_and_reverse": 2,
                "neighbor_and_forward": 3,
                "same_para_and_forward": 4,
                "same_doc_and_forward": 5,
            }
            num_tsp_classes = 6
        if model is None:
            self.model = TSPModelForPreTraining(
                TSPConfig.from_config(config), num_classes=num_tsp_classes,
            )
        else:
            self.model = model

        if self.config.tsp_loss_weight:
            self.tsp_accuracy = Accuracy(num_classes=num_tsp_classes)

        if self.config.inter_segment_task:
            num_classes = 3 if self.config.inter_segment_task == "sso" else 2
            self.inter_segment_head = torch.nn.Linear(config.hidden_size, num_classes)

    @classmethod
    def from_pretrained(cls, config):
        need_converting = not config.load_ckpt_path.endswith(".ckpt")
        return cls(
            config,
            model=cls.load_pretrained_model(
                config=config,
                model_cls=TSPModelForPreTraining,
                state_dict=load_converted_hf_state_dict(config.load_ckpt_path)
                if need_converting
                else None,
                num_classes=6,
            ),
        )

    def training_step(self, batch, batch_idx):
        loss = 0

        batch, mlm_loss = self.mlm_phase(batch)
        loss += mlm_loss

        if self.config.use_electra:
            batch = self.sampling_phase(batch)
            if "permutation" in batch:
                batch = self.apply_sentence_shuffling(batch)
            batch, rtd_loss = self.rtd_phase(batch)
            loss += rtd_loss * 50

        if self.config.tsp_loss_weight:
            tsp_loss = self.text_structure_prediction(batch)
            loss += tsp_loss * self.config.tsp_loss_weight

        if self.config.inter_segment_task:
            seg_loss = self.inter_segment_task(
                batch=batch,
                hidden_states=batch["hidden_states"],
                label=batch["inter_segment_labels"],
            )
            loss += seg_loss

        self.log(f"loss", loss, sync_dist=True)
        if self.global_step == 10:
            print(loss.item())
        if not loss.isfinite().item():
            # raise KeyboardInterrupt()
            warnings.warn("Infinite loss value is detected, it is set to zero.")
            loss.zero_()

        return loss

    # The final clean version of training is also written under the model
    # def training_step(self, batch, batch_idx):
    #     loss = self.model.forward(
    #         masked_ids=batch["corrupted_ids"],
    #         segment_ids=batch["segment_ids"],
    #         mlm_labels=batch["mlm_labels"],
    #         sentence_marks=batch["sentence_marks"],
    #         paragraph_ids=batch["paragraph_ids"],
    #         permutation=batch["permutation"],
    #     )
    #     if self.global_step == 10:
    #         print(loss.item())
    #     return loss

    def mlm_phase(self, batch):
        (
            mlm_loss,
            batch["mlm_logits"],
            batch["mlm_original_ids"],
            batch["hidden_states"],
            batch["mlm_selected"],
        ) = self.model.mlm_phase(
            masked_ids=batch["corrupted_ids"],
            segment_ids=batch["segment_ids"],
            mlm_labels=batch["mlm_labels"],
        )
        self.log(f"mlm_loss", mlm_loss, sync_dist=True)
        return batch, mlm_loss

    def sampling_phase(self, batch: dict):
        batch["corrupted_ids"], batch["rtd_labels"] = self.model.sampling_phase(
            mlm_selected=batch["mlm_selected"],
            masked_ids=batch["corrupted_ids"],
            mlm_logits=batch["mlm_logits"],
            mlm_original_ids=batch["mlm_original_ids"],
        )
        return batch

    def apply_sentence_shuffling(self, batch):
        perm = batch["permutation"]
        perm = perm.clip(min=0, max=batch["corrupted_ids"].shape[1]-1)
        batch["corrupted_ids"] = batch["corrupted_ids"].gather(1, perm)
        batch["rtd_labels"] = batch["rtd_labels"].gather(1, perm)
        is_sep = batch["corrupted_ids"] == self.model.config.sep_token_id
        batch["segment_ids"] = (is_sep.cumsum(dim=1) - is_sep.long()).bool().long()
        return batch

    def rtd_phase(self, batch):
        rtd_loss, rtd_logits, batch["hidden_states"] = self.model.rtd_phase(
            replaced_ids=batch["corrupted_ids"],
            segment_ids=batch["segment_ids"],
            rtd_labels=batch["rtd_labels"],
        )
        self.log("rtd_loss", rtd_loss, sync_dist=True)
        return batch, rtd_loss

    def text_structure_prediction(self, batch):
        B, S, device = *batch["paragraph_ids"].shape, batch["paragraph_ids"].device
        para_ids = batch["paragraph_ids"]  # <int>(B, S)
        padding = para_ids == -1  # <bool>(B, S)

        # Get relations of sentence pairs at different level of hierarchy
        postitions = torch.arange(S, device=device).view(1, S)
        diff = postitions.view(-1, 1, S) - postitions.view(-1, S, 1)  # <int>(B,S,S)
        dist = diff.abs()  # <int>(B,S,S)
        same_para = para_ids.view(B, 1, S) == para_ids.view(B, S, 1)  # <bool>(B,S,S)
        neighbor = dist == 1
        reverse, forward = (diff < 0).expand(B, S, S), (diff > 0).expand(B, S, S)

        # Labeling
        labels = torch.full_like(same_para, -1, dtype=torch.long)
        labels[forward] = self.tsp_classes["same_doc_and_forward"]
        labels[reverse] = self.tsp_classes["same_doc_and_reverse"]
        labels[same_para & forward] = self.tsp_classes["same_para_and_forward"]
        labels[same_para & reverse] = self.tsp_classes["same_para_and_reverse"]
        labels[neighbor & forward] = self.tsp_classes["neighbor_and_forward"]
        labels[neighbor & reverse] = self.tsp_classes["neighbor_and_reverse"]

        # Ignoring
        ignoring = padding.view(B, 1, S) | padding.view(B, S, 1)
        labels.masked_fill_(ignoring, -1)

        tsp_loss, tsp_logits, tsp_labels = self.model._text_structure_prediction(
            sentence_embeddings=self.model.get_sentence_embeddings(
                hidden_states=batch["hidden_states"],
                sentence_marks=batch["sentence_marks"],
                max_num_sentences=batch["paragraph_ids"].shape[1],
            ),
            tsp_labels=labels,
        )

        self.log(f"tsp_loss", tsp_loss, sync_dist=True)
        # if self.config.logger:
        #     self.log(f"tsp_accuracy", self.tsp_accuracy(tsp_logits, tsp_labels))
        return tsp_loss

    def inter_segment_task(self, batch, hidden_states, label):
        seq_encs = hidden_states[:, 0, :]  # (B,D)
        inter_seg_logits = self.inter_segment_head(seq_encs)  # (B,2)
        inter_seg_loss = F.cross_entropy(inter_seg_logits, label)
        self.log("inter_segment_loss", inter_seg_loss, sync_dist=True)
        return inter_seg_loss


def convert_zero_checkpoint_to_lightning_checkpoint(
    checkpoint_dir: Union[Path, str], tag: str = None
):
    import torch
    from pytorch_lightning.utilities.deepspeed import (
        convert_zero_checkpoint_to_fp32_state_dict,
    )

    checkpoint_dir = Path(checkpoint_dir)
    output_path = checkpoint_dir.with_suffix(".ckpt")

    # Get lightning checkpoint
    convert_zero_checkpoint_to_fp32_state_dict(
        checkpoint_dir=checkpoint_dir, output_file=output_path, tag=tag
    )
    ckpt = torch.load(output_path, map_location="cpu")

    # Fix state_dict and resave
    ## The converting program provided by deepspeed (and used by lightning), doesn't handle
    ## reference of parameter sharing well, such that paramters that refer to the other parameters
    ## is missing. In this case, we can just reestablish the references to fix the problem
    ## Also check my comment for details (https://github.com/microsoft/DeepSpeed/issues/1896#issuecomment-1124463675)
    # fmt: off
    if "model.generator_backbone.layers.0.transition_block.norm.weight" in ckpt['state_dict']:
        # Using ELECTRA and thus has generator-discriminator embeddings sharing
        ckpt['state_dict']["model.generator_backbone.embeddings.word_embeddings.weight"] = ckpt['state_dict']["model.backbone.embeddings.word_embeddings.weight"]
        ckpt['state_dict']["model.generator_backbone.embeddings.position_embeddings.weight"] = ckpt['state_dict']["model.backbone.embeddings.position_embeddings.weight"]
        ckpt['state_dict']["model.generator_backbone.embeddings.token_type_embeddings.weight"] = ckpt['state_dict']["model.backbone.embeddings.token_type_embeddings.weight"]
        ckpt['state_dict']["model.generator_backbone.embeddings.norm.weight"] = ckpt['state_dict']["model.backbone.embeddings.norm.weight"]
        ckpt['state_dict']["model.generator_backbone.embeddings.norm.bias"] = ckpt['state_dict']["model.backbone.embeddings.norm.bias"]
    ckpt['state_dict']["model.mlm_head.predictor.weight"] = ckpt['state_dict']["model.backbone.embeddings.word_embeddings.weight"]
    # fmt: on
    torch.save(ckpt, output_path)
