import torch
import torch.nn as nn
import torch.nn.functional as F

import random
from typing import Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import fairseq
from fairseq.fairseq import checkpoint_utils
import transformers
from srm_filters import *
import argparse



############################
## FOR fine-tuned SSL MODEL
############################
class SSLModel(nn.Module):
    def __init__(self,device,noise=False):
        super(SSLModel, self).__init__()
        
        cp_path = '/mnt/storage/datasets/ido_audio_df/models/ssl/trained_models/wav2vec/xlsr2_300m.pt'   # Change the pre-trained XLSR model path. 
        model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
        self.model = model[0]
        self.device=device
        self.out_dim = 1024
        self.model.requires_grad = True
        
        return

    def extract_feat(self, input_data):
        # put the model to GPU if it not there
        if next(self.model.parameters()).device != input_data.device \
           or next(self.model.parameters()).dtype != input_data.dtype:
            self.model.to(input_data.device, dtype=input_data.dtype)
            self.model.train()

        if True:
            # input should be in shape (batch, length)
            if input_data.ndim == 3:
                input_tmp = input_data[:, 0, :]
            else:
                input_tmp = input_data
            # [batch, length, dim]
            output = self.model(input_tmp, mask=False, features_only=True)['x']
        return output



class SimpleClassifier(nn.Module):
    def __init__(self, embed_dim=1024, hidden_dim=256, num_classes=2,device='cuda:0'):
        super().__init__()
        self.device = device
        self.srm_module = ConstrainedConv1dWithResidual(1,30,5)

        #### 
        # create content network wav2vec 2.0
        ####
        self.ssl_model = SSLModel(self.device)
        #self.sls_low  =SLSLayer(self.device)
        #### 
        # create noise network wav2vec 2.0
        ####
        self.noise_ssl_model = SSLModel(self.device)
        self.noise_ssl_model.requires_grad=True
        #self.sls_high = SLSLayer(self.device)
        #self.merge_outputs = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=1)
        


        # A very small feed-forward network
        self.fc1 = nn.Linear(2 * embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
        
    def forward(self,x):
        """
        z_low, z_high: Tensors of shape (batch_size, time_steps, embed_dim)
        """
        x_copy = x.clone()
        
        
        x = self.srm_module(x)
        
        #x = self.polynomial(x)
        #-------pre-trained Wav2vec model fine tunning ------------------------##
        x_content_ssl_feat = self.ssl_model.extract_feat(x_copy.squeeze(1))
        x_noise_ssl_feat = self.noise_ssl_model.extract_feat(x.squeeze(1))
        
        x_low_copy = x_content_ssl_feat.clone()
        x_high_copy = x_noise_ssl_feat.clone()

        # 1) Aggregate across time by mean pooling
        #    (Alternatively, you can use max-pooling or even flatten everything)
        z_low_avg = torch.mean(x_content_ssl_feat, dim=1)   # shape: (batch_size, embed_dim)
        z_high_avg = torch.mean(x_noise_ssl_feat, dim=1) # shape: (batch_size, embed_dim)

        # 2) Concatenate the average-pooled features
        x = torch.cat([z_low_avg, z_high_avg], dim=-1)  # shape: (batch_size, 2*embed_dim)
        
        # 3) Forward pass
        x = F.relu(self.fc1(x))   # shape: (batch_size, hidden_dim)
        x = self.fc2(x)           # shape: (batch_size, num_classes)

        # Return logits. For a binary real/spoof, num_classes=2:
        #   - You can apply nn.CrossEntropyLoss on x directly,
        #     or if you'd rather have a single logit for "real" vs. "spoof",
        #     just set num_classes=1 and use a sigmoid.


        return x,x_low_copy,x_high_copy
