import copy
import math

from espnet.nets.pytorch_backend.ctc import CTC
from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E
from espnet.nets.pytorch_backend.nets_utils import lengths_list_to_bool, pad_list, th_accuracy
from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos
from espnet.nets.pytorch_backend.transformer.encoder_layer import DropPath
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
from espnet.nets.pytorch_backend.transformer.layer_norm import InstanceNorm, LayerNorm
from hydra.utils import instantiate
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

from .utils import EMA, set_requires_grad


class USRSingle(nn.Module):
    def __init__(self, cfg, backbone_args=None, pred_other=None):
        super().__init__()
        self.cfg = cfg

        self.odim = 1049
        self.ignore_id = -1

        self.backbone = E2E(self.odim, backbone_args)
        self.predictor_other = instantiate(pred_other) if pred_other else None
        self.target_backbone = self.get_target_model(self.backbone)
        self.out_layer_unlabelled_v = nn.Linear(backbone_args.ddim, self.odim)
        self.out_layer_unlabelled_a = nn.Linear(backbone_args.ddim, self.odim)
        self.out_layer_unlabelled_av = nn.Linear(backbone_args.ddim, self.odim)
        # self.out_layer_unlabelled_ctc_v = nn.Linear(backbone_args.adim, self.odim)
        # self.out_layer_unlabelled_ctc_a = nn.Linear(backbone_args.adim, self.odim)
        # self.out_layer_unlabelled_ctc_av = nn.Linear(backbone_args.adim, self.odim)
        self.unlabelled_ctc_v = CTC(
            self.odim, backbone_args.adim, backbone_args.dropout_rate, ctc_type=backbone_args.ctc_type, reduce="none", return_lengths=True
        )
        self.unlabelled_ctc_a = CTC(
            self.odim, backbone_args.adim, backbone_args.dropout_rate, ctc_type=backbone_args.ctc_type, reduce="none", return_lengths=True
        )
        self.unlabelled_ctc_av = CTC(
            self.odim, backbone_args.adim, backbone_args.dropout_rate, ctc_type=backbone_args.ctc_type, reduce="none", return_lengths=True
        )
        # self.ctc_test = CTC(self.odim, backbone_args.adim, backbone_args.dropout_rate, ctc_type=backbone_args.ctc_type, reduce="mean")
        # self.ctc_test2 = CTC(self.odim, backbone_args.adim, backbone_args.dropout_rate, ctc_type=backbone_args.ctc_type, reduce="sum")

        self.d_idim = backbone_args.d_idim
        self.ema = EMA()            
        self.target_dropout_off = cfg.model.target_dropout_off

        self.layer_norm = LayerNorm(backbone_args.adim)
        self.instance_norm = InstanceNorm(backbone_args.adim)

        self.sos = self.odim - 1
        self.eos = self.odim - 1

    def update_moving_average(self, momentum):
        self.ema.update_moving_average(self.target_backbone, self.backbone, momentum)

    def get_target_model(self, model):
        target_model = copy.deepcopy(model)
        set_requires_grad(target_model, False)
        return target_model

    def set_dropout_mode(self, model, train_mode):
        for m in model.modules():
            if isinstance(m, (DropPath, nn.Dropout)):
                if train_mode:
                    m.train()
                else:
                    m.eval()

    @torch.no_grad()
    def get_target_features(self, x_v, x_a, padding_mask):
        if self.target_dropout_off:
            self.set_dropout_mode(self.target_backbone.encoder, train_mode=False) 

        e = self.target_backbone.encoder.forward_single(x_v, x_a, padding_mask, return_feats=True)[0]

        if self.target_dropout_off:
            self.set_dropout_mode(self.target_backbone.encoder, train_mode=True)

        return e

    @torch.no_grad()
    def get_encoder_targets(self, e):
        if self.target_dropout_off:
            self.set_dropout_mode(self.target_backbone.ctc_av, train_mode=False) 

        ctc_out = self.target_backbone.ctc_av.ctc_lo(e)

        if self.target_dropout_off:
            self.set_dropout_mode(self.target_backbone.ctc_av, train_mode=True)
        
        return ctc_out

    @torch.no_grad()
    def get_ctc_targets_with_probs(
        self,
        probs: torch.Tensor,
        padding_mask: torch.Tensor = None,
        blank_index: int = 0
    ):
        """
        Greedy CTC decoding for a batch of probs, returning:
        1) Collapsed token IDs (padded),
        2) Per-symbol probabilities (padded),
            derived by averaging frame-level log probabilities then exponentiating.

        The probability for the appended EOS token is fixed at 0.9.

        Args:
            probs (torch.Tensor):
                Model output of shape (batch_size, time_steps, vocab_size)
                or (time_steps, batch_size, vocab_size). Adjust if needed.
            padding_mask (torch.Tensor, optional):
                Boolean mask of shape (batch_size, time_steps).
                True => valid frame, False => padded/invalid frame (skip).
            blank_index (int):
                Index in the vocab used for the CTC blank token.

        Returns:
            padded_tokens: (batch_size, max_decoded_len) long tensor
            tokens_mask: (batch_size, max_decoded_len) bool tensor
            padded_probs: (batch_size, max_decoded_len) float tensor
            - Each row corresponds to the collapsed tokens for one sample
            - Probability for EOS is fixed at 1.0
        """


        # 1) Convert probs to log-probs: shape => (batch_size, time_steps, vocab_size)
        log_probs = torch.log(probs + 1e-8)

        # 2) Greedy decode: pick the argmax token at each time step
        predicted_tokens = log_probs.argmax(dim=-1)  # (batch_size, time_steps)
        batch_size, time_steps = predicted_tokens.shape

        decoded_tokens_per_sample = []
        decoded_probs_per_sample = []

        for b in range(batch_size):
            tokens_list = []
            probs_list = []

            prev_token = None
            sum_logprob = 0.0
            count = 0

            for t in range(time_steps):
                # Skip this timestep if it's masked out (padded)
                if padding_mask is not None and not padding_mask[b, t]:
                    continue

                # Identify the predicted token and its log-prob
                token_id = predicted_tokens[b, t].item()
                token_logprob = log_probs[b, t, token_id].item()

                if token_id == blank_index:
                    # If we see a blank, finalize the previous token if it exists
                    if prev_token is not None and count > 0:
                        # Average the log-prob
                        avg_logprob = sum_logprob / count
                        # Convert to probability
                        avg_prob = math.exp(avg_logprob)
                        tokens_list.append(prev_token)
                        probs_list.append(avg_prob)
                        # Reset
                        prev_token = None
                        sum_logprob = 0.0
                        count = 0
                else:
                    # Non-blank token
                    if token_id == prev_token:
                        # Same repeated token => keep accumulating
                        sum_logprob += token_logprob
                        count += 1
                    else:
                        # Different token => finalize the old one if needed
                        if prev_token is not None and count > 0:
                            avg_logprob = sum_logprob / count
                            avg_prob = math.exp(avg_logprob)
                            tokens_list.append(prev_token)
                            probs_list.append(avg_prob)
                        # Start new run
                        prev_token = token_id
                        sum_logprob = token_logprob
                        count = 1

            # After time_steps, if there's an unfinished run, finalize it
            if prev_token is not None and count > 0:
                avg_logprob = sum_logprob / count
                avg_prob = math.exp(avg_logprob)
                tokens_list.append(prev_token)
                probs_list.append(avg_prob)

            # # 4) Append EOS with probability 1.0
            # tokens_list.append(self.eos)
            # probs_list.append(1.0)

            # Convert to tensors
            tokens_tensor = torch.tensor(tokens_list, device=probs.device, dtype=torch.long)
            probs_tensor = torch.tensor(probs_list, device=probs.device, dtype=torch.float)

            decoded_tokens_per_sample.append(tokens_tensor)
            decoded_probs_per_sample.append(probs_tensor)

        # 5) Pad all sequences to the same length
        padded_tokens = pad_list(decoded_tokens_per_sample, pad_value=-1)
        # tokens_mask = lengths_list_to_bool(decoded_tokens_per_sample, offset=1)  # 1 offset due to EOS not being present
        tokens_mask = lengths_list_to_bool(decoded_tokens_per_sample)  # 1 offset due to EOS not being present
        padded_probs = pad_list(decoded_probs_per_sample, pad_value=0.0)

        return padded_tokens, tokens_mask, padded_probs

    def ctc_greedy_merge_collapse_batch(
        self,
        logits: torch.Tensor,
        blank: int = 0,
        pad_value: int = -1,
    ) -> torch.Tensor:
        """
        Performs greedy CTC decoding (argmax → merge repeats → remove blanks)
        on a batch of logits.

        Args:
            logits:   Tensor of shape (B, T, C), raw logits or log‐probs.
            blank:    int, index of the blank symbol.
            pad_value:int, value to pad shorter sequences to full batch max length.

        Returns:
            Tensor of shape (B, U_max) where U_max is the max collapsed length
            in the batch, padded with `pad_value`.
        """
        # 1) Greedy argmax over the class dimension → (B, T)
        preds = logits.argmax(dim=-1)

        collapsed_seqs = []
        for seq in preds:  # loop over batch
            # 2) merge repeated labels
            #    torch.unique_consecutive collapses runs of the same label
            merged = torch.unique_consecutive(seq)

            # 3) remove blank symbols
            if blank is not None:
                merged = merged[merged != blank]

            collapsed_seqs.append(merged)

        # 4) pad to a rectangular tensor
        #    pad_sequence will right‐pad shorter seqs with pad_value
        padded = pad_sequence(collapsed_seqs,
                            batch_first=True,
                            padding_value=pad_value)
        return padded  # shape (B, U_max)

    @torch.no_grad()
    def get_decoder_targets(self, x, padding_mask, max_sequence_lengths=None, modality="av"):
        if max_sequence_lengths is None:
            max_sequence_lengths = torch.full((len(x),), x.size(1), device=x.device)

        if self.target_dropout_off:
            self.set_dropout_mode(self.target_backbone.decoder, train_mode=False)

        ys_in = torch.zeros((len(x), 1, self.d_idim), dtype=x.dtype, device=x.device)
        min_val = float(torch.finfo(ys_in.dtype).min)
        max_val = float(torch.finfo(ys_in.dtype).max)
        ys_in.fill_(min_val)
        ys_in[:, :, -1] = max_val

        cache = None

        out = [None] * len(x)
        idcs = torch.tensor(list(range(len(x)))).to(x.device)

        for _ in range(x.size(1) // 3):
        # for _ in range(x.size(1)):
            ys_mask = torch.stack([subsequent_mask(ys_in.size(1), device=x.device)] * len(x))

            ys_out, cache = self.target_backbone.decoder.forward_one_step(
                ys_in if self.cfg.model.soft_inputs else ys_in.argmax(-1), 
                ys_mask,
                x, 
                padding_mask,
                cache=cache, 
            )

            if modality == "v":
                out_layer = self.target_backbone.decoder.out_layer_v
            elif modality == "a":
                out_layer = self.target_backbone.decoder.out_layer_a
            else:  # modality == "av"
                out_layer = self.target_backbone.decoder.out_layer_av

            ys_out = out_layer(ys_out)
            ys_in = torch.cat([ys_in, ys_out.unsqueeze(1)], dim=1)

            # is_eos = ys_out.argmax(dim=-1) == self.eos
            is_eos = (ys_out.argmax(dim=-1) == self.eos) | (max_sequence_lengths == ys_in.size(1) - 1)
            ended_idcs = torch.nonzero(is_eos, as_tuple=False).view(-1).to(x.device)
            remain_idcs = torch.nonzero(is_eos == 0, as_tuple=False).view(-1).to(x.device)
            for i in ended_idcs:
                i = i.item()
                out[idcs[i]] = ys_in[i][1:]
            
            idcs = idcs[remain_idcs]
            ys_in = ys_in[remain_idcs]
            x = x[remain_idcs]
            padding_mask = padding_mask[remain_idcs]
            max_sequence_lengths = max_sequence_lengths[remain_idcs]

            cache = [c[remain_idcs] for c in cache]

            if not len(idcs):
                break

        for i, idx in enumerate(idcs):
            idx = idx.item()
            out[idx] = ys_in[i][1:]
                
        if self.target_dropout_off:
            self.set_dropout_mode(self.target_backbone.decoder, train_mode=True)

        return pad_list(out, min_val), lengths_list_to_bool(out)
    
    def get_encoded_features(self, video, audio, padding_mask, return_feats=False):
        x_v, x_a, x_av, _, _ = self.backbone.encoder(video, audio, padding_mask, return_feats=return_feats)
        return x_v, x_a, x_av

    def get_encoded_features_video(self, video, padding_mask, return_feats=False):
        return self.backbone.encoder.forward_single(xs_v=video, masks=padding_mask, return_feats=return_feats)[0]

    def get_encoded_features_audio(self, audio, padding_mask, return_feats=False):
        return self.backbone.encoder.forward_single(xs_a=audio, masks=padding_mask, return_feats=return_feats)[0]

    def get_encoded_features_libri(self, audio, padding_mask, return_feats=False):
        return self.target_backbone.encoder.forward_single(xs_a=audio, masks=padding_mask, return_feats=return_feats)[0]
    
    # def get_encoder_losses(
    #         self, x_v, x_a, x_av, padding_mask, ctc_targets, is_labelled, mask_conf=None,
    #     ):
    #     if is_labelled:
    #         loss_ctc_v = self.backbone.ctc_v(x_v, padding_mask.sum(-1).squeeze(-1), ctc_targets)
    #         loss_ctc_a = self.backbone.ctc_a(x_a, padding_mask.sum(-1).squeeze(-1), ctc_targets)
    #         loss_ctc_av = self.backbone.ctc_av(x_av, padding_mask.sum(-1).squeeze(-1), ctc_targets)
    #     else:
    #         if ctc_targets is not None:
    #             pred_ctc_v = self.out_layer_unlabelled_ctc_v(x_v)
    #             loss_ctc_v = self.backbone.criterion_ctc(
    #                 pred_ctc_v, ctc_targets.argmax(-1), torch.logical_or(~mask_conf, ~padding_mask.squeeze(-2))
    #             )
    #             pred_ctc_a = self.out_layer_unlabelled_ctc_a(x_a)
    #             loss_ctc_a = self.backbone.criterion_ctc(
    #                 pred_ctc_a, ctc_targets.argmax(-1), torch.logical_or(~mask_conf, ~padding_mask.squeeze(-2))
    #             )
    #             pred_ctc_av = self.out_layer_unlabelled_ctc_av(x_av)
    #             loss_ctc_av = self.backbone.criterion_ctc(
    #                 pred_ctc_av, ctc_targets.argmax(-1), torch.logical_or(~mask_conf, ~padding_mask.squeeze(-2))
    #             )
    #         else:
    #             loss_ctc_v = loss_ctc_a = loss_ctc_av = None

    #     return loss_ctc_v, loss_ctc_a, loss_ctc_av

    def get_encoder_losses(
            self, x_v, x_a, x_av, padding_mask, ctc_targets, is_labelled, mask_conf=None, labels_aux=None, mask_conf_aux=None
        ):
        if is_labelled:
            loss_ctc_v = self.backbone.ctc_v(x_v, padding_mask.sum(-1), ctc_targets)
            loss_ctc_a = self.backbone.ctc_a(x_a, padding_mask.sum(-1), ctc_targets)
            loss_ctc_av = self.backbone.ctc_av(x_av, padding_mask.sum(-1), ctc_targets)
        else:
            loss_ctc_v, target_lengths = self.unlabelled_ctc_v(x_v, padding_mask.sum(-1), ctc_targets)
            if self.cfg.model.normalize_length:
                loss_ctc_v = loss_ctc_v / torch.clamp(target_lengths, min=1)
            loss_ctc_v = loss_ctc_v.masked_fill(~mask_conf, 0.0)
            loss_ctc_v = loss_ctc_v.sum()  # already divided by batch size in the espnet CTC implementation

            loss_ctc_a, target_lengths = self.unlabelled_ctc_a(x_a, padding_mask.sum(-1), ctc_targets)
            if self.cfg.model.normalize_length:
                loss_ctc_a = loss_ctc_a / torch.clamp(target_lengths, min=1)
            loss_ctc_a = loss_ctc_a.masked_fill(~mask_conf, 0.0)
            loss_ctc_a = loss_ctc_a.sum()

            loss_ctc_av, target_lengths = self.unlabelled_ctc_av(x_av, padding_mask.sum(-1), ctc_targets)
            if self.cfg.model.normalize_length:
                loss_ctc_av = loss_ctc_av / torch.clamp(target_lengths, min=1)
            loss_ctc_av = loss_ctc_av.masked_fill(~mask_conf, 0.0)
            loss_ctc_av = loss_ctc_av.sum()

            if labels_aux is not None:
                loss_ctc_aux_v, target_lengths = self.unlabelled_ctc_v(x_v, padding_mask.sum(-1), labels_aux)
                if self.cfg.model.normalize_length:
                    loss_ctc_aux_v = loss_ctc_aux_v / torch.clamp(target_lengths, min=1)
                loss_ctc_aux_v = loss_ctc_aux_v.masked_fill(~mask_conf_aux, 0.0)
                loss_ctc_aux_v = loss_ctc_aux_v.sum()

                loss_ctc_aux_a, target_lengths = self.unlabelled_ctc_a(x_a, padding_mask.sum(-1), labels_aux)
                if self.cfg.model.normalize_length:
                    loss_ctc_aux_a = loss_ctc_aux_a / torch.clamp(target_lengths, min=1)
                loss_ctc_aux_a = loss_ctc_aux_a.masked_fill(~mask_conf_aux, 0.0)
                loss_ctc_aux_a = loss_ctc_aux_a.sum()

                loss_ctc_aux_av, target_lengths = self.unlabelled_ctc_av(x_av, padding_mask.sum(-1), labels_aux)
                if self.cfg.model.normalize_length:
                    loss_ctc_aux_av = loss_ctc_aux_av / torch.clamp(target_lengths, min=1)
                loss_ctc_aux_av = loss_ctc_aux_av.masked_fill(~mask_conf_aux, 0.0)
                loss_ctc_aux_av = loss_ctc_aux_av.sum()

                loss_ctc_v = (1 - self.cfg.model.ctc_aux_weight)*loss_ctc_v + self.cfg.model.ctc_aux_weight*loss_ctc_aux_v
                loss_ctc_a = (1 - self.cfg.model.ctc_aux_weight)*loss_ctc_a + self.cfg.model.ctc_aux_weight*loss_ctc_aux_a
                loss_ctc_av = (1 - self.cfg.model.ctc_aux_weight)*loss_ctc_av + self.cfg.model.ctc_aux_weight*loss_ctc_aux_av

                # loss_ctc_v = loss_ctc_aux_v
                # loss_ctc_a = loss_ctc_aux_a
                # loss_ctc_av = loss_ctc_aux_av

        return loss_ctc_v, loss_ctc_a, loss_ctc_av
    
    # def get_decoder_losses(
    #         self, 
    #         x_v,
    #         x_a, 
    #         x_av,
    #         padding_mask, 
    #         labels, 
    #         is_labelled, 
    #         mask_conf=None,
    #         mask_targets=None,
    #     ):
    #     if is_labelled:
    #         loss_att_v, loss_att_a, loss_att_av, acc_v, acc_a, acc_av = self.backbone.forward_labelled(
    #             x_v, x_a, x_av, padding_mask, labels
    #         )
    #     else:
    #         e_v, e_a, e_av = self.backbone.forward_unlabelled(x_v, x_a, x_av, padding_mask, labels)
            
    #         pred_v = self.out_layer_unlabelled_v(e_v)
    #         loss_att_v = self.backbone.criterion_u(pred_v, labels.argmax(-1), torch.logical_or(~mask_conf, ~mask_targets))
    #         acc_v = th_accuracy(
    #                 pred_v.view(-1, self.odim), labels.argmax(-1), ignore_label=self.ignore_id
    #             )

    #         pred_a = self.out_layer_unlabelled_a(e_a)
    #         loss_att_a = self.backbone.criterion_u(pred_a, labels.argmax(-1), torch.logical_or(~mask_conf, ~mask_targets))
    #         acc_a = th_accuracy(
    #                 pred_a.view(-1, self.odim), labels.argmax(-1), ignore_label=self.ignore_id
    #             )

    #         pred_av = self.out_layer_unlabelled_av(e_av)
    #         loss_att_av = self.backbone.criterion_u(pred_av, labels.argmax(-1), torch.logical_or(~mask_conf, ~mask_targets))
    #         acc_av = th_accuracy(
    #                 pred_av.view(-1, self.odim), labels.argmax(-1), ignore_label=self.ignore_id
    #             )

    #     return loss_att_v, loss_att_a, loss_att_av, acc_v, acc_a, acc_av


    def get_decoder_losses(
            self, 
            x_v,
            x_a, 
            x_av,
            padding_mask, 
            labels, 
            is_labelled, 
            labels_in=None,
            mask_conf=None,
            labels_aux=None,
            mask_conf_aux=None,
        ):
        if is_labelled:
            loss_att_v, loss_att_a, loss_att_av, acc_v, acc_a, acc_av = self.backbone.forward_labelled(
                x_v, x_a, x_av, padding_mask, labels
            )
        else:
            e_v, e_a, e_av = self.backbone.forward_unlabelled(x_v, x_a, x_av, padding_mask, labels_in)
            
            pred_v = self.out_layer_unlabelled_v(e_v)
            loss_att_v = self.backbone.criterion_u(pred_v, labels, ~mask_conf)
            acc_v = th_accuracy(
                    pred_v.view(-1, self.odim), labels, ignore_label=self.ignore_id
                )

            pred_a = self.out_layer_unlabelled_a(e_a)
            loss_att_a = self.backbone.criterion_u(pred_a, labels, ~mask_conf)
            acc_a = th_accuracy(
                    pred_a.view(-1, self.odim), labels, ignore_label=self.ignore_id
                )

            pred_av = self.out_layer_unlabelled_av(e_av)
            loss_att_av = self.backbone.criterion_u(pred_av, labels, ~mask_conf)
            acc_av = th_accuracy(
                    pred_av.view(-1, self.odim), labels, ignore_label=self.ignore_id
                )

            if labels_aux is not None:
                labels_aux = add_sos_eos(labels_aux, self.sos, self.eos, self.ignore_id)[1]

                loss_att_aux_v = self.backbone.criterion_u(pred_v, labels_aux, ~mask_conf_aux)
                loss_att_aux_a = self.backbone.criterion_u(pred_a, labels_aux, ~mask_conf_aux)
                loss_att_aux_av = self.backbone.criterion_u(pred_av, labels_aux, ~mask_conf_aux)

                loss_att_v = (1 - self.cfg.model.dec_aux_weight)*loss_att_v + self.cfg.model.dec_aux_weight*loss_att_aux_v
                loss_att_a = (1 - self.cfg.model.dec_aux_weight)*loss_att_a + self.cfg.model.dec_aux_weight*loss_att_aux_a
                loss_att_av = (1 - self.cfg.model.dec_aux_weight)*loss_att_av + self.cfg.model.dec_aux_weight*loss_att_aux_av

                # loss_att_v = loss_att_aux_v
                # loss_att_a = loss_att_aux_a
                # loss_att_av = loss_att_aux_av

        return loss_att_v, loss_att_a, loss_att_av, acc_v, acc_a, acc_av


class USR(nn.Module):
    def __init__(self, cfg=None):
        super().__init__()
        pred_v2a = cfg.model.predictor_2a if cfg.model.v2a_weight else None
        self.model = USRSingle(cfg, cfg.model.backbone, pred_v2a)
        self.cfg = cfg

    def update_moving_average(self, momentum):
        self.model.update_moving_average(momentum)