import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .backbone import resnet18, TextEncoder, ImageEncoder, ImageEncoderaudio


class GOAL(torch.autograd.Function):
    @staticmethod
    def forward(ctx, classifier, temperature, feature_dims, *features):
        ctx.classifier = classifier
        ctx.temperature = temperature
        ctx.feature_dims = feature_dims
        ctx.features = features

        concatenated_features = torch.cat(features, dim=1)
        logits = classifier(concatenated_features)

        return logits

    @staticmethod
    def backward(ctx, grad_output):
        classifier = ctx.classifier
        temperature = ctx.temperature
        feature_dims = ctx.feature_dims
        features = ctx.features
        num_modalities = len(features)

        concatenated_features = torch.cat(features, dim=1).requires_grad_(True)

        with torch.enable_grad():
            logits = classifier(concatenated_features)

        grad_f_cat_tuple = torch.autograd.grad(logits, concatenated_features, grad_outputs=grad_output)
        grad_f_cat = grad_f_cat_tuple[0]

        base_gradients = list(torch.split(grad_f_cat, feature_dims, dim=1))

        confidences = []
        with torch.no_grad():
            for i in range(num_modalities):
                mask = torch.zeros_like(concatenated_features)
                start_idx = sum(feature_dims[:i])
                end_idx = start_idx + feature_dims[i]
                mask[:, start_idx:end_idx] = features[i]

                unimodal_logits = classifier(mask)

                probs = F.softmax(unimodal_logits, dim=1)
                log_probs = F.log_softmax(unimodal_logits, dim=1)
                entropy = -torch.sum(probs * log_probs, dim=1)
                num_classes = unimodal_logits.shape[1]
                normalized_entropy = entropy / (torch.log(torch.tensor(num_classes, device=entropy.device)) + 1e-9)

                confidence = 1.0 - normalized_entropy
                confidences.append(confidence)

        confidences_tensor = torch.stack(confidences, dim=1)
        weights = F.softmax(confidences_tensor / temperature, dim=1)

        leveraged_gradients = []
        for i in range(num_modalities):
            w = weights[:, i].unsqueeze(1)
            leveraged_g = num_modalities * w * base_gradients[i]
            leveraged_gradients.append(leveraged_g)

        final_gradients = [g.clone() for g in leveraged_gradients]

        for i in range(num_modalities):
            for j in range(num_modalities):
                if i == j:
                    continue

                g_i = final_gradients[i]
                g_j = leveraged_gradients[j]

                dot_product = torch.sum(g_i * g_j, dim=1)

                conflict_mask = dot_product < 0

                if conflict_mask.any():
                    g_j_norm_sq = torch.sum(g_j ** 2, dim=1) + 1e-8
                    projection_scalar = dot_product / g_j_norm_sq

                    projection = projection_scalar[conflict_mask].unsqueeze(1) * g_j[conflict_mask]
                    final_gradients[i][conflict_mask] -= projection

        return (None, None, None, *final_gradients)


class GOALWrapper(nn.Module):
    def __init__(self, classifier, feature_dims, temperature=1.0):
        super().__init__()
        if not isinstance(classifier, nn.Linear):
            raise TypeError("The wrapped classifier must be an nn.Linear layer.")
        self.classifier = classifier
        self.feature_dims = feature_dims
        self.temperature = temperature

    def forward(self, *features):
        if len(features) != len(self.feature_dims):
            raise ValueError(f"Expected {len(self.feature_dims)} feature tensors, but got {len(features)}.")

        return GOAL.apply(self.classifier, self.temperature, self.feature_dims, *features)


class ConcatFusion(nn.Module):
    def __init__(self, input_dim=1024+512, output_dim=100):
        super(ConcatFusion, self).__init__()
        self.fc_out = nn.Linear(input_dim, output_dim)

    def forward(self, out):
        output = self.fc_out(out)
        return output



