import cv2
import copy
import logging
import numpy as np
import torch
from torch import nn
import torch.utils.model_zoo as model_zoo
from torch.autograd import Variable
import torch.nn.functional as F
from itertools import chain

from inclearn.lib import factory
from inclearn.lib.network.mlp import Fusion, MLP, MaxOut_MLP

logger = logging.getLogger(__name__)


class AudioVisualNet(nn.Module):

    def __init__(
        self,
        dataset_name,
        n_classes,
        device=None
    ):
        super(AudioVisualNet, self).__init__()
        self.device = device
        
        self.audio_proj = nn.Linear(768, 768).to(self.device)
        self.visual_proj = nn.Linear(768, 768).to(self.device)
        self.attn_audio_proj = nn.Linear(768, 768).to(self.device)
        self.attn_visual_proj = nn.Linear(768, 768).to(self.device)

        self.classifier = nn.Linear(768, n_classes).to(self.device)
        
        self.to(self.device)

    @property
    def features_dim(self):
        return self.classifier.in_features
        
    def on_task_end(self):
        pass

    def on_epoch_end(self):
        pass

    def forward(self, inputs):
        outputs = {}
        
        visual = inputs['visual'].view(inputs['visual'].shape[0], 8, -1, 768)
        spatial_attn_score, temporal_attn_score = self._audio_visual_attention(inputs['audio'], visual)
        outputs['attention_spatial'] = spatial_attn_score
        outputs['attention_temporal'] = temporal_attn_score
        
        visual_pooled_feature = torch.sum(spatial_attn_score * visual, dim=2)
        visual_pooled_feature = torch.sum(temporal_attn_score * visual_pooled_feature, dim=1)
        
        outputs['features_audio'] = F.relu(self.audio_proj(inputs['audio']))
        outputs['features_visual'] = F.relu(self.visual_proj(visual_pooled_feature))
        outputs['features_fused'] = outputs['features_visual'] + outputs['features_audio'] 

        outputs['logits'] = self.classifier(outputs['features_fused'])

        return outputs
    
    def _audio_visual_attention(self, audio_features, visual_features):

        proj_audio_features = torch.tanh(self.attn_audio_proj(audio_features))
        proj_visual_features = torch.tanh(self.attn_visual_proj(visual_features))

        # (BS, 8, 14*14, 768)
        spatial_score = torch.einsum("ijkd,id->ijkd", [proj_visual_features, proj_audio_features])
        # (BS, 8, 14*14, 768)
        spatial_attn_score = F.softmax(spatial_score, dim=2)
        # (BS, 8, 768)
        spatial_attned_proj_visual_features = torch.sum(spatial_attn_score * proj_visual_features, dim=2)

        # (BS, 8, 768)
        temporal_score = torch.einsum("ijd,id->ijd", [spatial_attned_proj_visual_features, proj_audio_features])
        temporal_attn_score = F.softmax(temporal_score, dim=1)

        return spatial_attn_score, temporal_attn_score
    
    def freeze(self, trainable=False, model="all"):
        if isinstance(model, str):
            model = [model]
            
        modules = {}   
        for m in model:
            if m == "all":
                modules['all'] = self
            elif m == "proj_audio":
                modules['proj_audio'] = self.audio_proj
            elif m == "proj_visual":
                modules['proj_visual'] = self.visual_proj
            elif m == "att_audio":
                modules['att_audio'] = self.attn_audio_proj
            elif m == "att_visual":
                modules['att_visual'] = self.attn_visual_proj
            elif m == "classifier":
                modules['classifier'] = self.classifier
            else:
                assert False, m

        for k, m in modules.items():
            if not isinstance(m, nn.Module):
                continue

            for name, param in m.named_parameters():
                param.requires_grad = trainable

            if not trainable:
                m.eval()
            else:
                m.train()

        return self

    def get_group_parameters(self):
        groups = {
            "classifier": self.classifier.parameters(), 
            "proj_audio": self.audio_proj.parameters(),
            "proj_visual": self.visual_proj.parameters(),
            "att_audio": self.attn_audio_proj.parameters(),
            "att_visual": self.attn_visual_proj.parameters()
        }
        
        return groups

    def copy(self):
        return copy.deepcopy(self)
    
    def save(self, path):        
        torch.save(self.state_dict(), path)      

    def load(self, path):
        save_states = torch.load(path)        
        self.load_state_dict(save_states)        




