import torch
import torch.nn as nn
import torchaudio

from collections import namedtuple

try:
    from . import MODELS
    from .adapter import w2v2_adapter_conv
except ImportError:
    from models import MODELS
    from models.adapter import w2v2_adapter_conv

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Set, Tuple, Union

from transformers import VideoMAEForVideoClassification
from collections import namedtuple

try:
    from models.adapter import (
        video_mae_adapter_layer_conv,
        video_mae_adapter_layer_conv_multi,
    )
    from models import MODELS
    from models.optimal_transport import OptimalTransport, OptimalTransportReg
except ImportError:
    from .adapter import (
        video_mae_adapter_layer_conv,
        video_mae_adapter_layer_conv_multi,
    )
    from . import MODELS
    from .optimal_transport import OptimalTransport, OptimalTransportReg

pi = 3.1415926535



class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.):
        super(FeedForward, self).__init__()
        self.feed_forward = nn.Sequential(nn.Linear(d_model, d_ff),
                                          nn.Dropout(dropout),
                                          nn.GELU(),
                                          nn.Linear(d_ff, d_model),
                                          nn.Dropout(dropout))

    def forward(self, x):
        return self.feed_forward(x)


class AddNorm(nn.Module):
    def __init__(self, d_model, dropout=0.15):
        super(AddNorm, self).__init__()

        self.norm1 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.feed_forward = FeedForward(d_model, d_model, dropout)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = self.norm1(x)
        x_ = x
        x = self.dropout(x)
        x = self.feed_forward(x) + x_
        x = self.norm2(x)
        return x


@torch.no_grad()
def cheb_coeffs(
    K: int,
    mode: str = "T_fixed",
    device=None,
    dtype=torch.float32,
    alpha: Optional[torch.Tensor] = None,
    dolph_R: float = 60.0
):
    device = device or torch.device("cpu")
    if K <= 0:
        return torch.empty(0, device=device, dtype=dtype)

    k = torch.arange(1, K + 1, device=device, dtype=dtype)
    theta_fixed = (2 * k - 1) * (pi / (2.0 * K))

    if mode == "T_fixed":
        C = torch.cos(theta_fixed)

    elif mode == "U":
        theta0 = pi / (2.0 * K)
        denom = torch.sin(torch.tensor(theta0, device=device, dtype=dtype))
        C = torch.sin(theta_fixed) / (denom + 1e-12)

        C = C / (C.norm() + 1e-12) * (torch.cos(theta_fixed).norm() + 1e-12)

    elif mode == "T_learnable":
        if alpha is None:
            alpha = torch.tensor(1.0, device=device, dtype=dtype)
        theta_base = (pi / (2.0 * K)) * F.softplus(alpha)
        theta = (2 * k - 1) * theta_base
        C = torch.cos(theta)
        C = torch.clamp(C, -1.0, 1.0)

    elif mode == "dolph":
        Rlin = 10.0 ** (dolph_R / 20.0)
        beta = torch.cosh(torch.acosh(torch.tensor(Rlin, device=device, dtype=dtype)) / max(K - 1, 1))
        n = torch.arange(1, K + 1, device=device, dtype=dtype)
        x = beta * torch.cos(pi * n / (K + 1e-12))

        def cheb_T_m(xx, m):
            if m == 0: return torch.ones_like(xx)
            if m == 1: return xx
            Tm_2, Tm_1 = torch.ones_like(xx), xx
            for _ in range(2, m + 1):
                Tm = 2 * xx * Tm_1 - Tm_2
                Tm_2, Tm_1 = Tm_1, Tm
            return Tm_1

        num = cheb_T_m(x, K - 1)
        den = cheb_T_m(torch.tensor(beta, device=device, dtype=dtype), K - 1)
        C = num / (den + 1e-12)
        C = C / (C.norm() + 1e-12) * (torch.cos(theta_fixed).norm() + 1e-12)

    else:
        raise ValueError(f"Unknown chebyshev mode: {mode}")

    return C


