from typing import List, Dict

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from timm.layers import trunc_normal_

from model.backbone import resnet
from model.torchvggish.vggish import VGGish
from model.module.query_generator import PromptGenerator
from model.module.transformer import ELFDecoder
from model.module.repvgg import RepVGGBlock

spatial_reduction = {
    0: 4,
    1: 8,
    2: 16,
    3: 32
}


class Interpolate(nn.Module):
    def __init__(self, scale_factor, mode, align_corners=False):
        super(Interpolate, self).__init__()

        self.scale_factor = scale_factor
        self.mode = mode
        self.align_corners = align_corners

    def forward(self, x: Tensor):
        x = F.interpolate(
            x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
        )
        return x


class AVESFormer(nn.Module):
    def __init__(self,
                 img_size: int,
                 backbone: str,
                 pretrained,
                 in_channels: List[int],
                 vggish: Dict,
                 audio_dim: int,
                 embed_dim: int,
                 num_classes: int,
                 query_generator: Dict,
                 decoder: Dict,
                 valid_indices=[1, 2, 3],
                 ) -> None:
        super().__init__()
        self.img_size = img_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim

        self.valid_indices = valid_indices
        self.num_feats = len(valid_indices)
        self.num_classes = num_classes

        if backbone == 'resnet50':
            self.backbone = resnet.resnet50(pretrained=pretrained)
        elif backbone == 'resnet18':
            self.backbone = resnet.resnet18(pretrained=pretrained)
        else:
            raise NotImplementedError(f'Backbone {backbone} is not supported')

        self.vggish = VGGish(**vggish)
        self.audio_proj = nn.Linear(audio_dim, embed_dim)

        self.query_generator = PromptGenerator(**query_generator)
        self.query_generator.apply(AVESFormer.init_weights)

        self.in_proj = nn.Sequential(
            nn.Conv2d(in_channels[0], embed_dim, kernel_size=1),
            nn.GroupNorm(32, embed_dim)
        )

        spatial_shapes = [(self.img_size // spatial_reduction[idx], self.img_size // spatial_reduction[idx]) for
                          idx in valid_indices]
        masks = [torch.zeros((1, *shape), dtype=torch.bool) for shape in spatial_shapes]
        valid_ratio = torch.stack([AVESFormer.get_valid_ratio(m) for m in masks], 1)
        valid_ratios = nn.Parameter(valid_ratio, requires_grad=False)

        spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long)
        level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))

        spatial_shapes = nn.Parameter(spatial_shapes, requires_grad=False)
        level_start_index = nn.Parameter(level_start_index, requires_grad=False)

        self.decoder = ELFDecoder(**decoder, in_channels=[in_channels[i] for i in valid_indices],
                                  spatial_shapes=spatial_shapes,
                                  valid_ratios=valid_ratios,
                                  level_start_index=level_start_index)

        self.out_conv = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim,
                      kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(32, embed_dim),
            nn.ReLU(True)
        )

        self.fc = nn.Sequential(
            RepVGGBlock(embed_dim, 128, kernel_size=3, stride=1, padding=1),
            Interpolate(scale_factor=4, mode="bilinear"),
            RepVGGBlock(128, 32, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(32, num_classes, kernel_size=1,
                      stride=1, padding=0, bias=False)
        )
        self.fc.apply(AVESFormer.init_weights)


    @staticmethod
    def init_weights(module):
        if isinstance(module, nn.Linear):
            trunc_normal_(module.weight, std=.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Conv2d):
            nn.init.xavier_uniform_(module.weight, gain=1)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

    def fuse(self):
        for name,m in self.named_modules():
            if hasattr(m, 'switch_to_deploy'):
                m.switch_to_deploy()
                print(f"Fusing {name}")
            if hasattr(m, 'inplace'):
                print(f"Change {name} to inplace")
                m.inplace = True

    def mul_temporal_mask(self, feats, vid_temporal_mask_flag=None):
        if vid_temporal_mask_flag is None:
            return feats
        else:
            if isinstance(feats, list):
                out = []
                for x in feats:
                    out.append(x * vid_temporal_mask_flag)
            elif isinstance(feats, torch.Tensor):
                out = feats * vid_temporal_mask_flag

            return out

    @staticmethod
    def get_valid_ratio(mask):
        _, H, W = mask.shape
        valid_H = torch.sum(~mask[:, :, 0], 1)
        valid_W = torch.sum(~mask[:, 0, :], 1)
        valid_ratio_h = valid_H.float() / H
        valid_ratio_w = valid_W.float() / W
        valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
        return valid_ratio

    def reform_output_sequences(self, memory, spatial_shapes, level_start_index, dim=1):
        split_size_or_sections = [None] * self.num_feats
        for i in range(self.num_feats):
            if i < self.num_feats - 1:
                split_size_or_sections[i] = level_start_index[i +
                                                              1] - level_start_index[i]
            else:
                split_size_or_sections[i] = memory.shape[dim] - \
                                            level_start_index[i]
        y = torch.split(memory, split_size_or_sections, dim=dim)
        return y

    def forward(self, img: Tensor, audio: Tensor, vid_temporal_mask_flag=None):
        if vid_temporal_mask_flag is not None:
            vid_temporal_mask_flag = vid_temporal_mask_flag.view(-1, 1, 1, 1)
        # with torch.no_grad():
        audio_feat = self.vggish(audio)  # [B*T,128]
        audio_feat = audio_feat.unsqueeze(1)
        audio_feat = self.audio_proj(audio_feat)

        img_feat = self.backbone(img)
        img_feat = self.mul_temporal_mask(img_feat, vid_temporal_mask_flag)
        img_feat[0] = self.in_proj(img_feat[0])

        # prepare queries
        bs = audio_feat.shape[0]
        query = self.query_generator(audio_feat)  # [B,num_query, embed_dim]

        outputs, memory = self.decoder(img_feat[1:], query, query_pos=None)

        decoder_feats = []
        for i, z in enumerate(
                self.reform_output_sequences(outputs, self.decoder.spatial_shapes, self.decoder.level_start_index,
                                             1)):
            decoder_feats.append(z.transpose(1, 2).view(
                bs, -1, self.decoder.spatial_shapes[i][0], self.decoder.spatial_shapes[i][1]))

        inter_feature = memory[-1]
        inter_feature = img_feat[0] + F.interpolate(inter_feature, size=img_feat[0].shape[2:], mode='bilinear')
        inter_feature = self.out_conv(inter_feature)
        mask_feature = F.interpolate(decoder_feats[0], size=inter_feature.shape[2:], mode='bilinear')
        all_feats = inter_feature + inter_feature * mask_feature

        pred = self.fc(all_feats)
        pred = self.mul_temporal_mask(pred, vid_temporal_mask_flag)
        if self.training:
            aux = {"inter_out": [],
                   "mask_feature": self.mul_temporal_mask(mask_feature, vid_temporal_mask_flag), }

            return pred, aux

        return pred
