import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from models.conv import FullyConvolution, DeConvolution, Conv2d
from models.i3d_model import InceptionModule
from models.rnn import DynamicGRU


class VisualSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.fc_q = nn.Conv2d(config['clip_dim'], config['hidden_size'], kernel_size=1, stride=1, padding=0)
        self.fc_k = nn.Conv2d(config['clip_dim'], config['hidden_size'], kernel_size=1, stride=1, padding=0)
        self.fc_v = nn.Conv2d(config['clip_dim'], config['hidden_size'], kernel_size=1, stride=1, padding=0)

    def forward(self, x, mask):
        bsz, dim, h, w = x.size()
        q, k, v = self.fc_q(x).transpose(-2, -3).transpose(-1, -2).reshape(bsz, h * w, -1), \
                  self.fc_k(x).transpose(-2, -3).transpose(-1, -2).reshape(bsz, h * w, -1), \
                  self.fc_v(x).transpose(-2, -3).transpose(-1, -2).reshape(bsz, h * w, -1)
        mask = mask.reshape(bsz, h * w)
        score = torch.matmul(q, k.transpose(-1, -2))
        score = F.softmax(score.masked_fill(mask.unsqueeze(-1) == 0, float('-1e30')), dim=-1)
        res = torch.matmul(score, v)
        return res.reshape(bsz, h, w, -1).transpose(-1, -2).transpose(-2, -3)


class MainModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self._build_clip_encoder()
        self._build_query_encoder()

        # contrastive & tightness: fg and bg
        self.dconv2 = DeConvolution(config['clip_dim'] + 2, config['hidden_size'])
        self.conv2 = Conv2d(config['hidden_size'] + 2, config['hidden_size'], kernel_size=1, stride=1, padding=0)
        self.fg_predictor = Conv2d(config['hidden_size'], 1, kernel_size=1, stride=1, padding=0)
        self.lambda_ = nn.Parameter(torch.ones(1, 1).float() * 10, requires_grad=False)

        # cross-modal learning
        self.vg_attn = VisionGuidedAttention(config['clip_dim'], config['hidden_size'])
        self.cg = CrossGate(config['clip_dim'], config['hidden_size'])
        self.fconv1 = FullyConvolution(config['clip_dim'] + config['hidden_size'] + 2, config['hidden_size'],
                                       config['hidden_size'])
        self.dconv1 = DeConvolution(config['hidden_size'], config['hidden_size'])
        self.ali_predictor = FullyConvolution(config['hidden_size'] + config['hidden_size'] + 2,
                                              config['hidden_size'], 1)
        self.up = nn.UpsamplingBilinear2d(scale_factor=4)

    def forward(self, clip, query, query_len, coarse_gt_mask, fine_gt_mask, anch_mask, mask, **kwargs):
        query_mask = generate_mask(query, query_len)
        query_h = self.gru(query, query_len)

        coord_emb = generate_coordinate_emb(clip).transpose(-1, -2).transpose(-3, -2)
        clip_h = clip.transpose(-1, -2).transpose(-2, -3).transpose(-3, -4)
        for i in self.i3d:
            clip_h = i(clip_h)
        clip_h = clip_h.mean(dim=2)

        bsz, dim, h, w = clip_h.size()
        # clip_h = self.vs_attn(clip_h, mask[32])

        input_clip_h = clip_h

        # cross-modal learning
        clip_h = clip_h.transpose(-2, -3).transpose(-1, -2)
        query_h_ = self.vg_attn(clip_h, mask[32], query_h, query_mask)
        clip_h, query_h_ = self.cg(clip_h, query_h_)
        clip_h = clip_h.transpose(-1, -2).transpose(-2, -3)
        query_h_ = query_h_.transpose(-1, -2).transpose(-2, -3)

        clip_h = self.fconv1(torch.cat([clip_h, query_h_, coord_emb], dim=1), mask[32])  # [bsz, hid, h, w]
        clip_h = self.dconv1(clip_h, mask[128])
        x = torch.cat([clip_h, self.up(query_h_), self.up(coord_emb)], 1)
        ali_score_map = self.ali_predictor(x, mask[128])

        #
        clip_h = F.relu(self.dconv2(torch.cat([input_clip_h, coord_emb], dim=1), mask[128]))
        # mid_clip_h = clip_h
        clip_h = self.conv2(torch.cat([clip_h, self.up(coord_emb)], dim=1), mask[128])
        fg_score_map = self.fg_predictor(clip_h, mask[128])
        res = {
            'ali_score_map': torch.sigmoid(ali_score_map).masked_fill(mask[128].unsqueeze(1) == 0, 0),
            'clip_emb': F.normalize(clip_h, dim=1),
            'fg_score_map': torch.sigmoid(fg_score_map).masked_fill(mask[128].unsqueeze(1) == 0, 0),
            'mix_score_map': torch.sigmoid(ali_score_map).masked_fill(mask[128].unsqueeze(1) == 0, 0),
            'fine_gt_mask': fine_gt_mask,
            'coarse_gt_mask': coarse_gt_mask,
            'mask': mask
        }

        if self.training:
            contrast_score, diversity_loss, same_loss = self._contrastive_score(clip_h, mask[128], anch_mask[128],
                                                                                fg_score_map)
            res.update({
                'contrast_score': contrast_score,
                'diversity_loss': diversity_loss,
                'same_loss': same_loss,
            })

        return res

    def _contrastive_score(self, clip_h, mask, anch_mask, fg_score_map):
        # clip_h = F.relu(self.conv1(clip_h, mask))
        # clip_h = self.conv2(clip_h, mask)

        bsz, dim, h, w = clip_h.size()
        anch_mask = anch_mask.unsqueeze(1)

        mask1 = (anch_mask == 1).long()
        score = fg_score_map.masked_fill(mask1 == 0, float('-1e30')).reshape(bsz, h * w)
        score = F.softmax(score, dim=-1)
        diversity_loss = -(score * torch.log(score + 1e-10)).sum(dim=-1).mean(dim=0)
        score = score.reshape(bsz, 1, h, w)
        anchor_emb = (clip_h * score).sum(dim=-1).sum(dim=-1)
        anchor_emb = F.normalize(anchor_emb, dim=-1)

        mask1 = (anch_mask == 2).long()
        score = fg_score_map.masked_fill(mask1 == 0, float('-1e30')).reshape(bsz, h * w)
        score = F.softmax(score, dim=-1).reshape(bsz, 1, h, w)
        pos_emb = (clip_h * score).sum(dim=-1).sum(dim=-1)

        # print(anchor_emb.size(), pos_emb.size())

        exist_mask = [torch.ones(bsz, 1).type_as(mask)]
        contrast_emb = [pos_emb]

        same_loss = 0.0

        for neg_idx in [3, 4, 5, 6]:
            mask1 = (anch_mask == neg_idx).long()
            is_exist = ((mask1 == 1).sum(dim=-1).sum(dim=-1) > 0).long()
            score = fg_score_map.masked_fill(mask1 == 0, float('-1e30')).reshape(bsz, h * w)
            score = F.softmax(score, dim=-1).reshape(bsz, 1, h, w)
            same_loss += ((-(score * torch.log(score + 1e-10)).sum(dim=-1))
                          * is_exist.float()).sum() / (is_exist.sum().float() + 1e-10)

            neg_emb = (clip_h * score).sum(dim=-1).sum(dim=-1)
            contrast_emb.append(neg_emb)
            exist_mask.append(is_exist)
        same_loss /= 4
        exist_mask = torch.cat(exist_mask, dim=1)

        contrast_emb = F.normalize(torch.stack(contrast_emb, dim=1), dim=-1)

        score = torch.matmul(anchor_emb.unsqueeze(1), contrast_emb.transpose(-1, -2)).squeeze(1) * self.lambda_.cuda()
        score = F.softmax(score.masked_fill(exist_mask == 0, float('-1e30')), dim=-1)
        return score[:, 0], diversity_loss, same_loss

    def _build_clip_encoder(self):
        self.i3d = []
        name = 'inception_i3d'
        end_point = 'Mixed_4b'
        self.i3d.append(InceptionModule(128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point))
        end_point = 'Mixed_4c'
        self.i3d.append(InceptionModule(192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point))
        end_point = 'Mixed_4d'
        self.i3d.append(InceptionModule(160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point))
        end_point = 'Mixed_4e'
        self.i3d.append(InceptionModule(128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point))
        self.i3d.append(InceptionModule(112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128], name + 'Mixed_4f'))
        self.i3d = nn.ModuleList(self.i3d)

        self.vs_attn = VisualSelfAttention(self.config)

    def _build_query_encoder(self):
        self.gru = DynamicGRU(self.config['query_dim'], self.config['hidden_size'] >> 1,
                              num_layers=2, bidirectional=True, batch_first=True)


