import copy
import torch.nn as nn
import torch
import torchvision

import random

from utils import change_multihead

import open_clip

from peft import LoraModel, LoraConfig

from imagebind.models import imagebind_model
from imagebind.models.helpers import SelectElement

class AudioModel(nn.Module):
    def __init__(self, args, audio_embed_dim=768) -> None:
        super(AudioModel, self).__init__()
        self.args = args
        ImageBind = imagebind_model.imagebind_huge(pretrained=True)
        self.preprocessor = copy.deepcopy(ImageBind.modality_preprocessors['audio'])
        self.backbone = copy.deepcopy(ImageBind.modality_trunks['audio'])
        self.layernorm = nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6)
        self.clf = nn.Linear(audio_embed_dim, args.num_classes, bias=False)
        del ImageBind

    def forward(self, inputs):

        features = self.forward_feature(inputs)
        
        outputs = self.clf(features)
        
        return outputs
    
    def forward_feature(self, inputs):
        assert inputs.ndim >= 5
        B, S = inputs.shape[:2]
        if hasattr(self.args, 'FreezeEnc') and self.args.FreezeEnc:
            with torch.no_grad():
                inputs_value = inputs.reshape(B * S, *inputs.shape[2:])
                embeddings = self.preprocessor(**{'audio': inputs_value})
                trunk_inputs = embeddings["trunk"]
                features = self.backbone(**trunk_inputs)
        else:
            inputs_value = inputs.reshape(B * S, *inputs.shape[2:])
            embeddings = self.preprocessor(**{'audio': inputs_value})
            trunk_inputs = embeddings["trunk"]
            features = self.backbone(**trunk_inputs)
        features = self.layernorm(features)
        features = features[:, 0, ...]
        features = features.reshape(B, S, -1)
        features = features.mean(dim=1)
        return features

class VideoModel(nn.Module):
    def __init__(self, args) -> None:
        super(VideoModel, self).__init__()
        self.args = args
        embedding_dim = 768 if args.clip_model =='ViT-L-14' else 512
        # clip_model, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='datacomp_xl_s13b_b90k')
        # clip_model, _, _ = open_clip.create_model_and_transforms('ViT-B-16', pretrained='laion2b_s34b_b88k')
        clip_model, _, _ = open_clip.create_model_and_transforms(args.clip_model, pretrained=args.clip_pretraining_data)
        
        self.clip_visual_model = copy.deepcopy(clip_model.visual)
        del clip_model

        self.layernorm = nn.LayerNorm(normalized_shape=embedding_dim, eps=1e-6)
        self.clf = nn.Linear(embedding_dim, args.num_classes, bias=False)
    
    def forward(self, inputs):
        features = self.forward_feature(inputs)
        outputs = self.clf(features)#B, num_classes
        return outputs

    def forward_feature(self, inputs):
        assert inputs.ndim >= 5 #B, 3, 3, 224, 224
        B, S = inputs.shape[:2] 
        if hasattr(self.args, 'FreezeEnc') and self.args.FreezeEnc:
            with torch.no_grad():
                inputs_value = inputs.reshape(B * S, *inputs.shape[2:])#3B, 3, 224, 224
                features = self.clip_visual_model(inputs_value) #3B, 768
        else:
            inputs_value = inputs.reshape(B * S, *inputs.shape[2:])#3B, 3, 224, 224
            features = self.clip_visual_model(inputs_value) #3B, 768
        features = self.layernorm(features)#3B,  768
        features = features.reshape(B, S, -1)#B, 3, 768
        features = features.mean(dim=1)#B, 768
        return features

class OGMGEMMmodel(nn.Module):
    def __init__(self, args):
        super(OGMGEMMmodel, self).__init__()
        self.audio_model = AudioModel(args)
        self.video_model = VideoModel(args)
        audio_embed_dim = 768
        video_embedding_dim = 768 if args.clip_model =='ViT-L-14' else 512

      
    def forward(self, audio, video):
        audio_features = self.audio_model.forward_feature(audio)
        video_features = self.video_model.forward_feature(video)

        audio_predict = self.audio_model.clf(audio_features)
        video_predict = self.video_model.clf(video_features)


        av_predict = audio_predict + video_predict

        return audio_predict, video_predict, av_predict