class AVClassifier(nn.Module):
    def __init__(self, args):
        super(AVClassifier, self).__init__()

        if args.dataset == 'VGGSound':
            n_classes = 309
        elif args.dataset == 'KineticSound':
            n_classes = 31
        elif args.dataset == 'CREMAD':
            n_classes = 6
        elif args.dataset == 'AVE':
            n_classes = 28
        else:
            raise NotImplementedError('Incorrect dataset name {}'.format(args.dataset))

        self.dataset = args.dataset

        self.audio_net = resnet18(modality='audio')
        self.visual_net = resnet18(modality='visual')

        self.head = nn.Linear(1024, n_classes)
        self.head_audio = nn.Linear(512, n_classes)
        self.head_video = nn.Linear(512, n_classes)



    def forward(self, audio, visual):
        a = self.audio_net(audio)
        v = self.visual_net(visual)

        (_, C, H, W) = v.size()
        B = a.size()[0]
        v = v.view(B, -1, C, H, W)
        v = v.permute(0, 2, 1, 3, 4)

        a = F.adaptive_avg_pool2d(a, 1)
        v = F.adaptive_avg_pool3d(v, 1)

        a = torch.flatten(a, 1)
        v = torch.flatten(v, 1)


        out = torch.cat((a,v),1)
        out = self.head(out)

        out_audio=self.head_audio(a)
        out_video=self.head_video(v)

        return out,out_audio,out_video


class AVClassifierGOAL(nn.Module):
    def __init__(self, args):
        super(AVClassifierGOAL, self).__init__()

        if args.dataset == 'VGGSound':
            n_classes = 309
        elif args.dataset == 'KineticSound':
            n_classes = 31
        elif args.dataset == 'CREMAD':
            n_classes = 6
        elif args.dataset == 'AVE':
            n_classes = 28
        else:
            raise NotImplementedError('Incorrect dataset name {}'.format(args.dataset))


        self.dataset = args.dataset

        self.audio_net = resnet18(modality='audio')
        self.visual_net = resnet18(modality='visual')

        self.head = nn.Linear(1024, n_classes)
        self.head_audio = nn.Linear(512, n_classes)
        self.head_video = nn.Linear(512, n_classes)
        self.classifier = GOALWrapper(self.head, feature_dims=[512, 512], temperature=1.0)


    def forward(self, audio, visual):
        a = self.audio_net(audio)
        v = self.visual_net(visual)

        (_, C, H, W) = v.size()
        B = a.size()[0]
        v = v.view(B, -1, C, H, W)
        v = v.permute(0, 2, 1, 3, 4)

        a = F.adaptive_avg_pool2d(a, 1)
        v = F.adaptive_avg_pool3d(v, 1)

        a = torch.flatten(a, 1)
        v = torch.flatten(v, 1)


        out = self.classifier(a,v)

        out_audio=self.head_audio(a)
        out_video=self.head_video(v)

        return out,out_audio,out_video

class AVClassifieraudio(nn.Module):
    def __init__(self, args):
        super(AVClassifieraudio, self).__init__()

        if args.dataset == 'VGGSound':
            n_classes = 309
        elif args.dataset == 'KineticSound':
            n_classes = 31
        elif args.dataset == 'CREMAD':
            n_classes = 6
        elif args.dataset == 'AVE':
            n_classes = 28
        else:
            raise NotImplementedError('Incorrect dataset name {}'.format(args.dataset))


        self.dataset = args.dataset

        self.audio_net = resnet18(modality='audio')
        self.visual_net = resnet18(modality='visual')

        self.head = nn.Linear(1024, n_classes)
        self.head_audio = nn.Linear(512, n_classes)
        self.head_video = nn.Linear(512, n_classes)



    def forward(self, audio, visual):
        a = self.audio_net(audio)
        v = self.visual_net(visual)

        (_, C, H, W) = v.size()

        a = F.adaptive_avg_pool2d(a, 1)

        a = torch.flatten(a, 1)

        out = []

        out_audio=self.head_audio(a)
        out_video=[]

        return out,out_audio,out_video

