from fairseq_manual.data_utils import compute_mask_indices
from hydra.utils import instantiate
import torch
import torch.distributed as dist
from torch.optim import AdamW
from pytorch_lightning import LightningModule
import random

from schedulers.warmup_cosine import WarmupCosineScheduler

from espnet.asr.asr_utils import add_results_to_json, torch_load
from espnet.nets.batch_beam_search import BatchBeamSearch
from espnet.nets.pytorch_backend.lm.transformer import TransformerLM
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
from espnet.nets.scorers.length_bonus import LengthBonus
from metrics import WER
from schedulers.warmup_cosine import WarmupCosineMomentumScheduler, ConstantThenLinearScheduler
from utils.utils import ids_to_str, set_requires_grad, UNIGRAM1000_LIST, get_param_groups


class SSLLearner(LightningModule):
    def __init__(self, cfg):
        super().__init__()
        self.save_hyperparameters(cfg)
        self.cfg = cfg

        if cfg.compile_model:
            self.model = torch.compile(instantiate(cfg.model.obj, cfg))
        else:
            self.model = instantiate(cfg.model.obj, cfg)

        if cfg.model.pretrained_model_path:
            print("Load pretrained model weights")
            ckpt = torch.load(cfg.model.pretrained_model_path, map_location=lambda storage, loc: storage)

            if cfg.model.transfer_only_encoder:
                ckpt = {k[39:]: v for k, v in ckpt.items() if k.startswith('model._orig_mod.model.backbone.encoder')}
                ckpt = {k: v for k, v in ckpt.items() if not k.startswith("after_norm")}
                self.model.model.backbone.encoder.load_state_dict(ckpt, strict=False)
                self.model.model.target_backbone.encoder.load_state_dict(ckpt, strict=False)
            else:
                self.model.load_state_dict(ckpt, strict=False)

        if cfg.debug.log_gradients:
            self.logger.experiment.watch(self.model, log="gradients")
        
        self.ignore_id = -1
        self.beam_search_video = self.get_beam_search(self.model.model.backbone)
        self.wer_video = WER()

    def get_beam_search(self, model):
        token_list = UNIGRAM1000_LIST

        odim = len(token_list)
        self.token_list = token_list

        scorers = model.scorers()

        if self.cfg.decode.lm_weight and self.cfg.model.pretrained_lm_path:
            lm = TransformerLM(len(token_list), self.cfg.model.language_model)
            set_requires_grad(lm, False)
            print("Load pretrained language model weights")
            torch_load(self.cfg.model.pretrained_lm_path, lm)
        else:
            lm = None

        scorers["lm"] = lm
        scorers["length_bonus"] = LengthBonus(len(token_list))

        weights = dict(
            decoder=1.0 - self.cfg.decode.ctc_weight,
            ctc=self.cfg.decode.ctc_weight,
            lm=self.cfg.decode.lm_weight,
            length_bonus=self.cfg.decode.penalty,
        )
        beam_search = BatchBeamSearch(
            beam_size=self.cfg.decode.beam_size,
            vocab_size=len(token_list),
            weights=weights,
            scorers=scorers,
            sos=odim - 1,
            eos=odim - 1,
            token_list=token_list,
            pre_beam_score_key=None if self.cfg.decode.ctc_weight == 1.0 else "decoder",
        )

        return beam_search

    def shared_val_test_step(self, data):
        video, label = data["video"], data["label"]
        padding_mask_v = make_non_pad_mask(data["video_lengths"]).to(data["video"].device).unsqueeze(-2)

        features_v = self.model.model.get_encoded_features_video(video.squeeze(1), padding_mask_v)

        if self.cfg.model.ctc_rel_weight > 0.0:
            loss_ctc_v = self.model.model.backbone.ctc_v(
                features_v, torch.tensor(data["video_lengths"], device=features_v.device), data["label"].squeeze(1)
            )

            self.log("loss_ctc_v_val", loss_ctc_v, batch_size=len(label), sync_dist=True)

        if self.cfg.model.ctc_rel_weight < 1.0:
            acc_video = self.model.model.backbone.forward_labelled_video(features_v, padding_mask_v, label)[-3:]

            self.log("acc_video_val", acc_video, batch_size=len(label), sync_dist=True)

    def validation_step(self, data, batch_idx):
        self.shared_val_test_step(data)

    def calculate_wer(self, video, padding_mask, labels):
        labels = labels.squeeze(1)
        for vid, label, mask in zip(video, labels, padding_mask):
            feat_v = self.model.model.get_encoded_features_video(
                vid.unsqueeze(0), mask.unsqueeze(0).unsqueeze(-2)
            )
            
            nbest_hyps_v = self.beam_search_video(
                    x=feat_v.squeeze(0),
                    modality="v",
                    maxlenratio=self.cfg.decode.maxlenratio,
                    minlenratio=self.cfg.decode.minlenratio
                )
            
            nbest_hyps_v = [
                h.asdict() for h in nbest_hyps_v[: min(len(nbest_hyps_v), 1)]
            ]

            transcription_v = add_results_to_json(nbest_hyps_v, self.token_list)
            transcription_v = transcription_v.replace("<eos>", "")

            label = label[label != self.ignore_id]
            groundtruth = ids_to_str(label, self.token_list)

            groundtruth = groundtruth.replace("▁", " ").strip()
            transcription_v = transcription_v.replace("▁", " ").strip()

            self.wer_video.update(transcription_v, groundtruth)

    def test_step(self, data, batch_idx):
        lengths = torch.tensor(data["video_lengths"], device=data["video"].device)
        padding_mask = make_non_pad_mask(lengths).to(lengths.device)
        self.calculate_wer(
            data["video"].squeeze(1), 
            padding_mask, 
            data["label"], 
        )

        print(self.wer_video.compute())

    def on_test_epoch_end(self):
        wer_video = self.wer_video.compute()
        print(wer_video)
        self.log("wer_video", wer_video)
        self.wer_video.reset()