class MultiTaskMMmodel(nn.Module):
    def __init__(self, args):
        super(MultiTaskMMmodel, self).__init__()
        self.modalDrop = args.modalDrop
        self.audio_model = AudioModel(args)
        self.video_model = VideoModel(args)
        audio_embed_dim = 768
        video_embedding_dim = 768 if args.clip_model =='ViT-L-14' else 512

        # self.clf = nn.Linear(audio_embed_dim+video_embedding_dim, args.num_classes)
        self.a_clf = nn.Linear(audio_embed_dim, args.num_classes)
        self.v_clf = nn.Linear(video_embedding_dim, args.num_classes)
        
    
    def forward(self, audio, video):
        audio_features = self.audio_model.forward_feature(audio)
        video_features = self.video_model.forward_feature(video)
        
        # audio_predict = self.audio_model.clf(audio_features)
        # video_predict = self.video_model.clf(video_features)

        # audio_features, video_features = self.modality_drop(audio_features, video_features)

        # av_feature = torch.cat([audio_features, video_features], dim=1)
        # av_predict = self.clf(av_feature)
        audio_predict = self.a_clf(audio_features)
        video_predict = self.v_clf(video_features)

        av_predict = self.audio_model.clf(audio_features) + self.video_model.clf(video_features)

        return audio_predict, video_predict, av_predict


class MMmodel(nn.Module):
    def __init__(self, args):
        super(MMmodel, self).__init__()
        self.modalDrop = args.modalDrop
        self.audio_model = AudioModel(args)
        self.video_model = VideoModel(args)
        audio_embed_dim = 768
        video_embedding_dim = 768 if args.clip_model =='ViT-L-14' else 512

        # self.clf = nn.Linear(audio_embed_dim+video_embedding_dim, args.num_classes)
        self.a_clf = nn.Linear(audio_embed_dim, args.num_classes)
        self.v_clf = nn.Linear(video_embedding_dim, args.num_classes)
        
    
    def forward(self, audio, video):
        audio_features = self.audio_model.forward_feature(audio)
        video_features = self.video_model.forward_feature(video)
        
        # audio_predict = self.audio_model.clf(audio_features.detach())
        # video_predict = self.video_model.clf(video_features.detach())

        # audio_features, video_features = self.modality_drop(audio_features, video_features)

        # av_feature = torch.cat([audio_features, video_features], dim=1)
        # av_predict = self.clf(av_feature)
        audio_predict = self.a_clf(audio_features.detach())
        video_predict = self.v_clf(video_features.detach())

        # av_predict = self.a_clf(audio_features) + self.v_clf(video_features)
        av_predict = self.audio_model.clf(audio_features) + self.video_model.clf(video_features)

        return audio_predict, video_predict, av_predict

    def modality_drop(self, audio_feature, video_feature):
        new_audio_feature = audio_feature
        new_video_feature = video_feature
        p = random.random()
        if p < self.modalDrop:
            new_audio_feature = torch.zeros_like(audio_feature).cuda()
            new_audio_feature = new_audio_feature.detach()
        elif p < self.modalDrop * 2:
            new_video_feature = torch.zeros_like(video_feature).cuda()
            new_video_feature = new_video_feature.detach()
        return new_audio_feature, new_video_feature