class AVClassifiervideo(nn.Module):
    def __init__(self, args):
        super(AVClassifiervideo, self).__init__()

        if args.dataset == 'VGGSound':
            n_classes = 309
        elif args.dataset == 'KineticSound':
            n_classes = 31
        elif args.dataset == 'CREMAD':
            n_classes = 6
        elif args.dataset == 'AVE':
            n_classes = 28
        else:
            raise NotImplementedError('Incorrect dataset name {}'.format(args.dataset))


        self.dataset = args.dataset

        self.audio_net = resnet18(modality='audio')
        self.visual_net = resnet18(modality='visual')

        self.head = nn.Linear(1024, n_classes)
        self.head_audio = nn.Linear(512, n_classes)
        self.head_video = nn.Linear(512, n_classes)



    def forward(self, audio, visual):
        a = self.audio_net(audio)
        v = self.visual_net(visual)

        (_, C, H, W) = v.size()
        B = a.size()[0]
        v = v.view(B, -1, C, H, W)
        v = v.permute(0, 2, 1, 3, 4)

        v = F.adaptive_avg_pool3d(v, 1)

        v = torch.flatten(v, 1)


        out = []

        out_audio=[]
        out_video=self.head_video(v)

        return out,out_audio,out_video


class AVClassifierGOALcmumosi(nn.Module):
    def __init__(self, args):
        super(AVClassifierGOALcmumosi, self).__init__()

        if args.dataset == 'VGGSound':
            n_classes = 309
        elif args.dataset == 'KineticSound':
            n_classes = 31
        elif args.dataset == 'CREMAD':
            n_classes = 6
        elif args.dataset == 'AVE':
            n_classes = 28
        elif args.dataset == 'MOSI':
            n_classes = 3
        elif args.dataset == 'MELD':
            n_classes = 3
        elif args.dataset == 'IEMOCAP':
            n_classes = 4
        else:
            raise NotImplementedError('Incorrect dataset name {}'.format(args.dataset))


        self.audio_net = resnet18(modality='audio')
        self.visual_net = resnet18(modality='visual')
        self.text_net = TextEncoder()

        self.head = nn.Linear(1536, n_classes)
        self.head_audio = nn.Linear(512, n_classes)
        self.head_video = nn.Linear(512, n_classes)
        self.head_text = nn.Linear(512, n_classes)
        self.classifier = GOALWrapper(self.head, feature_dims=[512, 512, 512], temperature=1.0)


    def forward(self, audio, visual, text):
        a = self.audio_net(audio)
        v = self.visual_net(visual)
        input_ids = text[0]
        attention_mask = text[1]
        t = self.text_net(input_ids, attention_mask)

        (_, C, H, W) = v.size()
        B = a.size()[0]
        v = v.view(B, -1, C, H, W)
        v = v.permute(0, 2, 1, 3, 4)

        a = F.adaptive_avg_pool2d(a, 1)
        v = F.adaptive_avg_pool3d(v, 1)

        a = torch.flatten(a, 1)
        v = torch.flatten(v, 1)

        out = self.classifier(a,v,t)

        out_audio=self.head_audio(a)
        out_video=self.head_video(v)
        out_text=self.head_text(t)

        return out,out_audio,out_video,out_text

