from pathlib import Path
from transformers import Wav2Vec2PreTrainedModel,Wav2Vec2Config
from transformers.modeling_outputs import  ModelOutput
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2BaseModelOutput,Wav2Vec2ForPreTrainingOutput, _compute_mask_indices, Wav2Vec2FeatureEncoder,Wav2Vec2FeatureProjection
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2EncoderStableLayerNorm,Wav2Vec2Encoder,Wav2Vec2Adapter,Wav2Vec2GumbelVectorQuantizer
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
import pandas as pd
from typing import Union, Optional, List, Dict

from dataclasses import dataclass
from typing import Optional, Tuple, Union

import soundfile as sf

from collections import defaultdict
#from ANN.models.wav2vec2.utils_data  import get_dataloader,_downsample_output
from datasets import IterableDataset
import random
from transformers import Wav2Vec2PreTrainedModel, Wav2Vec2Config, Wav2Vec2Model
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2BaseModelOutput
import torch
import torch.nn as nn

@dataclass
class customWav2Vec2_quantizeModelOutput(ModelOutput):
    quantizedactivity: Tuple[torch.FloatTensor] = None
    quantizedactivity_preproj : Tuple[torch.FloatTensor] = None
    codevectorIdx: Tuple[torch.FloatTensor] = None
    activity : Tuple[torch.FloatTensor] = None
    quantizedTransformerOutput : Tuple[torch.FloatTensor] = None


class customWav2Vec2ForQuantize(Wav2Vec2PreTrainedModel):
    # custom: This class is not to be used for training.
    # We allow the training mode to be used tho in order to propagate gradient through the network
    # Importantly, the following computations will differ from eveal:
    #       The code-vector index will be taken with a Gumbel Softmax through which gradient can flow!
    # No dropout is used!

    def __init__(self, config: Wav2Vec2Config):
        super().__init__(config)
        self.wav2vec2 = customWav2Vec2Model_quantize(config)

        # self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)

        self.quantizer = customWav2Vec2GumbelVectorQuantizer(config)

        # Initialize weights and apply final processing
        self.post_init()

        # make sure that project_hid & project_q are initialized like normal linear layers
        self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
        self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)

    def forward(
        self,
        input_values: Optional[torch.Tensor],
        return_dict: Optional[bool] = None,
        return_hidden_activity : Optional[bool] = False,
    ) -> Union[Tuple, customWav2Vec2_quantizeModelOutput]:
        r"""
        mask_time_indices (torch.BoolTensor of shape (batch_size, sequence_length), *optional*):
            Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
            masked extracted features in *config.proj_codevector_dim* space.
        sampled_negative_indices (torch.BoolTensor of shape (batch_size, sequence_length, num_negatives), *optional*):
            Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
            Required input for pre-training.
        """

        ## Verify Dropout deactivations:
        # assert self.wav2vec2.feature_projection.dropout.p == 0
        # Remark: we don't use the results after dropout in the feature projection
        # so this test is meaningless


        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        outputs = self.wav2vec2(
            input_values,
            return_dict=return_dict)
        # 2. quantize all extracted features and project to final vq dim
        # Note that we skip the training dropout at this stage!!!
        quantized_features,codevector_idx = self.quantizer(outputs[0], mask_time_indices=None)
        # in train mode, codevector_idx will correspond to the codevector probabilities
        # in eval mode, codevector_idx will correspond to the codevector idxs.
        # quantized_features are the codevectors living in a space of dim 256

        quantized_featuresOut = self.project_q(quantized_features)

        ## we observed that project_q is not full rank: 252 vs 256
        # --> consequently we decided to return the original quantized feature

        if not return_dict:
            return (quantized_featuresOut,)
        if return_hidden_activity:
            return customWav2Vec2_quantizeModelOutput(
                                               quantizedactivity=quantized_featuresOut,
                                               quantizedactivity_preproj=quantized_features,
                                               codevectorIdx=codevector_idx,
                                               activity=outputs[0])
        return customWav2Vec2_quantizeModelOutput(quantizedactivity=quantized_featuresOut,
                                                  quantizedactivity_preproj=quantized_features,
                                                  codevectorIdx=codevector_idx)


    def _read_quantize(self,dl):
        assert  self.training == False

        # Batch processing:
        ### loop and batching
        quantized_activ,vectorIdxs = [],[]
        for x in dl:
            x0 = x["input_values"].to(self.device)
            with torch.no_grad():
                encoder_outputs = self.forward(input_values=x0,
                                               return_dict=True)
                quantize = encoder_outputs.quantizedactivity_preproj
                vectorIdx = encoder_outputs.codevectorIdx

                #outputs: [batch,time,units]
                quantized_activ += [quantize.detach().cpu().numpy()]
                vectorIdxs += [vectorIdx.detach().cpu().numpy()]

        quantized_activ = np.concatenate(quantized_activ,axis=0)
        vectorIdxs = np.concatenate(vectorIdxs,axis=0)
        # produces a (sounds,time,units)
        return {"quantized":quantized_activ,"vectoridx":vectorIdxs}

    def read_quantize(self,feature_extractor,sound_mat):
        dl =  get_dataloader(self,feature_extractor,sound_mat)
        return self._read_quantize(dl)


@dataclass
class customWav2Vec2_featureModelOutput(ModelOutput):
    extractFeatures: Tuple[torch.FloatTensor] = None
    