class UMEMMLoRAModel(nn.Module):
    def __init__(self, args):
        super(UMEMMLoRAModel, self).__init__()
        self.args = args
        self.modalDrop = args.modalDrop
        self.audio_model = AudioModel(args)
        self.video_model = VideoModel(args)

        # audio_embed_dim = 768
        # video_embedding_dim = 768 if args.clip_model =='ViT-L-14' else 512

        # self.clf = nn.Linear(audio_embed_dim+video_embedding_dim, args.num_classes)
        # self.a_clf = nn.Linear(audio_embed_dim, args.num_classes)
        # self.v_clf = nn.Linear(video_embedding_dim, args.num_classes)
        
    def lora_the_model(self):
        change_multihead(self.audio_model.backbone, mode='A')
        change_multihead(self.video_model.clip_visual_model, mode='V')
        lora_config = LoraConfig(
                r=self.args.lora_r,
                lora_alpha= self.args.lora_alpha,
                target_modules=['qkv']
            )
        
        if self.args.lora_modal == 'mm':
            self.audio_model = LoraModel(self.audio_model, lora_config, "default")
            self.video_model = LoraModel(self.video_model, lora_config, "default")

        elif self.args.lora_modal == 'a':
            self.audio_model = LoraModel(self.audio_model, lora_config, "default")
            for n, p in self.video_model.named_parameters():
                p.requires_grad = False

        elif self.args.lora_modal == 'v':
            self.video_model = LoraModel(self.video_model, lora_config, "default")
            for n, p in self.audio_model.named_parameters():
               p.requires_grad = False
        else:
            exit()
        print('lora done:', self.args.lora_modal)

    def forward(self, audio, video):
        if self.args.lora_modal == 'mm':
            audio_predict = self.audio_model(audio)
            video_predict = self.video_model(video)
            
            # audio_features = self.audio_model.forward_feature(audio)
            # video_features = self.video_model.forward_feature(video)
            # audio_predict = self.audio_model.clf(audio_features.detach())
            # video_predict = self.video_model.clf(video_features.detach())
            # av_predict = self.audio_model.clf(audio_features) +  self.video_model.clf(video_features)

        elif self.args.lora_modal == 'a':
            with torch.no_grad():
                self.video_model.eval()
                video_predict = self.video_model(video)
            audio_predict = self.audio_model(audio)
        elif self.args.lora_modal == 'v':
            with torch.no_grad():
                self.audio_model.eval()
                audio_predict = self.audio_model(audio)
            video_predict = self.video_model(video)

            
        av_predict = audio_predict + video_predict

        return audio_predict, video_predict, av_predict

class ResNet18_Video(nn.Module):
    def __init__(self, args) -> None:
        super(ResNet18_Video, self).__init__()
        self.backbone = torchvision.models.resnet18(pretrained=True)
        self.backbone.fc = nn.Identity()

        self.layernorm = nn.LayerNorm(normalized_shape=512, eps=1e-6)
        self.clf = nn.Linear(512, args.num_classes, bias=False)
    
    def forward_feature(self, inputs):
        assert inputs.ndim >= 5
        B, S = inputs.shape[:2]
        inputs_value = inputs.reshape(B * S, *inputs.shape[2:])
        features = self.backbone(inputs_value) #B*S, 512
        features = self.layernorm(features) #B*S, 512
        features = features.reshape(B, S, -1)#B, 3, 512
        features = features.mean(dim=1)#B, 512
        return features

    def forward(self, inputs):
        features = self.forward_feature(inputs)
        return self.clf(features)


class ResNet18_Audio(nn.Module):
    def __init__(self, args) -> None:
        super(ResNet18_Audio, self).__init__()
        self.backbone = torchvision.models.resnet18(pretrained=True)
        self.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
        self.backbone.fc = nn.Identity()

        self.layernorm = nn.LayerNorm(normalized_shape=512, eps=1e-6)
        self.clf = nn.Linear(512, args.num_classes, bias=False)
    
    def forward_feature(self, inputs):
        assert inputs.ndim >= 5
        B, S = inputs.shape[:2]
        inputs_value = inputs.reshape(B * S, *inputs.shape[2:])
        features = self.backbone(inputs_value) #B*S, 512
        features = self.layernorm(features) #B*S, 512
        features = features.reshape(B, S, -1)#B, 3, 512
        features = features.mean(dim=1)#B, 512
        return features

    def forward(self, inputs):
        features = self.forward_feature(inputs)
        return self.clf(features)

