import torch
import torch.nn as nn
import torch.nn.functional as F
from .backbone import resnet18
from .fusion_modules import SumFusion, ConcatFusion, FiLM
from utils.gradcam import GradCAM
from utils.dynamic_drop import DynamicDrop



class AVClassifier(nn.Module):
    def __init__(self, args):
        super(AVClassifier, self).__init__()
        self.args = args

        fusion = args.fusion_method
        if args.dataset == 'CREMAD':
            n_classes = 6
        elif args.dataset == 'AVE':
            n_classes = 28
        elif args.dataset == 'KineticSound':
            n_classes = 31
        else:
            raise NotImplementedError('Incorrect dataset name {}'.format(args.dataset))

        if fusion == 'sum':
            self.fusion_module = SumFusion(output_dim=n_classes)
        elif fusion == 'concat':
            self.fusion_module = ConcatFusion(output_dim=n_classes)
        elif fusion == 'film':
            self.fusion_module = FiLM(output_dim=n_classes, x_film=True)
        else:
            raise NotImplementedError('Incorrect fusion method: {}!'.format(fusion))

        self.audio_net = resnet18(modality='audio')
        self.visual_net = resnet18(modality='visual')

        self.audio_fc = nn.Linear(512, n_classes)
        self.visual_fc = nn.Linear(512, n_classes)

        self.drop_mode = getattr(args, 'drop_mode', 'spatial')
        self.drop_init = getattr(args, 'drop_init', 0.25)

        self.cam_v = GradCAM(self.visual_net, self.visual_net.out_conv)
        self.cam_a = GradCAM(self.audio_net, self.audio_net.out_conv)
        self.drop_v = DynamicDrop(self.drop_mode, self.drop_init)
        self.drop_a = DynamicDrop(self.drop_mode, self.drop_init)


    def _pool_fuse(self, a_feat, v_feat):
        B = a_feat.size(0)
        _, C, h, w = v_feat.size()
        v_feat = v_feat.view(B, -1, C, h, w).permute(0, 2, 1, 3, 4)
        a_emb = F.adaptive_avg_pool2d(a_feat, 1).flatten(1)
        v_emb = F.adaptive_avg_pool3d(v_feat, 1).flatten(1)
        return self.fusion_module(a_emb, v_emb)

    def forward(self, audio, visual, apply_drop=False):

        a_feat = self.audio_net(audio)
        v_feat = self.visual_net(visual)
        a_emb, v_emb, logits = self._pool_fuse(a_feat, v_feat)
        if not apply_drop:
            return a_emb, v_emb, logits
        if self.args.mask_type == 'gradcam':
            idx = logits.argmax(dim=1)
            scores = logits.gather(1, idx.unsqueeze(1)).squeeze(1)
            cam_v = self.cam_v(scores)
            cam_a = self.cam_a(scores)
            v_mask = self.drop_v(v_feat, cam_v)
            a_mask = self.drop_a(a_feat, cam_a)
        else:  # 'none'
            v_mask, a_mask = v_feat, a_feat

        return self._pool_fuse(a_mask, v_mask)
