import copy
import torch.nn as nn
import torch
import torchvision

import random


import open_clip


from imagebind.models import imagebind_model
from imagebind.models.helpers import SelectElement

class AudioModel(nn.Module):
    def __init__(self, args, audio_embed_dim=768) -> None:
        super(AudioModel, self).__init__()
        self.args = args
        ImageBind = imagebind_model.imagebind_huge(pretrained=True)
        self.preprocessor = copy.deepcopy(ImageBind.modality_preprocessors['audio'])
        self.backbone = copy.deepcopy(ImageBind.modality_trunks['audio'])
        self.layernorm = nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6)
        self.clf = nn.Linear(audio_embed_dim, args.num_classes, bias=False)
        del ImageBind

    def forward(self, inputs):

        features = self.forward_feature(inputs)
        
        outputs = self.clf(features)
        
        return outputs
    
    def forward_feature(self, inputs):
        assert inputs.ndim >= 5
        B, S = inputs.shape[:2]
        if hasattr(self.args, 'FreezeEnc') and self.args.FreezeEnc:
            with torch.no_grad():
                inputs_value = inputs.reshape(B * S, *inputs.shape[2:])
                embeddings = self.preprocessor(**{'audio': inputs_value})
                trunk_inputs = embeddings["trunk"]
                features = self.backbone(**trunk_inputs)
        else:
            inputs_value = inputs.reshape(B * S, *inputs.shape[2:])
            embeddings = self.preprocessor(**{'audio': inputs_value})
            trunk_inputs = embeddings["trunk"]
            features = self.backbone(**trunk_inputs)
        features = self.layernorm(features)
        features = features[:, 0, ...]
        features = features.reshape(B, S, -1)
        features = features.mean(dim=1)
        return features

class VideoModel(nn.Module):
    def __init__(self, args) -> None:
        super(VideoModel, self).__init__()
        self.args = args
        embedding_dim = 768 if args.clip_model =='ViT-L-14' else 512
        # clip_model, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='datacomp_xl_s13b_b90k')
        # clip_model, _, _ = open_clip.create_model_and_transforms('ViT-B-16', pretrained='laion2b_s34b_b88k')
        clip_model, _, _ = open_clip.create_model_and_transforms(args.clip_model, pretrained=args.clip_pretraining_data)
        
        self.clip_visual_model = copy.deepcopy(clip_model.visual)
        del clip_model

        self.layernorm = nn.LayerNorm(normalized_shape=embedding_dim, eps=1e-6)
        self.clf = nn.Linear(embedding_dim, args.num_classes, bias=False)
    
    def forward(self, inputs):
        features = self.forward_feature(inputs)
        outputs = self.clf(features)#B, num_classes
        return outputs

    def forward_feature(self, inputs):
        assert inputs.ndim >= 5 #B, 3, 3, 224, 224
        B, S = inputs.shape[:2] 
        if hasattr(self.args, 'FreezeEnc') and self.args.FreezeEnc:
            with torch.no_grad():
                inputs_value = inputs.reshape(B * S, *inputs.shape[2:])#3B, 3, 224, 224
                features = self.clip_visual_model(inputs_value) #3B, 768
        else:
            inputs_value = inputs.reshape(B * S, *inputs.shape[2:])#3B, 3, 224, 224
            features = self.clip_visual_model(inputs_value) #3B, 768
        features = self.layernorm(features)#3B,  768
        features = features.reshape(B, S, -1)#B, 3, 768
        features = features.mean(dim=1)#B, 768
        return features