class ResNetMMmodel(nn.Module):
    def __init__(self, args) -> None:
        super(ResNetMMmodel, self).__init__()
        self.args = args
        self.audio_model = ResNet18_Audio(args)
        self.video_model = ResNet18_Video(args)
        self.a_clf = nn.Linear(512, args.num_classes)
        self.v_clf = nn.Linear(512, args.num_classes)
        
    
    def forward(self, audio, video):
        audio_features = self.audio_model.forward_feature(audio)
        video_features = self.video_model.forward_feature(video)
        
        audio_predict = self.a_clf(audio_features.detach())
        video_predict = self.v_clf(video_features.detach())

        mm_predict = self.audio_model.clf(audio_features) + self.video_model.clf(video_features)
        return audio_predict, video_predict, mm_predict


class ResNetUMEMMLoRAModel(nn.Module):
    def __init__(self, args) -> None:
        super(ResNetUMEMMLoRAModel, self).__init__()
        self.args = args
        self.audio_model = ResNet18_Audio(args)
        self.video_model = ResNet18_Video(args)
    
    def lora_the_model(self):
        lora_config = LoraConfig(
            r=self.args.lora_r,
            lora_alpha= self.args.lora_alpha,
            target_modules=['conv1', 'conv2']
        )
        # self.audio_model = LoraModel(self.audio_model, lora_config, "default")
        # for n, p in self.video_model.named_parameters():
        #     p.requires_grad = False
        
        if self.args.lora_modal == 'mm':
            self.audio_model = LoraModel(self.audio_model, lora_config, "default")
            self.video_model = LoraModel(self.video_model, lora_config, "default")

        elif self.args.lora_modal == 'a':
            self.audio_model = LoraModel(self.audio_model, lora_config, "default")
            for n, p in self.video_model.named_parameters():
                p.requires_grad = False

        elif self.args.lora_modal == 'v':
            self.video_model = LoraModel(self.video_model, lora_config, "default")
            for n, p in self.audio_model.named_parameters():
               p.requires_grad = False
        else:
            exit()
        print('lora done:', self.args.lora_modal)


    def forward(self, audio, video):
        if not hasattr(self.args, 'lora_modal'):
            audio_predict = self.audio_model(audio)
            video_predict = self.video_model(video)
            mm_predict = audio_predict + video_predict
            return audio_predict, video_predict, mm_predict
        
        if self.args.lora_modal == 'mm':
            audio_predict = self.audio_model(audio)
            video_predict = self.video_model(video)
            
        elif self.args.lora_modal == 'a':
            with torch.no_grad():
                self.video_model.eval()
                video_predict = self.video_model(video)
            audio_predict = self.audio_model(audio)
        
        elif self.args.lora_modal == 'v':
            with torch.no_grad():
                self.audio_model.eval()
                audio_predict = self.audio_model(audio)
            video_predict = self.video_model(video)

        mm_predict = audio_predict + video_predict
        return audio_predict, video_predict, mm_predict



if __name__ == "__main__":
    # import torch
    # audio_model = AudioModel(num_classes=32)
    # x = torch.randn(10, 3, 3, 224, 224)
    # audio_model(x)
    # visual_model = ImageModel(num_classes=32)
    # visual_model(x)
    import argparse
    parser = argparse.ArgumentParser(description='Training Uni-Video Model')
    parser.add_argument('--num_classes', type=int,default=28)
    parser.add_argument('--clip_model', type=str, default='ViT-B-16', choices=['ViT-L-14', 'ViT-B-16'],help='using which size of model')
    parser.add_argument('--clip_pretraining_data', type=str, default='laion2b_s34b_b88k', 
                        choices=['datacomp_xl_s13b_b90k', 'laion2b_s34b_b88k'],help='using which dataset for pre-training')
    

    args = parser.parse_args()
    
    mm_model = OGMGEMMmodel(args)
    for n, p in mm_model.named_parameters():
        print(n, p.size())
    # audio = torch.randn(1, 3, 1, 128, 204)
    # video = torch.randn(1, 3, 3, 224, 224)
    # mm_model(audio, video)
    # x = torch.randn(3, 3,3,224,224)
    # model = ResNet18_Audio(args)
    # print(model(x))
    # res18 = torchvision.models.resnet18(pretrained=True)
    # print(res18)