# -*- coding: utf-8 -*-
from .modules.seanet import SEANetEncoder, SEANetDecoder
from .quantization  import ResidualVectorQuantizer
import torch.nn as nn
from einops import rearrange
import torch
import numpy as np
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self, config):
        '''
        
        Parameters
        ----------
        config : json
            Model Config.

        '''
        super().__init__()
        self.encoder = SEANetEncoder(n_filters=config.get('n_filters'), 
                                     dimension=config.get('dimension'), 
                                     ratios=config.get('strides'),
                                     lstm=config.get('lstm_layers'),
                                     bidirectional=config.get('bidirectional'),
                                     dilation_base=config.get('dilation_base'),
                                     residual_kernel_size=config.get('residual_kernel_size'),
                                     n_residual_layers=config.get('n_residual_layers'),
                                     activation=config.get('activation'))
        self.sample_rate = config.get('sample_rate')
        self.n_q = config.get('n_q')
        self.downsample_rate = np.prod(config.get('strides'))
        if config.get('dimension') != config.get('semantic_dimension'):
            self.transform = nn.Linear(config.get('dimension'), config.get('semantic_dimension'))
        else:
            self.transform = nn.Identity()
        self.quantizer = ResidualVectorQuantizer(dimension=config.get('dimension'), n_q=config.get('n_q'), bins=config.get('codebook_size'))
        self.decoder = SEANetDecoder(n_filters=config.get('n_filters'), 
                                     dimension=config.get('dimension'), 
                                     ratios=config.get('strides'),
                                     lstm=config.get('lstm_layers'),
                                     bidirectional=False,
                                     dilation_base=config.get('dilation_base'),
                                     residual_kernel_size=config.get('residual_kernel_size'),
                                     n_residual_layers=config.get('n_residual_layers'),
                                     activation=config.get('activation'))       
        self.attention = nn.MultiheadAttention(embed_dim=768, num_heads=8, batch_first=True)

    @classmethod
    def load_from_checkpoint(cls, 
                             config_path: str, 
                             ckpt_path: str):
        '''

        Parameters
        ----------
        config_path : str
            Path of model configuration file.
        ckpt_path : str
            Path of model  checkpoint.

        Returns
        -------
        model : Model
            Model model.

        '''
        import json
        with open(config_path) as f:
            cfg = json.load(f)
        model = cls(cfg)
        params = torch.load(ckpt_path, map_location='cpu')
        model.load_state_dict(params, strict=False)
        return model
    
    
    def forward(self, 
                x: torch.Tensor = None,
                n_q: int = None, 
                layers: list = [0],
                manipulate: str = "no",
                llm_rep: torch.Tensor = None):

        '''
        Perform the forward pass of the model.

        Parameters
        ----------
        x : torch.Tensor, optional
            Input waveform tensor. Shape: (batch_size, channels, timesteps).
        llm_rep : torch.Tensor, optional
            Precomputed LLM representation. Shape: (batch_size, timesteps, feature_dim).
        n_q : int, optional
            Number of quantizers in RVQ (Residual Vector Quantization) to encode the input.
            Default is `None`, which uses all quantization layers.
        layers : list[int], optional
            Specifies which RVQ layers should return quantized results. Default is `[0]`, the first layer.
        manipulate : str, optional
            Indicates whether to perform cross-attention, token alignment or standard processing.

        Returns
        -------
        Depending on `manipulate`:
        - Standard encoder-decoder output:
            - o : torch.Tensor
                Output waveform. Shape: (batch_size, channels, timesteps).
            - commit_loss : torch.Tensor
                Commitment loss from residual vector quantizers.
            - feature : torch.Tensor
                Encoded feature representation from the first RVQ layer. Shape: (batch_size, timesteps, feature_dim).
        - Standard encoder-decoder output with aligned_llm_rep: aligned LLM representation based on token similarity.
        - Standard encoder-decoder output with conditioned_llm_rep: cross-attention LLM representation (BERT conditioned on RVQ feature).
        '''
        n_q = n_q if n_q else self.n_q
        
        e = self.encoder(x)
        quantized, codes, commit_loss, quantized_list = self.quantizer(e, n_q=n_q, layers=layers)
        # feature = rearrange(quantized_list[0], 'b d t -> b t d')
        feature = rearrange(torch.stack(quantized_list).sum(0), 'b d t -> b t d')
        feature = self.transform(feature)       
        o = self.decoder(quantized)
        
        if 'align' in manipulate and llm_rep is not None:
            aligned_llm_rep = self.token_alignment(llm_rep, feature, manipulate)
            return o, commit_loss, feature, aligned_llm_rep

        elif 'cross' in manipulate and llm_rep is not None:
            # Perform the cross-attention operation
            Q, K, V = llm_rep, feature, feature
            conditioned_llm_rep, _ = self.attention(Q, K, V)
            return o, commit_loss, feature, conditioned_llm_rep    
        
        return o, commit_loss, feature


    def token_alignment(self, llm_rep, feature, manipulate):
        '''
        Align LLM representations with feature representations.

        Parameters
        ----------
        llm_rep : torch.Tensor
            LLM representation tensor. Shape: (batch_size, bert_tokens, feature_dim).
        feature : torch.Tensor
            Feature representation tensor. Shape: (batch_size, feature_tokens, feature_dim).
        manipulate : str
            String specifying alignment details, such as window size.

        Returns
        -------
        aligned_llm_rep : torch.Tensor
            LLM representation aligned to feature tokens.
        '''
        # Count the number of non-padded tokens in BERT and feature representations
        num_bert_tokens = (llm_rep.abs().sum(dim=-1) > 0).sum(dim=-1)  # Unpadded BERT tokens
        num_feature_tokens = (feature.abs().sum(dim=-1) > 0).sum(dim=-1)  # Unpadded feature tokens
        # Compute window: how many feature tokens correspond to each BERT token
        if 'window' in manipulate and manipulate.split('window')[-1].isdigit():
            window_size = int(manipulate.split('window')[-1])  # Extract integer after 'window'
            window = torch.full_like(num_feature_tokens, window_size, dtype=torch.float) 
        else:
            window = (num_feature_tokens.float() / num_bert_tokens.float()).clamp(min=1)  # Prevent division by zero
        # Initialize aligned representation with zeros (same shape as feature embeddings)
        aligned_llm_rep = torch.zeros_like(feature)
        # Iterate over each batch
        for batch in range(llm_rep.shape[0]):
            last_matched_index = 0  # Track the last matched speech token index
            for bert_idx in range(num_bert_tokens[batch]):  # Only process non-padded BERT tokens
                # Compute feature token range corresponding to the current BERT token dynamically or by scaled window
                if "dynamic" in manipulate:
                    feature_start = last_matched_index + 1 if bert_idx > 0 else 0
                    feature_end = min(int(bert_idx * window[batch].item()) + int(window[batch].item()), num_feature_tokens[batch].item())
                else:
                    feature_start = int(bert_idx * window[batch].item())
                    feature_end = min(feature_start + int(window[batch].item()), num_feature_tokens[batch].item())
                if feature_start >= feature_end:
                    feature_start = max(0, feature_end - int(window[batch].item()))
                # Compute cosine similarity only for the relevant range
                cosine_sim_sub = F.cosine_similarity(llm_rep[batch, bert_idx], 
                                feature[batch, feature_start:feature_end], dim=-1)
                # Define a threshold: 100% of the max similarity
                threshold = 1 * cosine_sim_sub.max().item()
                # Determine whether to compare using >= or <= based on similarity sign
                compare_op = torch.ge if cosine_sim_sub.max() > 0 else torch.le  
                # Find indices where similarity is close to the max threshold
                similar_feature_indices = torch.where(compare_op(cosine_sim_sub, threshold))[0] + feature_start  
                last_matched_index = similar_feature_indices[-1].item()  # Take the highest matched index
                # Fill only those indices
                aligned_llm_rep[batch, similar_feature_indices] = llm_rep[batch, bert_idx]                
        return aligned_llm_rep
    
    def forward_feature(self, 
                        x: torch.tensor, 
                        layers: list=None):
        '''

        Parameters
        ----------
        x : torch.tensor
            Input wavs. Shape should be (batch, channels, timesteps).
        layers : list[int], optional
            Layers of RVQ should return quantized result. The default is all layers.

        Returns
        -------
        quantized_list : list[torch.tensor]
            Quantized of required layers.

        '''
        e = self.encoder(x)
        layers = layers if layers else list(range(self.n_q))
        quantized, codes, commit_loss, quantized_list = self.quantizer(e, layers=layers)
        return quantized_list
    
    def encode(self, 
               x: torch.tensor, 
               n_q: int=None, 
               st: int=None):
        '''

        Parameters
        ----------
        x : torch.tensor
            Input wavs. Shape: (batch, channels, timesteps).
        n_q : int, optional
            Number of quantizers in RVQ used to encode. The default is all layers.
        st : int, optional
            Start quantizer index in RVQ. The default is 0.

        Returns
        -------
        codes : torch.tensor
            Output indices for each quantizer. Shape: (n_q, batch, timesteps)

        '''
        e = self.encoder(x)
        if st is None:
            st = 0
        n_q = n_q if n_q else self.n_q
        codes = self.quantizer.encode(e, n_q=n_q, st=st)
        return codes
    
    def decode(self, 
               codes: torch.tensor, 
               st: int=0):
        '''

        Parameters
        ----------
        codes : torch.tensor
            Indices for each quantizer. Shape: (n_q, batch, timesteps).
        st : int, optional
            Start quantizer index in RVQ. The default is 0.

        Returns
        -------
        o : torch.tensor
            Reconstruct wavs from codes. Shape: (batch, channels, timesteps)

        '''
        quantized = self.quantizer.decode(codes, st=st)
        o = self.decoder(quantized)
        return o
