import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple

try:
    from . import MODELS
    from .video_mae import VideoMAE
    from .audio_model import W2V2_Model
except ImportError:
    from models import MODELS
    from models.video_mae import VideoMAE
    from .audio_model import W2V2_Model
try:
    from . import MODELS
    from .video_mae import VideoMAE
    from .audio_model import W2V2_Model
    from .optimal_transport import Sinkhorn_low_level, Sinkhorn_high_level
    from .map_model import Map
except ImportError:
    from models import MODELS
    from models.video_mae import VideoMAE
    from models.audio_model import W2V2_Model
    from models.optimal_transport import Sinkhorn_low_level, Sinkhorn_high_level,OT
    from models.map_model import Map




pi = 3.1415926535

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.):
        super(FeedForward, self).__init__()
        self.feed_forward = nn.Sequential(nn.Linear(d_model, d_ff),
                                          nn.Dropout(dropout),
                                          nn.GELU(),
                                          nn.Linear(d_ff, d_model),
                                          nn.Dropout(dropout))

    def forward(self, x):
        return self.feed_forward(x)


class AddNorm(nn.Module):
    def __init__(self, d_model, dropout=0.15):
        super(AddNorm, self).__init__()

        self.norm1 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.feed_forward = FeedForward(d_model, d_model, dropout)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = self.norm1(x)
        x_ = x
        x = self.dropout(x)
        x = self.feed_forward(x) + x_
        x = self.norm2(x)
        return x



class OTBasedCrossModalFusion(nn.Module):
    def __init__(self, feature_dim, eps=0.01, max_iter=100, reduction="none"):
        super().__init__()
        self.ot = OT(
            eps=eps, max_iter=max_iter, reduction=reduction
        )

        self.proj_source = nn.Linear(feature_dim, feature_dim)
        self.proj_target = nn.Linear(feature_dim, feature_dim)

    def forward(self, source_feat, target_feat):
        if source_feat.is_complex():
            source_feat = source_feat.real
        if target_feat.is_complex():
            target_feat = target_feat.real
        source_proj = self.proj_source(source_feat)

        target_proj = self.proj_target(target_feat)


        batch_size = source_proj.shape[0]
        mu = torch.ones(batch_size, source_proj.shape[1]) / source_proj.shape[1]
        mu = mu.to(source_proj.device)

        cost, pi, C = self.ot(source_proj, target_proj, mu)
        fused_feat = torch.bmm(pi, target_feat)
        fused_feat = fused_feat + source_feat

        return fused_feat



@MODELS.register
class FusionModel(nn.Module):
    def __init__(
        self,
        feature_dim: int = 512,
        num_classes: int = 2,
        num_encoders: int = 12,
        num_frames: int = 16,
        hidden_dim: int = 768,
        pretrained_model_visual: str = None,
        pretrained_model_audio: str = None,
        adapter: bool = True,
        num_adapter: int = 12,
        adapter_type: str = "efficient_conv",
        adapter_hidden_dim: int = 32,
        num_filter: int = 2,
        dropout: float = 0.15,
        act_layer=torch.tanh,
        **kwargs
    ) -> None:
        super(FusionModel, self).__init__()

        self.visual_model = VideoMAE(
            num_encoders=num_encoders,
            num_frames=num_frames,
            hidden_dim=hidden_dim,
            feature_dim=feature_dim,
            num_classes=num_classes,
            pretrained_model_visual=pretrained_model_visual,
            adapter=adapter,
            num_adapter=num_adapter,
            adapter_type=adapter_type,
            adapter_hidden_dim=adapter_hidden_dim,
            **kwargs
        )
        self.audio_model = W2V2_Model(
            num_encoders=num_encoders,
            adapter=adapter,
            adapter_type=adapter_type,
            hidden_dim=hidden_dim,
            adapter_hidden_dim=adapter_hidden_dim,
            feature_dim=feature_dim,
            num_classes=num_classes,
            pretrained_model_audio=pretrained_model_audio,
            **kwargs
        )


        self.audio_weight = nn.Parameter(torch.randn(feature_dim, feature_dim, dtype=torch.float32))
        self.image_weight = nn.Parameter(torch.randn(feature_dim, feature_dim, dtype=torch.float32))
        self.act_layer = act_layer
        self.add_norm = AddNorm(feature_dim, dropout)
        self.classifier = nn.Sequential(
            nn.Linear(feature_dim, num_classes),
            nn.Softmax(),
        )
        self.result = namedtuple("Result", ["output", "loss"])
        self.loss_fn = nn.CrossEntropyLoss(reduce="mean")


        self.visual_to_audio_ot = OTBasedCrossModalFusion(feature_dim)
        self.audio_to_visual_ot = OTBasedCrossModalFusion(feature_dim)

    @staticmethod
    def js_div(p, q):
        M = (p + q) / 2
        kl1 = F.kl_div(F.log_softmax(M, dim=-1), F.softmax(p, dim=-1), reduction='batchmean')
        kl2 = F.kl_div(F.log_softmax(M, dim=-1), F.softmax(q, dim=-1), reduction='batchmean')
        gamma = 0.5 * kl1 + 0.5 * kl2
        return gamma

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def type(self):
        return next(self.parameters()).dtype

    def forward(self, data_collection):
        label = data_collection.label.to(self.device)
        feature = self.extract_feature(data_collection)
        output = self.classifier(feature)
        loss = self.loss_fn(output, label)
        return self.result(output=output, loss=loss)



    def extract_feature(self, data_collection):
        visual_feature = self.visual_model.extract_feature(data_collection)
        audio_feature = self.audio_model.extract_feature(data_collection)

        visual_selected = self.visual_to_audio_ot(visual_feature, audio_feature)
        audio_selected = self.audio_to_visual_ot(audio_feature, visual_feature)
        visual = torch.fft.irfft(visual_selected, n=visual_selected.size(1), dim=1, norm='ortho')
        audio = torch.fft.irfft(audio_selected, n=audio_selected.size(1), dim=1, norm='ortho')

        visual =visual.mean(dim=1)
        audio = audio.mean(dim=1)

        js = self.js_div(visual, audio)

        fusion = torch.matmul(audio, self.audio_weight) + torch.matmul(visual, self.image_weight)
        feature = (1 - js) * fusion + js * visual + js * audio
        return feature