class AVClassifiercmumosi(nn.Module):
    def __init__(self, args):
        super(AVClassifiercmumosi, self).__init__()

        if args.dataset == 'VGGSound':
            n_classes = 309
        elif args.dataset == 'KineticSound':
            n_classes = 31
        elif args.dataset == 'CREMAD':
            n_classes = 6
        elif args.dataset == 'AVE':
            n_classes = 28
        elif args.dataset == 'MOSI':
            n_classes = 3
        elif args.dataset == 'MELD':
            n_classes = 3
        elif args.dataset == 'IEMOCAP':
            n_classes = 4
        else:
            raise NotImplementedError('Incorrect dataset name {}'.format(args.dataset))


        self.audio_net = resnet18(modality='audio')
        self.visual_net = resnet18(modality='visual')
        self.text_net = TextEncoder()

        self.head = nn.Linear(1536, n_classes)
        self.head_audio = nn.Linear(512, n_classes)
        self.head_video = nn.Linear(512, n_classes)
        self.head_text = nn.Linear(512, n_classes)


    def forward(self, audio, visual, text):
        a = self.audio_net(audio)
        v = self.visual_net(visual)
        input_ids = text[0]
        attention_mask = text[1]
        t = self.text_net(input_ids, attention_mask)

        (_, C, H, W) = v.size()
        B = a.size()[0]
        v = v.view(B, -1, C, H, W)
        v = v.permute(0, 2, 1, 3, 4)

        a = F.adaptive_avg_pool2d(a, 1)
        v = F.adaptive_avg_pool3d(v, 1)

        a = torch.flatten(a, 1)
        v = torch.flatten(v, 1)


        out = self.head(torch.cat((a, v, t), 1))

        out_audio=self.head_audio(a)
        out_video=self.head_video(v)
        out_text=self.head_text(t)

        return out,out_audio,out_video,out_text


class AVClassifier2modal(nn.Module):
    def __init__(self, args):
        super(AVClassifier2modal, self).__init__()

        if args.dataset == 'VGGSound':
            n_classes = 309
        elif args.dataset == 'KineticSound':
            n_classes = 31
        elif args.dataset == 'CREMAD':
            n_classes = 6
        elif args.dataset == 'AVE':
            n_classes = 28
        elif args.dataset == 'MOSI':
            n_classes = 3
        elif args.dataset == 'MELD':
            n_classes = 3
        elif args.dataset == 'FOOD':
            n_classes = 101
        elif args.dataset == 'Hateful':
            n_classes = 2
        else:
            raise NotImplementedError('Incorrect dataset name {}'.format(args.dataset))


        self.visual_net = resnet18(modality='visual')
        self.text_net = TextEncoder()

        self.head = nn.Linear(1024, n_classes)
        self.head_video = nn.Linear(512, n_classes)
        self.head_text = nn.Linear(512, n_classes)


    def forward(self, visual, text):
        v = self.visual_net(visual)
        input_ids = text[0]
        attention_mask = text[1]
        t = self.text_net(input_ids, attention_mask)

        (_, C, H, W) = v.size()
        B = v.size()[0]
        v = v.view(B, -1, C, H, W)
        v = v.permute(0, 2, 1, 3, 4)

        v = F.adaptive_avg_pool3d(v, 1)

        v = torch.flatten(v, 1)


        out = self.head(torch.cat((v, t), 1))

        out_video=self.head_video(v)
        out_text=self.head_text(t)

        return out,out_video,out_text

class AVClassifier2modalgoal(nn.Module):
    def __init__(self, args):
        super(AVClassifier2modalgoal, self).__init__()

        if args.dataset == 'VGGSound':
            n_classes = 309
        elif args.dataset == 'KineticSound':
            n_classes = 31
        elif args.dataset == 'CREMAD':
            n_classes = 6
        elif args.dataset == 'AVE':
            n_classes = 28
        elif args.dataset == 'MOSI':
            n_classes = 3
        elif args.dataset == 'MELD':
            n_classes = 3
        elif args.dataset == 'FOOD':
            n_classes = 101
        elif args.dataset == 'Hateful':
            n_classes = 2
        else:
            raise NotImplementedError('Incorrect dataset name {}'.format(args.dataset))


        self.visual_net = resnet18(modality='visual')
        self.text_net = TextEncoder()

        self.head = nn.Linear(1024, n_classes)
        self.head_video = nn.Linear(512, n_classes)
        self.head_text = nn.Linear(512, n_classes)

        self.classifier = GOALWrapper(self.head, feature_dims=[512, 512], temperature=1.0)


    def forward(self, visual, text):
        v = self.visual_net(visual)
        input_ids = text[0]
        attention_mask = text[1]
        t = self.text_net(input_ids, attention_mask)

        (_, C, H, W) = v.size()
        B = v.size()[0]
        v = v.view(B, -1, C, H, W)
        v = v.permute(0, 2, 1, 3, 4)

        v = F.adaptive_avg_pool3d(v, 1)

        v = torch.flatten(v, 1)


        out = self.classifier(v, t)

        out_video=self.head_video(v)
        out_text=self.head_text(t)

        return out,out_video,out_text