class VisionGuidedAttention(nn.Module):
    def __init__(self, src_dim, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.fc1 = nn.Linear(src_dim, self.hidden_size)
        self.fc2 = nn.Linear(self.hidden_size, self.hidden_size)

    def forward(self, clip, clip_mask, query, query_mask=None):
        bsz, h, w, _ = clip.size()
        a, b = self.fc1(clip), self.fc2(query)
        a = a.reshape(bsz, h * w, -1)
        score = torch.matmul(a, b.transpose(-1, -2)) / math.sqrt(self.hidden_size)
        if query_mask is not None:
            score = score.masked_fill_(query_mask.unsqueeze(1) == 0, float('-inf'))
        score = F.softmax(score, -1)
        # assert not torch.isnan(score).any()
        query_ = torch.matmul(score, query)
        # .view(bsz, h, w, -1)
        # assert not torch.isnan(query_).any()
        if clip_mask is not None:
            clip_mask = clip_mask.reshape(bsz, -1)
            query_ = query_.masked_fill_(clip_mask.unsqueeze(-1) == 0, 0)
        return query_.reshape(bsz, h, w, -1)


class CrossGate(nn.Module):
    def __init__(self, h1, h2):
        super().__init__()
        self.g1 = nn.Linear(h2, h1)
        self.g2 = nn.Linear(h1, h2)

    def forward(self, x1, x2):
        return x1 * torch.sigmoid(self.g1(x2)), x2 * torch.sigmoid(self.g2(x1))


def generate_mask(x, x_len):
    if False and int(x_len.min()) == x.size(1):
        mask = None
    else:
        mask = []
        for l in x_len:
            mask.append(torch.zeros([x.size(1)]).long())
            mask[-1][:l] = 1
        mask = torch.stack(mask, 0).cuda()
    return mask


def generate_coordinate_emb(clip):
    bsz, t, h, w, dim = clip.size()
    x = (2 * (torch.linspace(0, h, h).type_as(clip) / h) - 1)
    x = x.unsqueeze(-1).expand(h, w)
    y = (2 * (torch.linspace(0, w, w).type_as(clip) / w) - 1)
    y = y.unsqueeze(0).expand(h, w)
    co = torch.stack([x, y], -1)
    co = F.normalize(co, -1)
    co = co.unsqueeze(0).expand(bsz, -1, -1, -1)
    return co


if __name__ == '__main__':
    from torch.utils.data import DataLoader
    from fairseq.utils import move_to_cuda
    from datasets.a2d import A2D

    args = {
        "videoset_path": "/home/user/data/A2D/Release/videoset.csv",
        "annotation_path": "/home/user/data/A2D/Release/Annotations",
        "vocab_path": "/home/user/code/mm-2020/data/glove_a2d.bin",
        "sample_path": "/home/user/data/A2D/a2d_annotation2.txt",
        "max_num_words": 20,
    }
    dataset = A2D(args)
    # dataset.train_set[66]
    # exit(0)
    loader = DataLoader(dataset.train_set, batch_size=4, shuffle=True, num_workers=1,
                        pin_memory=True, collate_fn=dataset.collate_fn)
    args = {
        "hidden_size": 256,
        "clip_dim": 832,
        "query_dim": 300,
    }
    model = MainModel(args).cuda()
    for batch in loader:
        net_input = move_to_cuda(batch['net_input'])
        output = model(**net_input)
        exit(0)