class MultiTaskMMmodel(nn.Module):
    def __init__(self, args):
        super(MultiTaskMMmodel, self).__init__()
        self.modalDrop = args.modalDrop
        self.audio_model = AudioModel(args)
        self.video_model = VideoModel(args)
        audio_embed_dim = 768
        video_embedding_dim = 768 if args.clip_model =='ViT-L-14' else 512

        # self.clf = nn.Linear(audio_embed_dim+video_embedding_dim, args.num_classes)
        self.a_clf = nn.Linear(audio_embed_dim, args.num_classes)
        self.v_clf = nn.Linear(video_embedding_dim, args.num_classes)
        
    
    def forward(self, audio, video):
        audio_features = self.audio_model.forward_feature(audio)
        video_features = self.video_model.forward_feature(video)
        
        # audio_predict = self.audio_model.clf(audio_features)
        # video_predict = self.video_model.clf(video_features)

        # audio_features, video_features = self.modality_drop(audio_features, video_features)

        # av_feature = torch.cat([audio_features, video_features], dim=1)
        # av_predict = self.clf(av_feature)
        audio_predict = self.a_clf(audio_features)
        video_predict = self.v_clf(video_features)

        av_predict = self.audio_model.clf(audio_features) + self.video_model.clf(video_features)

        return audio_predict, video_predict, av_predict


class MMmodel(nn.Module):
    def __init__(self, args):
        super(MMmodel, self).__init__()
        self.modalDrop = args.modalDrop
        self.audio_model = AudioModel(args)
        self.video_model = VideoModel(args)
        audio_embed_dim = 768
        video_embedding_dim = 768 if args.clip_model =='ViT-L-14' else 512

        # self.clf = nn.Linear(audio_embed_dim+video_embedding_dim, args.num_classes)
        self.a_clf = nn.Linear(audio_embed_dim, args.num_classes)
        self.v_clf = nn.Linear(video_embedding_dim, args.num_classes)
        
    
    def forward(self, audio, video):
        audio_features = self.audio_model.forward_feature(audio)
        video_features = self.video_model.forward_feature(video)
        
        # audio_predict = self.audio_model.clf(audio_features.detach())
        # video_predict = self.video_model.clf(video_features.detach())

        # audio_features, video_features = self.modality_drop(audio_features, video_features)

        # av_feature = torch.cat([audio_features, video_features], dim=1)
        # av_predict = self.clf(av_feature)
        audio_predict = self.a_clf(audio_features.detach())
        video_predict = self.v_clf(video_features.detach())

        # av_predict = self.a_clf(audio_features) + self.v_clf(video_features)
        av_predict = self.audio_model.clf(audio_features) + self.video_model.clf(video_features)

        return audio_predict, video_predict, av_predict

    def modality_drop(self, audio_feature, video_feature):
        new_audio_feature = audio_feature
        new_video_feature = video_feature
        p = random.random()
        if p < self.modalDrop:
            new_audio_feature = torch.zeros_like(audio_feature).cuda()
            new_audio_feature = new_audio_feature.detach()
        elif p < self.modalDrop * 2:
            new_video_feature = torch.zeros_like(video_feature).cuda()
            new_video_feature = new_video_feature.detach()
        return new_audio_feature, new_video_feature





if __name__ == "__main__":
    # import torch
    # audio_model = AudioModel(num_classes=32)
    # x = torch.randn(10, 3, 3, 224, 224)
    # audio_model(x)
    # visual_model = ImageModel(num_classes=32)
    # visual_model(x)
    import argparse
    parser = argparse.ArgumentParser(description='Training Uni-Video Model')
    parser.add_argument('--num_classes', type=int,default=28)
    parser.add_argument('--clip_model', type=str, default='ViT-B-16', choices=['ViT-L-14', 'ViT-B-16'],help='using which size of model')
    parser.add_argument('--clip_pretraining_data', type=str, default='laion2b_s34b_b88k', 
                        choices=['datacomp_xl_s13b_b90k', 'laion2b_s34b_b88k'],help='using which dataset for pre-training')
    

    args = parser.parse_args()
    
    # audio = torch.randn(1, 3, 1, 128, 204)
    # video = torch.randn(1, 3, 3, 224, 224)
    # mm_model(audio, video)
    # x = torch.randn(3, 3,3,224,224)
    # model = ResNet18_Audio(args)
    # print(model(x))
    # res18 = torchvision.models.resnet18(pretrained=True)
    # print(res18)