class vitvit(nn.Module):
    def __init__(self, args):
        super(vitvit, self).__init__()

        if args.dataset == 'VGGSound':
            n_classes = 309
        elif args.dataset == 'KineticSound':
            n_classes = 31
        elif args.dataset == 'CREMAD':
            n_classes = 6
        elif args.dataset == 'AVE':
            n_classes = 28
        else:
            raise NotImplementedError('Incorrect dataset name {}'.format(args.dataset))


        self.dataset = args.dataset

        self.audio_net = ImageEncoderaudio()
        self.visual_net = ImageEncoder()

        self.head = nn.Linear(1024, n_classes)
        self.head_audio = nn.Linear(512, n_classes)
        self.head_video = nn.Linear(512, n_classes)



    def forward(self, audio, visual):
        audio = audio.repeat(1, 3, 1, 1)
        a = self.audio_net(audio)
        visual = torch.mean(visual, dim=1)
        v = self.visual_net(visual)

        out = torch.cat((a,v),1)
        out = self.head(out)

        out_audio=self.head_audio(a)
        out_video=self.head_video(v)

        return out,out_audio,out_video

class vitvitgoal(nn.Module):
    def __init__(self, args):
        super(vitvitgoal, self).__init__()

        if args.dataset == 'VGGSound':
            n_classes = 309
        elif args.dataset == 'KineticSound':
            n_classes = 31
        elif args.dataset == 'CREMAD':
            n_classes = 6
        elif args.dataset == 'AVE':
            n_classes = 28
        else:
            raise NotImplementedError('Incorrect dataset name {}'.format(args.dataset))


        self.dataset = args.dataset

        self.audio_net = ImageEncoderaudio()
        self.visual_net = ImageEncoder()

        self.head = nn.Linear(1024, n_classes)
        self.head_audio = nn.Linear(512, n_classes)
        self.head_video = nn.Linear(512, n_classes)
        self.classifier = GOALWrapper(self.head, feature_dims=[512, 512], temperature=1.0)


    def forward(self, audio, visual):
        audio = audio.repeat(1, 3, 1, 1)
        a = self.audio_net(audio)
        visual = torch.mean(visual, dim=1)
        v = self.visual_net(visual)

        out = self.classifier(a,v)

        out_audio=self.head_audio(a)
        out_video=self.head_video(v)

        return out,out_audio,out_video

class AVClassifierregress(nn.Module):
    def __init__(self, args):
        super(AVClassifierregress, self).__init__()

        self.visual_net = resnet18(modality='visual')
        self.text_net = TextEncoder()

        self.head = nn.Linear(1024, 1)
        self.head_video = nn.Linear(512, 1)
        self.head_text = nn.Linear(512, 1)


    def forward(self, visual, text):
        v = self.visual_net(visual)
        input_ids = text[0]
        attention_mask = text[1]
        t = self.text_net(input_ids, attention_mask)

        (_, C, H, W) = v.size()
        B = v.size()[0]
        v = v.view(B, -1, C, H, W)
        v = v.permute(0, 2, 1, 3, 4)

        v = F.adaptive_avg_pool3d(v, 1)

        v = torch.flatten(v, 1)


        out = self.head(torch.cat((v, t), 1))

        out_video=self.head_video(v)
        out_text=self.head_text(t)

        return out.squeeze(-1),out_video.squeeze(-1),out_text.squeeze(-1)