class FrequencyDomainProcessing(nn.Module):
    def __init__(
        self,
        d_model,
        num_filter=2,
        dropout=0.15,
        kernel_size=None,
        cheb_mode: str = "T_fixed",
        alpha_init: float = 1.0,
        dolph_R: float = 60.0,
        **kwargs
    ):
        super().__init__()
        self.d_model = d_model
        self.num_filter = num_filter
        self.cheb_mode = cheb_mode
        self.dolph_R = dolph_R


        self.weight = nn.Parameter(torch.randn(d_model, 2))
        self.filter_bank = nn.Parameter(torch.randn(num_filter, d_model, 2))
        self.add_norm = AddNorm(d_model, dropout)


        if cheb_mode == "T_learnable":
            self.alpha = nn.Parameter(torch.tensor(alpha_init, dtype=torch.float32))
        else:
            self.register_parameter("alpha", None)


        self.kernel_size = kernel_size
        self.extra_kwargs = kwargs

    @staticmethod
    def _to_complex(w2):
        return torch.view_as_complex(w2)

    def unimodal_spectrum_compression(self, x):

        B, T, D = x.shape
        x_fft = torch.fft.rfft(x, dim=1, norm='ortho')
        power = (x_fft.real ** 2 + x_fft.imag ** 2) / (T + 0.0)

        C = cheb_coeffs(
            self.num_filter,
            mode=self.cheb_mode,
            device=x.device,
            dtype=self.filter_bank.dtype,
            alpha=self.alpha if self.alpha is not None else None,
            dolph_R=self.dolph_R
        )


        Kc = torch.view_as_complex(self.filter_bank)
        w = (Kc * C.unsqueeze(1)).sum(dim=0)
        w = w.unsqueeze(0).unsqueeze(0)

        compressed = power * w
        return compressed

    def forward(self, x):
        return self.unimodal_spectrum_compression(x)




@MODELS.register
class W2V2_Model(nn.Module):
    def __init__(
        self,
        feature_dim: int = 512,
        num_classes: int = 2,
        num_encoders: int = 4,
        adapter: bool = True,
        adapter_type: str = "efficient_conv",
        hidden_dim: int = 768,
        adapter_hidden_dim: int = 32,
        pretrained_model_audio: str = None,
        num_filter: int = 2,
        dropout: float = 0.15,
        **kwargs,
    ):
        super(W2V2_Model, self).__init__()

        self.num_encoders = num_encoders
        self.adapter = adapter
        self.adapter_type = adapter_type
        self.hidden_dim = hidden_dim
        self.feature_dim = feature_dim
        self.num_classes = num_classes

        self.freq_processor = FrequencyDomainProcessing(
            feature_dim, num_filter, dropout,
            cheb_mode="T_learnable",
            alpha_init=1.0,
            dolph_R=60.0
        )
        model = torchaudio.pipelines.WAV2VEC2_BASE.get_model(
            dl_kwargs=dict(model_dir=pretrained_model_audio)
        )
        for p in model.parameters():
            p.requires_grad = False


        self.FEATURE_EXTRACTOR = model.feature_extractor


        self.FEATURE_PROJECTOR = nn.Sequential(
            model.encoder.feature_projection,
            model.encoder.transformer.pos_conv_embed,
            model.encoder.transformer.layer_norm,
            model.encoder.transformer.dropout,
        )


        layer_list = []

        for i in range(self.num_encoders):
            if self.adapter:
                if self.adapter_type == "mlp":
                    raise NotImplementedError(
                        "MLP adapter is not implemented in VideoMAE, please use efficient_conv instead."
                    )
                elif self.adapter_type == "efficient_conv":
                    layer_list.append(
                        w2v2_adapter_conv(
                            transformer_encoder=model.encoder.transformer.layers[i],
                            hidden_dim=hidden_dim,
                            adapter_hidden_dim=adapter_hidden_dim,
                        )
                    )
            else:

                for p in model.encoder.transformer.layers[i].parameters():
                    p.requires_grad = True
                layer_list.append(model.encoder.transformer.layers[i])

        self.TRANSFORMER = nn.Sequential(*layer_list)


        self.classifier = nn.Sequential(
            nn.Linear(self.feature_dim, self.num_classes),
            nn.Softmax(),
        )

        self.mapping = nn.Sequential(nn.Linear(self.hidden_dim, self.feature_dim))
        self.result = namedtuple("Result", ["output", "loss"])
        self.loss_fn = nn.CrossEntropyLoss(reduce="mean")

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def type(self):
        return next(self.parameters()).dtype

    def forward(self, data_collection, train_stage=False):
        label = data_collection.label.to(self.device)
        output_tokens = self.extract_feature(data_collection)
        # output_tokens = output_tokens.mean(dim=1)
        output_tokens = output_tokens
        logits = self.classifier(output_tokens)

        logits = F.softmax(logits, dim=1)
        loss = self.loss_fn(logits, label)

        return self.result(output=logits, loss=loss)

    def extract_feature(self, data_collection):
        features, _ = self.FEATURE_EXTRACTOR(
            data_collection.audio.squeeze(dim=1).to(self.device), None
        )

        projections = self.FEATURE_PROJECTOR(features)

        x = projections
        for layer in self.TRANSFORMER:
            x = layer(x)
            if isinstance(x, tuple):
                x = x[0]
        output_tokens = x


        seq_feature = self.mapping(output_tokens)

        freq_processed = self.freq_processor(seq_feature)
        # freq_processed = torch.fft.irfft(freq_processed, n=freq_processed.size(1), dim=1, norm='ortho')
        # freq_processed=freq_processed.mean(dim=1)

        return freq_processed

        # return self.mapping(output_tokens).mean(axis=1)
