import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Set, Tuple, Union

from transformers import VideoMAEForVideoClassification
from collections import namedtuple

try:
    from models.adapter import (
        video_mae_adapter_layer_conv,
        video_mae_adapter_layer_conv_multi,
    )
    from models import MODELS
    from models.optimal_transport import OptimalTransport, OptimalTransportReg
except ImportError:
    from .adapter import (
        video_mae_adapter_layer_conv,
        video_mae_adapter_layer_conv_multi,
    )
    from . import MODELS
    from .optimal_transport import OptimalTransport, OptimalTransportReg

pi = 3.1415926535


class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.):
        super(FeedForward, self).__init__()
        self.feed_forward = nn.Sequential(nn.Linear(d_model, d_ff),
                                          nn.Dropout(dropout),
                                          nn.GELU(),
                                          nn.Linear(d_ff, d_model),
                                          nn.Dropout(dropout))

    def forward(self, x):
        return self.feed_forward(x)


class AddNorm(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(AddNorm, self).__init__()

        self.norm1 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.feed_forward = FeedForward(d_model, d_model, dropout)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = self.norm1(x)
        x_ = x
        x = self.dropout(x)
        x = self.feed_forward(x) + x_
        x = self.norm2(x)
        return x



@torch.no_grad()
def cheb_coeffs(
    K: int,
    mode: str = "T_fixed",
    device=None,
    dtype=torch.float32,
    alpha: Optional[torch.Tensor] = None,
    dolph_R: float = 60.0
):

    device = device or torch.device("cpu")
    if K <= 0:
        return torch.empty(0, device=device, dtype=dtype)

    k = torch.arange(1, K + 1, device=device, dtype=dtype)
    theta_fixed = (2 * k - 1) * (pi / (2.0 * K))

    if mode == "T_fixed":
        C = torch.cos(theta_fixed)

    elif mode == "U":
        theta0 = pi / (2.0 * K)
        denom = torch.sin(torch.tensor(theta0, device=device, dtype=dtype))
        C = torch.sin(theta_fixed) / (denom + 1e-12)
        C = C / (C.norm() + 1e-12) * (torch.cos(theta_fixed).norm() + 1e-12)

    elif mode == "T_learnable":
        if alpha is None:
            alpha = torch.tensor(1.0, device=device, dtype=dtype)
        theta_base = (pi / (2.0 * K)) * F.softplus(alpha)
        theta = (2 * k - 1) * theta_base
        C = torch.cos(theta)
        C = torch.clamp(C, -1.0, 1.0)

    elif mode == "dolph":
        Rlin = 10.0 ** (dolph_R / 20.0)
        beta = torch.cosh(torch.acosh(torch.tensor(Rlin, device=device, dtype=dtype)) / max(K - 1, 1))
        n = torch.arange(1, K + 1, device=device, dtype=dtype)
        x = beta * torch.cos(pi * n / (K + 1e-12))

        def cheb_T_m(xx, m):
            if m == 0: return torch.ones_like(xx)
            if m == 1: return xx
            Tm_2, Tm_1 = torch.ones_like(xx), xx
            for _ in range(2, m + 1):
                Tm = 2 * xx * Tm_1 - Tm_2
                Tm_2, Tm_1 = Tm_1, Tm
            return Tm_1

        num = cheb_T_m(x, K - 1)
        den = cheb_T_m(torch.tensor(beta, device=device, dtype=dtype), K - 1)
        C = num / (den + 1e-12)
        C = C / (C.norm() + 1e-12) * (torch.cos(theta_fixed).norm() + 1e-12)

    else:
        raise ValueError(f"Unknown chebyshev mode: {mode}")

    return C


class FrequencyDomainProcessing(nn.Module):
    def __init__(
        self,
        d_model,
        num_filter=2,
        dropout=0.15,
        kernel_size=None,
        cheb_mode: str = "T_fixed",
        alpha_init: float = 1.0,
        dolph_R: float = 60.0,
        **kwargs
    ):
        super().__init__()
        self.d_model = d_model
        self.num_filter = num_filter
        self.cheb_mode = cheb_mode
        self.dolph_R = dolph_R


        self.weight = nn.Parameter(torch.randn(d_model, 2))
        self.filter_bank = nn.Parameter(torch.randn(num_filter, d_model, 2))
        self.add_norm = AddNorm(d_model, dropout)


        if cheb_mode == "T_learnable":
            self.alpha = nn.Parameter(torch.tensor(alpha_init, dtype=torch.float32))
        else:
            self.register_parameter("alpha", None)


        self.kernel_size = kernel_size
        self.extra_kwargs = kwargs

    @staticmethod
    def _to_complex(w2):
        return torch.view_as_complex(w2)

    def unimodal_spectrum_compression(self, x):
        B, T, D = x.shape
        x_fft = torch.fft.rfft(x, dim=1, norm='ortho')
        power = (x_fft.real ** 2 + x_fft.imag ** 2) / (T + 0.0)

        C = cheb_coeffs(
            self.num_filter,
            mode=self.cheb_mode,
            device=x.device,
            dtype=self.filter_bank.dtype,
            alpha=self.alpha if self.alpha is not None else None,
            dolph_R=self.dolph_R
        )

        Kc = torch.view_as_complex(self.filter_bank)
        w = (Kc * C.unsqueeze(1)).sum(dim=0)
        w = w.unsqueeze(0).unsqueeze(0)

        compressed = power * w
        return compressed

    def forward(self, x):
        return self.unimodal_spectrum_compression(x)


class WrappedEncoderLayer(nn.Module):
    def __init__(self, layer):
        super().__init__()
        self.layer = layer

    def forward(self, hidden_states):
        outputs = self.layer(hidden_states)
        return outputs[0] if isinstance(outputs, tuple) else outputs


class VideoMAEEncoderAdapter(nn.Module):
    def __init__(
            self,
            num_encoder,
            num_adapter,
            hidden_dim,
            encoder,
            adapter,
            adapter_type,
            adapter_hidden_dim,
    ) -> None:
        super(VideoMAEEncoderAdapter, self).__init__()
        self.encoder = encoder
        layer_list = []
        if adapter:
            if num_adapter == num_encoder:
                for i in range(num_encoder):
                    if adapter_type == "mlp":
                        raise NotImplementedError(
                            "MLP adapter is not implemented in VideoMAE, please use efficient_conv instead."
                        )
                    elif adapter_type == "efficient_conv":
                        layer_list.append(
                            video_mae_adapter_layer_conv(
                                transformer_encoder_layer=encoder.layer[i],
                                hidden_dim=hidden_dim,
                                adapter_hidden_dim=adapter_hidden_dim,
                            )
                        )
                    else:
                        raise NotImplementedError(
                            f"Such method has not beed implemented."
                        )
            elif num_encoder % num_adapter == 0:
                block = num_encoder // num_adapter
                for i in range(0, num_encoder, block):
                    if adapter_type == "mlp":
                        raise NotImplementedError(
                            "MLP adapter is not implemented in VideoMAE, please use efficient_conv instead."
                        )
                    elif adapter_type == "efficient_conv":
                        layer_list.append(
                            video_mae_adapter_layer_conv_multi(
                                transformer_encoder_layers=encoder.layer[
                                                           i * block: (i + 1) * block
                                                           ],
                                hidden_dim=hidden_dim,
                                adapter_hidden_dim=adapter_hidden_dim,
                            )
                        )
                    else:
                        raise NotImplementedError(
                            f"Such method has not beed implemented."
                        )
            else:
                raise RuntimeError(
                    f"Number of encoders must can be divided by the number of adapters."
                )
        else:
            for i in range(num_encoder):
                for p in encoder.layer[i].parameters():
                    p.requires_grad = True
                wrapped_layer = WrappedEncoderLayer(encoder.layer[i])
                layer_list.append(wrapped_layer)
        self.layer = nn.Sequential(*layer_list)

    def forward(self, hidden_states):
        return self.layer(hidden_states)


@MODELS.register
class VideoMAE(nn.Module):
    def __init__(
            self,
            feature_dim: int = 512,
            num_classes: int = 2,
            num_encoders: int = 12,
            num_frames: int = 16,
            hidden_dim: int = 768,
            pretrained_model_visual: str = None,
            adapter: bool = True,
            num_adapter: int = 12,
            adapter_type: str = "efficient_conv",
            adapter_hidden_dim: int = 32,
            num_filter: int = 2,
            dropout: float = 0.15,
            **kwargs,
    ) -> None:
        super(VideoMAE, self).__init__()

        self.num_encoders = num_encoders
        self.adapter = adapter
        self.adapter_type = adapter_type
        self.hidden_dim = hidden_dim
        self.num_frames = num_frames
        self.num_classes = num_classes
        self.num_adapter = num_adapter
        self.feature_dim = feature_dim
        self.loss_cls = nn.CrossEntropyLoss(reduce="mean")

        self.videoMAEModel = VideoMAEForVideoClassification.from_pretrained(
            pretrained_model_visual
        )
        assert self.hidden_dim == self.videoMAEModel.videomae.config.hidden_size
        for _, p in self.videoMAEModel.named_parameters():
            p.requires_grad = False

        self.embedding = self.videoMAEModel.videomae.embeddings

        self.encoder = VideoMAEEncoderAdapter(
            self.num_encoders,
            self.num_adapter,
            self.hidden_dim,
            self.videoMAEModel.videomae.encoder,
            adapter,
            adapter_type,
            adapter_hidden_dim,
        )
        self.mapping = nn.Sequential(nn.Linear(self.hidden_dim, self.feature_dim))

        self.classifier = nn.Sequential(
            nn.Linear(self.hidden_dim, self.num_classes),
            nn.Softmax(),
        )
        self.loss_cls = nn.CrossEntropyLoss(reduce="mean")
        self.result = namedtuple("Result", ["output", "loss"])


        self.freq_processor = FrequencyDomainProcessing(
            feature_dim, num_filter, dropout,
            cheb_mode="T_learnable",
            alpha_init=1.0,
            dolph_R=60.0
        )

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def type(self):
        return next(self.parameters()).dtype

    def forward(self, data_collection):
        label = data_collection.label.to(self.device)
        feature_map = self.extract_feature(data_collection)
        output_tokens = output_tokens.mean(dim=1)
        # output = F.softmax(output, dim=1)
        loss = self.loss_cls(output, label)

        return self.result(output=output, loss=loss)

    def extract_feature(self, data_collection):
        feature_map = data_collection.visual.to(self.device)
        feature_map = self.embedding(feature_map, None)
        feature_map = self.encoder(hidden_states=feature_map)

        seq_feature = self.mapping(feature_map)

        freq_processed = self.freq_processor(seq_feature)
        freq_processed = torch.fft.irfft(freq_processed, n=freq_processed.size(1), dim=1, norm='ortho')
        freq_processed = freq_processed.mean(dim=1)
        return freq_processed