class AVClassifierregressgoal(nn.Module):
    def __init__(self, args):
        super(AVClassifierregressgoal, self).__init__()

        self.visual_net = resnet18(modality='visual')
        self.text_net = TextEncoder()

        self.head = nn.Linear(1024, 1)
        self.head_video = nn.Linear(512, 1)
        self.head_text = nn.Linear(512, 1)
        self.classifier = GOALWrapper(self.head, feature_dims=[512, 512], temperature=1.0)


    def forward(self, visual, text):
        v = self.visual_net(visual)
        input_ids = text[0]
        attention_mask = text[1]
        t = self.text_net(input_ids, attention_mask)

        (_, C, H, W) = v.size()
        B = v.size()[0]
        v = v.view(B, -1, C, H, W)
        v = v.permute(0, 2, 1, 3, 4)

        v = F.adaptive_avg_pool3d(v, 1)

        v = torch.flatten(v, 1)


        out = self.classifier(v, t)

        out_video=self.head_video(v)
        out_text=self.head_text(t)

        return out.squeeze(-1),out_video.squeeze(-1),out_text.squeeze(-1)

class AVClassifiercmumosiregress(nn.Module):
    def __init__(self, args):
        super(AVClassifiercmumosiregress, self).__init__()

        self.audio_net = resnet18(modality='audio')
        self.visual_net = resnet18(modality='visual')
        self.text_net = TextEncoder()

        self.head = nn.Linear(1536, 1)
        self.head_audio = nn.Linear(512, 1)
        self.head_video = nn.Linear(512, 1)
        self.head_text = nn.Linear(512, 1)


    def forward(self, audio, visual, text):
        a = self.audio_net(audio)
        v = self.visual_net(visual)
        input_ids = text[0]
        attention_mask = text[1]
        t = self.text_net(input_ids, attention_mask)

        (_, C, H, W) = v.size()
        B = a.size()[0]
        v = v.view(B, -1, C, H, W)
        v = v.permute(0, 2, 1, 3, 4)

        a = F.adaptive_avg_pool2d(a, 1)
        v = F.adaptive_avg_pool3d(v, 1)

        a = torch.flatten(a, 1)
        v = torch.flatten(v, 1)


        out = self.head(torch.cat((a, v, t), 1))

        out_audio=self.head_audio(a)
        out_video=self.head_video(v)
        out_text=self.head_text(t)

        return out,out_audio,out_video,out_text

class AVClassifierGOAL3modalregress(nn.Module):
    def __init__(self, args):
        super(AVClassifierGOAL3modalregress, self).__init__()

        self.audio_net = resnet18(modality='audio')
        self.visual_net = resnet18(modality='visual')
        self.text_net = TextEncoder()

        self.head = nn.Linear(1536, 1)
        self.head_audio = nn.Linear(512, 1)
        self.head_video = nn.Linear(512, 1)
        self.head_text = nn.Linear(512, 1)
        self.classifier = GOALWrapper(self.head, feature_dims=[512, 512, 512], temperature=1.0)


    def forward(self, audio, visual, text):
        a = self.audio_net(audio)
        v = self.visual_net(visual)
        input_ids = text[0]
        attention_mask = text[1]
        t = self.text_net(input_ids, attention_mask)

        (_, C, H, W) = v.size()
        B = a.size()[0]
        v = v.view(B, -1, C, H, W)
        v = v.permute(0, 2, 1, 3, 4)

        a = F.adaptive_avg_pool2d(a, 1)
        v = F.adaptive_avg_pool3d(v, 1)

        a = torch.flatten(a, 1)
        v = torch.flatten(v, 1)

        out = self.classifier(a,v,t)

        out_audio=self.head_audio(a)
        out_video=self.head_video(v)
        out_text=self.head_text(t)

        return out,out_audio,out_video,out_text




        
    

        
    