class customWav2Vec2Model_quantize(Wav2Vec2Model):
    def __init__(self, config: Wav2Vec2Config):
        super().__init__(config)
        self.config = config
        self.feature_extractor = customWav2vec2FeatureEncoder(config)
        self.feature_projection = Wav2Vec2FeatureProjection(config)

        # model only needs masking vector if mask prob is > 0.0
        if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
            self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())

        if config.do_stable_layer_norm:
            self.encoder = Wav2Vec2EncoderStableLayerNorm(config)
        else:
            self.encoder = Wav2Vec2Encoder(config)

        self.adapter = Wav2Vec2Adapter(config) if config.add_adapter else None

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_values: Optional[torch.Tensor],
        return_dict: Optional[bool] = None,
    ) -> customWav2Vec2_featureModelOutput:

        # assert not self.training

        extract_features = self.feature_extractor(input_values)
        extract_features = extract_features.transpose(1, 2)

        hidden_states, extract_features = self.feature_projection(extract_features)
        # hidden_states: extract_features but composed with a projection and a dropout
        # the projection goes from a space of dim 512 to a space of dim 768


        return customWav2Vec2_featureModelOutput(extractFeatures=extract_features)

class customWav2vec2FeatureEncoder(Wav2Vec2FeatureEncoder):
    def forward(self, input_values):
        hidden_states = input_values[:, None]
        ##  19/04/23 We remove this because the hidden_states is not the leaf any-more:
        # # make sure hidden_states require grad for gradient_checkpointing
        # if self._requires_grad and self.training:
        #     hidden_states.requires_grad = True
        # This removal was made to allow backpropagation across the model
        # Note: after the switch to encodec we do not use this strategy anymore to generate sounds...
        for conv_layer in self.conv_layers:
            if self._requires_grad and self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(conv_layer),
                    hidden_states,
                )
            else:
                hidden_states = conv_layer(hidden_states)
        return hidden_states


class customWav2Vec2GumbelVectorQuantizer(Wav2Vec2GumbelVectorQuantizer):
    """
    Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH
    GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
    """
    def __init__(self, config):
        super().__init__(config)

    def forward(self, hidden_states, mask_time_indices=None):
        # We removed the computation of the perplexity

        batch_size, sequence_length, hidden_size = hidden_states.shape

        # project to codevector dim
        hidden_states = self.weight_proj(hidden_states)
        hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)

        # assert not self.training
        # custom 19/04: re-allowed this to be used in training for the metamer generations

        if self.training:
            # sample code vector probs via gumbel in differentiateable way
            codevector_probs = nn.functional.gumbel_softmax(
                hidden_states.float(), tau=self.temperature, hard=True
            ).type_as(hidden_states)

            # # compute perplexity
            # codevector_soft_dist = torch.softmax(
            #     hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
            # )
            # perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
        else:
            # take argmax in non-differentiable way
            # compute hard codevector distribution (one hot)
            codevector_idx = hidden_states.argmax(dim=-1)
            codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
                -1, codevector_idx.view(-1, 1), 1.0
            )
            codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)

            # perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)

        codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
        # use probs to retrieve codevectors
        codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
        codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)

        # !!! Modification to zero the second codebook
      #  codevectors[:, 0, :, :] = 0
        
        codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)

        if self.training:
            return codevectors,codevector_probs
        else:
            return codevectors,codevector_idx.view(batch_size,sequence_length,self.num_groups)





class ShuffledWav2Vec2Model(nn.Module):
    def __init__(self, config: Wav2Vec2Config):
        super().__init__()
        # Create the original model
        self.model = Wav2Vec2Model(config)
        
    def shuffle_features(self, features, attention_mask=None):
        """
        Fully shuffle the temporal dimension of features independently for each item in the batch.
        """
        batch_size, seq_len, hidden_size = features.shape
        shuffled_features = features.clone()
        
        for i in range(batch_size):
            if attention_mask is not None:
                valid_positions = attention_mask[i].bool()
                valid_indices = torch.where(valid_positions)[0]
                valid_length = valid_positions.sum()
                shuffled_indices = valid_indices[torch.randperm(valid_length)]
                shuffled_features[i, valid_positions] = features[i, shuffled_indices]
            else:
                indices = torch.randperm(seq_len)
                shuffled_features[i] = features[i, indices]
        
        return shuffled_features
    
    def forward(
        self,
        input_values,
        attention_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        # Get features from the original model's feature extractor
        extract_features = self.model.feature_extractor(input_values)
        extract_features = extract_features.transpose(1, 2)
    
        # Handle attention mask
        if attention_mask is not None:
            attention_mask = self.model._get_feature_vector_attention_mask(
                extract_features.shape[1], attention_mask
            )
    
        # Shuffle features
        extract_features = self.shuffle_features(extract_features, attention_mask)
    
        # Project features
        hidden_states, extract_features = self.model.feature_projection(extract_features)
    
        # Pass through encoder
        encoder_outputs = self.model.encoder(
            hidden_states,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=True  # Force return_dict=True
        )
    
        # Always return Wav2Vec2BaseModelOutput
        return Wav2Vec2BaseModelOutput(
            last_hidden_state=encoder_outputs.last_hidden_state,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )

    @classmethod
    def from_pretrained(cls, pretrained_path, *args, **kwargs):
        config = Wav2Vec2Config.from_pretrained(pretrained_path)
        model = cls(config)
        # Load the pretrained weights into the internal model
        model.model = Wav2Vec2Model.from_pretrained(pretrained_path)
        return model

    def eval(self):
        self.model.eval()
        return super().eval()