# -*- coding: utf-8 -*-

from .modules.seanet import SEANetEncoder, SEANetDecoder
from .quantization  import ResidualVectorQuantizer
from einops import rearrange
import torch
import numpy as np
import torch.nn as nn

class KeylessAttention(nn.Module):
    def __init__(self, feature_embed_size):

        r"""
            Args:
                feature_embed_size: feature embedding size
        """

        super(KeylessAttention, self).__init__()

        self.feature_embed_size = feature_embed_size
        self.attention_module = nn.Conv1d(self.feature_embed_size, 1, 1)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        r"""
            Args:
                x: feature sequence 
            Returns:
                output: attended feature embeddings
                weight: attention weight
        """
        weights = self.softmax(self.attention_module(x.transpose(1,2)).squeeze(1)).unsqueeze(-1)
        weights =  weights.expand_as(x)
        output = x*weights
        return output, weights

class Model(nn.Module):
    def __init__(self, config):
        '''
        
        Parameters
        ----------
        config : json
            Model Config.

        '''
        super().__init__()
        self.encoder = SEANetEncoder(n_filters=config.get('n_filters'), 
                                     dimension=config.get('dimension'), 
                                     ratios=config.get('strides'),
                                     lstm=config.get('lstm_layers'),
                                     bidirectional=config.get('bidirectional'),
                                     dilation_base=config.get('dilation_base'),
                                     residual_kernel_size=config.get('residual_kernel_size'),
                                     n_residual_layers=config.get('n_residual_layers'),
                                     activation=config.get('activation'))
        self.sample_rate = config.get('sample_rate')
        self.n_q = config.get('n_q')
        self.downsample_rate = np.prod(config.get('strides'))
        if config.get('dimension') != config.get('semantic_dimension'):
            self.transform = nn.Linear(config.get('dimension'), config.get('semantic_dimension'))
        else:
            self.transform = nn.Identity()
        # projection for concatenation
        self.feature_transform = nn.Linear(config.get('semantic_dimension'), config.get('dimension'))
        self.one_fusion_transform = nn.Linear(config.get('dimension')*2, config.get('dimension'))
        self.two_fusion_transform = nn.Linear(config.get('dimension')*3, config.get('dimension'))
        # self attention
        self.llm_s_b_attention = nn.MultiheadAttention(config.get('semantic_dimension'), config.get('attention_heads'), batch_first=True)
        self.hubert_s_b_attention = nn.MultiheadAttention(config.get('semantic_dimension'), config.get('attention_heads'), batch_first=True)
        self.llm_s_a_attention = nn.MultiheadAttention(config.get('dimension'), config.get('attention_heads'), batch_first=True)
        self.hubert_s_a_attention = nn.MultiheadAttention(config.get('dimension'), config.get('attention_heads'), batch_first=True)
        # keyless attention
        self.llm_k_b_attention = KeylessAttention(config.get('semantic_dimension'))
        self.hubert_k_b_attention = KeylessAttention(config.get('semantic_dimension'))
        self.llm_k_a_attention = KeylessAttention(config.get('dimension'))
        self.hubert_k_a_attention = KeylessAttention(config.get('dimension'))

        self.quantizer = ResidualVectorQuantizer(dimension=config.get('dimension'), n_q=config.get('n_q'), bins=config.get('codebook_size'))
        self.decoder = SEANetDecoder(n_filters=config.get('n_filters'), 
                                     dimension=config.get('dimension'), 
                                     ratios=config.get('strides'),
                                     lstm=config.get('lstm_layers'),
                                     bidirectional=False,
                                     dilation_base=config.get('dilation_base'),
                                     residual_kernel_size=config.get('residual_kernel_size'),
                                     n_residual_layers=config.get('n_residual_layers'),
                                     activation=config.get('activation'))
        
    @classmethod
    def load_from_checkpoint(cls, 
                             config_path: str, 
                             ckpt_path: str):
        '''

        Parameters
        ----------
        config_path : str
            Path of model configuration file.
        ckpt_path : str
            Path of model  checkpoint.

        Returns
        -------
        model : Model
            Model model.

        '''
        import json
        with open(config_path) as f:
            cfg = json.load(f)
        model = cls(cfg)
        params = torch.load(ckpt_path, map_location='cpu')
        model.load_state_dict(params, strict=False)
        return model
    
    
    def forward(self, 
                x: torch.tensor, 
                hubert_rep: torch.tensor = None, 
                llm_rep: torch.tensor = None, 
                n_q: int = None, 
                layers: list = [0], 
                fusion_strategy: str = "sum",
                attention_stage: str = "", 
                modality_dropout_rate: float = 0.1):
        '''
        Parameters
        ----------
        x : torch.tensor
            Input wavs. Shape: (batch, channels, timesteps).
        hubert_rep : torch.tensor, optional
            Hidden representation from Speech model (e.g., HuBERT). Default is None.
        llm_rep : torch.tensor, optional
            Hidden representation from LLM (e.g., BERT). Default is None.
        n_q : int, optional
            Number of quantizers in RVQ used to encode. Default is all layers.
        layers : list[int], optional
            Layers of RVQ should return quantized result. Default is the first layer.
        fusion_strategy : str, optional
            Strategy for fusing features. 'sum' to add the features; 'concat' to concatenate them. Default is "sum".
        attention_stage: str, optional
            Attention type and stage. attention - 'self' or 'keyless'; stage - 'before' or 'after' feature projection.
        modality_dropout_rate : float, optional
            Probability of dropping a modality (0.0 to 1.0). Default is 0.0 (no dropout).

        Returns
        -------
        o : torch.tensor
            Output wavs. Shape: (batch, channels, timesteps).
        commit_loss : torch.tensor
            Commitment loss from residual vector quantizers.
        feature : torch.tensor
            Output of RVQ's first layer. Shape: (batch, timesteps, dimension).
        '''
        
        n_q = n_q if n_q else self.n_q
        
        # Encoder representations
        e = self.encoder(x) # [6, 1024, 150]
        
        if "before" in attention_stage:
            if hubert_rep is not None:    
                if "keyless" in attention_stage:
                    hubert_rep = self.hubert_k_b_attention(hubert_rep)[0] # [6, 150, 768]
                elif "self" in attention_stage:
                    hubert_rep, _ = self.hubert_s_b_attention(hubert_rep, hubert_rep, hubert_rep) # [6, 150, 768]
                elif "cross" in attention_stage:
                    hubert_rep, _ = self.hubert_s_b_attention(hubert_rep, llm_rep, llm_rep) # [6, 150, 768]
            if llm_rep is not None:    
                if "keyless" in attention_stage:
                    llm_rep = self.llm_k_b_attention(llm_rep)[0] # [6, 150, 768]
                elif "self" in attention_stage:
                    llm_rep, _ = self.llm_s_b_attention(llm_rep, llm_rep, llm_rep) # [6, 150, 768]
                elif "cross" in attention_stage:
                    llm_rep, _ = self.llm_s_b_attention(llm_rep, hubert_rep, hubert_rep) # [6, 150, 768]
        
        # Linear project to match e's feature dimensions
        hubert_rep = self.feature_transform(hubert_rep) if hubert_rep is not None else None # [6, 150, 1024]
        llm_rep = self.feature_transform(llm_rep) if llm_rep is not None else None # [6, 150, 1024]
        
        if "after" in attention_stage:
            if hubert_rep is not None:    
                if "keyless" in attention_stage:
                    hubert_rep = self.hubert_k_a_attention(hubert_rep)[0] # [6, 150, 1024]
                elif "self" in attention_stage:
                    hubert_rep, _ = self.hubert_s_a_attention(hubert_rep, hubert_rep, hubert_rep) # [6, 150, 1024]
                elif "cross" in attention_stage:
                    hubert_rep, _ = self.hubert_s_a_attention(hubert_rep, llm_rep.clone(), llm_rep.clone()) # [6, 150, 1024]
            if llm_rep is not None:    
                if "keyless" in attention_stage:
                    llm_rep = self.llm_k_a_attention(llm_rep)[0] # [6, 150, 1024]
                elif "self" in attention_stage:
                    llm_rep, _ = self.llm_s_a_attention(llm_rep, llm_rep, llm_rep) # [6, 150, 1024]
                elif "cross" in attention_stage:
                    llm_rep, _ = self.llm_s_a_attention(llm_rep, hubert_rep.clone(), hubert_rep.clone()) # [6, 150, 1024]

        hubert_rep = hubert_rep.permute(0, 2, 1) if hubert_rep is not None else None # [6, 1024, 150]
        llm_rep = llm_rep.permute(0, 2, 1) if llm_rep is not None else None # [6, 1024, 150]
        
        # Apply modality dropout
        if modality_dropout_rate > 0.0:
            if hubert_rep is not None:
                dropout_mask = torch.rand(hubert_rep.size(0), 1, 1, device=hubert_rep.device) > modality_dropout_rate
                hubert_rep *= dropout_mask.float()  # Apply dropout mask
                
            if llm_rep is not None:
                dropout_mask = torch.rand(llm_rep.size(0), 1, 1, device=llm_rep.device) > modality_dropout_rate
                llm_rep *= dropout_mask.float()  # Apply dropout mask

        # Apply fusion strategy
        if fusion_strategy == "sum":
            e_fused = e.clone()  # Start with encoder representation for sum-based fusion
            if hubert_rep is not None and torch.count_nonzero(hubert_rep) > 0:
                e_fused += hubert_rep
            if llm_rep is not None and torch.count_nonzero(llm_rep) > 0:
                e_fused += llm_rep

        elif fusion_strategy == "concat":
            fusion_list = [e]
            if hubert_rep is not None and torch.count_nonzero(hubert_rep) > 0:  # Only add if it's not all zeros
                fusion_list.append(hubert_rep)
            if llm_rep is not None and torch.count_nonzero(llm_rep) > 0:  # Only add if it's not all zeros
                fusion_list.append(llm_rep)
            # concatenate available representations along the feature dimension
            e_fused = torch.cat(fusion_list, dim=1)  # [6, 2048, 150] if one extra, [6, 3072, 150] if both
        
            # Apply a linear projection to reduce the feature dimension back to 1024
            e_fused = e_fused.permute(0, 2, 1)  # [6, 150, feature_dim]
            if len(fusion_list) == 2:
                e_fused = self.one_fusion_transform(e_fused)  # [6, 150, 1024] if one feature fusion
            elif len(fusion_list) == 3:
                e_fused = self.two_fusion_transform(e_fused)  # [6, 150, 1024] if both
            e_fused = e_fused.permute(0, 2, 1)  # [6, 1024, 150]
            
        # Pass fused representation to quantizer
        quantized, codes, commit_loss, quantized_list = self.quantizer(e_fused, n_q=n_q, layers=layers)
        feature = rearrange(quantized_list[0], 'b d t -> b t d')
        feature = self.transform(feature)
        
        # Decode
        o = self.decoder(quantized)
        
        return o, commit_loss, feature
    
    def forward_feature(self, 
                        x: torch.tensor, 
                        layers: list=None):
        '''

        Parameters
        ----------
        x : torch.tensor
            Input wavs. Shape should be (batch, channels, timesteps).
        layers : list[int], optional
            Layers of RVQ should return quantized result. The default is all layers.

        Returns
        -------
        quantized_list : list[torch.tensor]
            Quantized of required layers.

        '''
        e = self.encoder(x)
        layers = layers if layers else list(range(self.n_q))
        quantized, codes, commit_loss, quantized_list = self.quantizer(e, layers=layers)
        return quantized_list
    
    def encode(self, 
               x: torch.tensor, 
               n_q: int=None, 
               st: int=None):
        '''

        Parameters
        ----------
        x : torch.tensor
            Input wavs. Shape: (batch, channels, timesteps).
        n_q : int, optional
            Number of quantizers in RVQ used to encode. The default is all layers.
        st : int, optional
            Start quantizer index in RVQ. The default is 0.

        Returns
        -------
        codes : torch.tensor
            Output indices for each quantizer. Shape: (n_q, batch, timesteps)

        '''
        e = self.encoder(x)
        if st is None:
            st = 0
        n_q = n_q if n_q else self.n_q
        codes = self.quantizer.encode(e, n_q=n_q, st=st)
        return codes
    
    def decode(self, 
               codes: torch.tensor, 
               st: int=0):
        '''

        Parameters
        ----------
        codes : torch.tensor
            Indices for each quantizer. Shape: (n_q, batch, timesteps).
        st : int, optional
            Start quantizer index in RVQ. The default is 0.

        Returns
        -------
        o : torch.tensor
            Reconstruct wavs from codes. Shape: (batch, channels, timesteps)

        '''
        quantized = self.quantizer.decode(codes, st=st)
        o = self.decoder(quantized)
        return o
