from fairseq_manual.data_utils import compute_mask_indices
from hydra.utils import instantiate
import torch
import torch.distributed as dist
from torch.optim import AdamW
from pytorch_lightning import LightningModule
import random

from schedulers.warmup_cosine import WarmupCosineScheduler

from espnet.asr.asr_utils import add_results_to_json, torch_load
from espnet.nets.batch_beam_search import BatchBeamSearch
from espnet.nets.pytorch_backend.lm.transformer import TransformerLM
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
from espnet.nets.scorers.length_bonus import LengthBonus
from metrics import WER
from schedulers.warmup_cosine import WarmupCosineMomentumScheduler, ConstantThenLinearScheduler
from utils.utils import ids_to_str, set_requires_grad, UNIGRAM1000_LIST, get_param_groups


class SSLLearner(LightningModule):
    def __init__(self, cfg):
        super().__init__()
        self.save_hyperparameters(cfg)
        self.cfg = cfg

        if cfg.compile_model:
            self.model = torch.compile(instantiate(cfg.model.obj, cfg))
        else:
            self.model = instantiate(cfg.model.obj, cfg)

        if cfg.model.pretrained_model_path:
            print("Load pretrained model weights")
            ckpt = torch.load(cfg.model.pretrained_model_path, map_location=lambda storage, loc: storage)

            if cfg.model.transfer_only_encoder:
                ckpt = {k[39:]: v for k, v in ckpt.items() if k.startswith('model._orig_mod.model.backbone.encoder')}
                ckpt = {k: v for k, v in ckpt.items() if not k.startswith("after_norm")}
                self.model.model.backbone.encoder.load_state_dict(ckpt, strict=False)
                self.model.model.target_backbone.encoder.load_state_dict(ckpt, strict=False)
            else:
                self.model.load_state_dict(ckpt, strict=False)

        if cfg.debug.log_gradients:
            self.logger.experiment.watch(self.model, log="gradients")
        
        self.ignore_id = -1
        self.beam_search_audio = self.get_beam_search(self.model.model.backbone)
        self.wer_audio = WER()

        self.automatic_optimization = False
    

    def get_beam_search(self, model):
        token_list = UNIGRAM1000_LIST

        odim = len(token_list)
        self.token_list = token_list

        scorers = model.scorers()

        if self.cfg.decode.lm_weight and self.cfg.model.pretrained_lm_path:
            lm = TransformerLM(len(token_list), self.cfg.model.language_model)
            set_requires_grad(lm, False)
            print("Load pretrained language model weights")
            torch_load(self.cfg.model.pretrained_lm_path, lm)
        else:
            lm = None

        scorers["lm"] = lm
        scorers["length_bonus"] = LengthBonus(len(token_list))

        weights = dict(
            decoder=1.0 - self.cfg.decode.ctc_weight,
            ctc=self.cfg.decode.ctc_weight,
            lm=self.cfg.decode.lm_weight,
            length_bonus=self.cfg.decode.penalty,
        )
        beam_search = BatchBeamSearch(
            beam_size=self.cfg.decode.beam_size,
            vocab_size=len(token_list),
            weights=weights,
            scorers=scorers,
            sos=odim - 1,
            eos=odim - 1,
            token_list=token_list,
            pre_beam_score_key=None if self.cfg.decode.ctc_weight == 1.0 else "decoder",
        )

        return beam_search

    def calculate_wer(self, audio, padding_mask, labels):
        for aud, groundtruth, mask in zip(audio, labels, padding_mask):
            feat_a = self.model.model.get_encoded_features_audio(
                aud.unsqueeze(0), mask.unsqueeze(0).unsqueeze(-2)
            )
            
            nbest_hyps_a = self.beam_search_audio(
                    x=feat_a.squeeze(0),
                    modality="a",
                    maxlenratio=self.cfg.decode.maxlenratio,
                    minlenratio=self.cfg.decode.minlenratio
                )

            nbest_hyps_a = [
                h.asdict() for h in nbest_hyps_a[: min(len(nbest_hyps_a), 1)]
            ]

            transcription_a = add_results_to_json(nbest_hyps_a, self.token_list)
            transcription_a = transcription_a.replace("<eos>", "")
            transcription_a = transcription_a.replace("▁", " ").strip()

            print(transcription_a)
            print(groundtruth)

            self.wer_audio.update(transcription_a, groundtruth)

    def test_step(self, data, batch_idx):
        lengths = torch.tensor(data["audio_lengths"], device=data["audio"].device) // 640
        padding_mask = make_non_pad_mask(lengths).to(lengths.device)
        self.calculate_wer(
            data["audio"].transpose(1, 2),
            padding_mask, 
            data["label"], 
        )

        print(self.wer_audio.compute())

    def on_test_epoch_end(self):
        wer_audio = self.wer_audio.compute()
        print(wer_audio)
        self.log("wer_audio", wer_audio)
        self.wer_audio.reset()
