import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter
from torchvision.models import resnet50

from models.conv import Conv2d, FullyConvolution
from models.i3d_model import InceptionModule, MaxPool3dSamePadding, Unit3D
from models.rnn import DynamicGRU


# class QueryGuidedRegionAttention2(nn.Module):
#     def __init__(self, config, src_dim):
#         super().__init__()
#         self.config = config
#         self.fc_i = nn.Linear(src_dim, config['hidden_size'])
#         self.fc_q = nn.Linear(config['hidden_size'], config['hidden_size'])
#         self.fc_k = nn.Linear(config['hidden_size'], config['hidden_size'])
#         self.fc_v = nn.Linear(config['hidden_size'], config['hidden_size'])
#
#         self.attn = VisionGuidedAttention(src_dim, config['hidden_size'])
#         self.fc_g = nn.Linear(config['hidden_size'], src_dim)
#         self.fc_o = nn.Linear(config['hidden_size'], src_dim)
#
#     def forward(self, clip_h, clip_mask, segment, query, query_len, query_mask):
#         res = clip_h.transpose(-2, -3).transpose(-1, -2)
#
#         clip_h_ = clip_h.transpose(-2, -3).transpose(-1, -2)
#         clip_h_ = self.fc_i(clip_h_)
#         clip_h_ = clip_h_.transpose(-1, -2).transpose(-2, -3)
#
#         bsz, dim, h, w = clip_h_.size()
#         segment_ = segment.reshape(bsz, h * w, 1).expand(bsz, h * w, dim)
#         clip_h_ = clip_h_.transpose(-2, -3).transpose(-1, -2).reshape(bsz, h * w, dim)
#         # segment_: [bsz, h * w, dim], clip_h_: [bsz, h * w, dim]
#         cluster_emb = scatter(src=clip_h_, index=segment_, dim=1, reduce='max')
#         cluster_len = torch.max(segment.reshape(bsz, -1), dim=-1)[0] + 1
#         cluster_mask = generate_mask(cluster_emb, cluster_len)
#         # cluster_mask[:, 0] = 0
#
#         q, k = self.fc_q(cluster_emb), self.fc_k(cluster_emb)
#         v = self.fc_v(cluster_emb)
#         s = torch.matmul(q, k.transpose(-1, -2))
#         s = s.masked_fill(cluster_mask.unsqueeze(1) == 0, float('-inf'))
#         s = F.softmax(s, dim=-1)
#         x = torch.matmul(s, v)
#         segment_ = segment.reshape(bsz, h * w, 1).expand(bsz, h * w, dim)
#         x = x.gather(dim=1, index=segment_)
#         x = x.reshape(bsz, h, w, -1)
#         x = self.fc_o(x)
#
#         x = res + x
#
#         query_ = self.attn(x, clip_mask, query, query_mask)
#         x = torch.sigmoid(self.fc_g(query_)) * x
#         x = x.transpose(-1, -2).transpose(-2, -3)
#         x = F.dropout(x, 0.1, self.training)
#
#         return x

# class QueryGuidedRegionAttention2D(nn.Module):
#     def __init__(self, config, src_dim):
#         super().__init__()
#         self.config = config
#         hidden_size = config['hidden_size']
#         self.fc_i = nn.Linear(src_dim, hidden_size)
#         self.fc_q = nn.Linear(hidden_size, hidden_size)
#         self.fc_k = nn.Linear(hidden_size, hidden_size)
#         self.fc_v = nn.Linear(hidden_size, hidden_size)
#
#         self.attn = VisionGuidedAttention(src_dim, hidden_size)
#         self.fc_g = nn.Linear(hidden_size, hidden_size)
#         self.fc_o = nn.Conv2d(hidden_size * 3, src_dim, kernel_size=1, stride=1, padding=0)
#
#     def forward(self, clip_h, clip_mask, segment, query, query_len, query_mask):
#         bsz, dim, h, w = clip_h.size()
#
#         x = clip_h.transpose(-2, -3).transpose(-1, -2)
#         query_h = self.attn(x, clip_mask, query, query_mask)
#         x = self.fc_i(x) * torch.tanh(self.fc_g(query_h))
#
#         res = x
#         segment_ = segment.reshape(bsz, h * w, 1).expand(bsz, h * w, res.size(-1))
#         x = x.reshape(bsz, h * w, -1)
#         cluster_emb = scatter(src=x, index=segment_, dim=1, reduce='max')
#         cluster_len = torch.max(segment.reshape(bsz, -1), dim=-1)[0] + 1
#         cluster_mask = generate_mask(cluster_emb, cluster_len)
#         q, k = self.fc_q(cluster_emb), self.fc_k(cluster_emb)
#         v = self.fc_v(cluster_emb)
#         s = torch.matmul(q, k.transpose(-1, -2))
#         s = s.masked_fill(cluster_mask.unsqueeze(1) == 0, float('-inf'))
#         s = F.softmax(s, dim=-1)
#         x = torch.matmul(s, v)
#         segment_ = segment.reshape(bsz, h * w, 1).expand(bsz, h * w, res.size(-1))
#         x = x.gather(dim=1, index=segment_)
#         x = x.reshape(bsz, h, w, -1)
#
#         # print(res.size(), x.size(), query_h.size())
#         x = torch.cat([res, x, query_h], dim=-1).transpose(-1, -2).transpose(-2, -3)
#         x = self.fc_o(x)
#         x = F.dropout(x, p=0.1, training=self.training)
#         return x

class QueryGuidedRegionAttention2D(nn.Module):
    def __init__(self, src_dim1, src_dim2, hidden_size):
        super().__init__()
        # self.fc1 = nn.Linear(src_dim1, hidden_size)
        self.attn = VisionGuidedAttention(src_dim1, hidden_size)

        self.fc_q = nn.Linear(src_dim1, hidden_size)
        self.fc_k = nn.Linear(src_dim1, hidden_size)
        self.fc_v = nn.Linear(src_dim1, hidden_size)

        self.fc_g = nn.Linear(src_dim2, src_dim1)
        self.fc_o = nn.Conv2d(hidden_size + src_dim2 + src_dim1, src_dim1, kernel_size=1, stride=1, padding=0)

    def forward(self, clip_h, clip_mask, segment, query, query_len, query_mask):
        bsz, dim, h, w = clip_h.size()

        clip_h_ = clip_h.transpose(-2, -3).transpose(-1, -2)
        # query_ = []
        # for i, l in enumerate(query_len):
        #     query_.append(query[i, :l].mean(dim=0))
        # query_ = torch.stack(query_, dim=0).unsqueeze(1)
        clip_h_, query_ = clip_h_, query
        query_ = self.attn(clip_h_, clip_mask, query, query_mask)
        clip_h_ = clip_h_ * torch.tanh(self.fc_g(query_))

        x = clip_h_
        segment_ = segment.reshape(bsz, h * w, 1).expand(bsz, h * w, clip_h_.size(-1))
        x = x.reshape(bsz, h * w, -1)
        cluster_emb = scatter(src=x, index=segment_, dim=1, reduce='max')
        cluster_len = torch.max(segment.reshape(bsz, -1), dim=-1)[0] + 1
        cluster_mask = generate_mask(cluster_emb, cluster_len)
        q, k = self.fc_q(cluster_emb), self.fc_k(cluster_emb)
        v = self.fc_v(cluster_emb)
        s = torch.matmul(q, k.transpose(-1, -2))
        s = s.masked_fill(cluster_mask.unsqueeze(1) == 0, float('-inf'))
        s = F.softmax(s, dim=-1)
        x = torch.matmul(s, v)
        segment_ = segment.reshape(bsz, h * w, 1).expand(bsz, h * w, x.size(-1))
        x = x.gather(dim=1, index=segment_)
        x = x.reshape(bsz, h, w, -1)

        # print(clip_h_.size(), query_.size(), x.size())
        # exit(0)
        x = torch.cat([clip_h_, query_, x], dim=-1)
        x = x.transpose(-1, -2).transpose(-2, -3)
        x = self.fc_o(x)
        x = F.dropout(x, p=0.1, training=self.training)
        return x


class QueryGuidedRegionAttention2D_I3D(nn.Module):
    def __init__(self, src_dim1, src_dim2, hidden_size):
        super().__init__()
        self.fc1 = nn.Linear(src_dim1, hidden_size)
        self.attn = VisionGuidedAttention(src_dim1, hidden_size)

        self.fc_g = nn.Linear(src_dim2, src_dim1)
        self.hidden_size = hidden_size

    def forward(self, clip_h, clip_mask, segment, query, query_len, query_mask):
        bsz, dim, t, h, w = clip_h.size()
        _, ql, qd = query.size()
        clip_h_ = clip_h.transpose(-3, -4).transpose(-2, -3).transpose(-1, -2).reshape(bsz * t, h, w, dim)
        query = query.unsqueeze(1).expand(bsz, t, ql, qd).reshape(bsz * t, ql, qd)
        query_mask = query_mask.unsqueeze(1).expand(bsz, t, ql).reshape(bsz * t, ql)
        query_ = self.attn(clip_h_, clip_mask, query, query_mask)
        # clip_h_ = clip_h_ * torch.tanh(self.fc_g(query_))
        clip_h_ = clip_h_ * torch.tanh(self.fc_g(query_))
        return clip_h_.reshape(bsz, t, h, w, dim).transpose(-1, -2).transpose(-2, -3).transpose(-3, -4)


# class QueryGuidedRegionAttention3D(nn.Module):
#     def __init__(self, src_dim1, src_dim2, hidden_size):
#         super().__init__()
#         self.fc1 = nn.Linear(src_dim1, hidden_size)
#         self.fc2 = nn.Linear(src_dim2, hidden_size)
#
#         self.fc_q = nn.Linear(hidden_size, hidden_size)
#         self.fc_k = nn.Linear(hidden_size, hidden_size)
#         self.fc_v = nn.Linear(hidden_size, hidden_size)
#
#         self.fc_g = nn.Linear(hidden_size, hidden_size)
#         self.fc_o = nn.Conv2d(hidden_size, src_dim1, kernel_size=1, stride=1, padding=0)
#
#     def forward(self, clip_h, clip_mask, segment, query, query_len, query_mask):
#         bsz, dim, t, h, w = clip_h.size()
#
#         clip_h_ = clip_h.transpose(-3, -4).transpose(-2, -3).transpose(-1, -2)
#         # query_ = []
#         # for i, l in enumerate(query_len):
#         #     query_.append(query[i, :l].mean(dim=0))
#         # query_ = torch.stack(query_, dim=0).unsqueeze(1)
#         clip_h_, query_ = self.fc1(clip_h_), self.fc2(query).unsqueeze(1).unsqueeze(1).unsqueeze(1)
#         clip_h_ = clip_h_ * torch.tanh(self.fc_g(query_))
#
#         x = clip_h_
#         segment = segment.unsqueeze(1).expand(bsz, t, h, w).unsqueeze(-1)
#         segment_ = segment.expand(bsz, t, h, w, clip_h_.size(-1))
#         x = x.reshape(bsz * t, h * w, -1)
#         cluster_emb = scatter(src=x, index=segment_.reshape(bsz * t, h * w, -1), dim=1, reduce='mean')
#         cluster_len = torch.max(segment.reshape(bsz * t, -1), dim=-1)[0] + 1
#         # print(cluster_len)
#         # exit(0)
#         cluster_mask = generate_mask(cluster_emb, cluster_len)
#         q, k = self.fc_q(cluster_emb), self.fc_k(cluster_emb)
#         v = self.fc_v(cluster_emb)
#         s = torch.matmul(q, k.transpose(-1, -2))
#         s = s.masked_fill(cluster_mask.unsqueeze(1) == 0, float('-inf'))
#         s = F.softmax(s, dim=-1)
#         x = torch.matmul(s, v)
#         segment_ = segment.reshape(bsz * t, h * w, 1).expand(bsz * t, h * w, clip_h_.size(-1))
#         x = x.gather(dim=1, index=segment_)
#         x = x.reshape(bsz, t, h, w, -1)
#
#         x = (clip_h_ + x) / math.sqrt(2)
#         x = x.reshape(bsz * t, h, w, -1).transpose(-1, -2).transpose(-2, -3)
#         x = self.fc_o(x)
#         x = F.dropout(x, p=0.1, training=self.training)
#         return x.reshape(bsz, t, dim, h, w).transpose(-3, -4)


class BestSupervisedModelI3D(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self._build_clip_encoder()
        self._build_query_encoder()

        # cross-modal learning
        # self.fconv1 = FullyConvolution(config['clip_dim'] + 2, config['hidden_size'],
        #                                config['hidden_size'])
        # self.fconv2 = FullyConvolution(config['hidden_size'] + 2, config['hidden_size'],
        #                                config['hidden_size'])
        self.ali_predictor = {}

        # for (a, b) in [(256, 64), (128, 64), (64, 192), (32, 480), (16, 832)]:
        #     self.ali_predictor['pred_{}'.format(a)] = FullyConvolution(b, b, 1)
        # FullyConvolution(b, b, 1)

        for (a, b, c) in [(256, 64, 64), (128, 64, 64), (64, 192, 128), (32, 480, 256), (16, 832, 512)]:
            self.ali_predictor['pred_{}'.format(a)] = FullyConvolution(b, c, 1)

        self.up = nn.UpsamplingBilinear2d(scale_factor=2)

        # self.ali_predictor2 = FullyConvolution(config['hidden_size'],
        #                                        config['hidden_size'], 1)

        self.fconv = {}

        self.ra = {}
        self.resolution = {
            8: 1024,
            16: 832,
            32: 480,
            64: 192,
            128: 64
        }
        for res in self.resolution.keys():
            self.ra['RegionAttention_{}'.format(res)] = \
                QueryGuidedRegionAttention2D(self.resolution[res], config['hidden_size'], config['hidden_size'])
            self.ra['RegionAttention_{}_i3d'.format(res)] = \
                QueryGuidedRegionAttention2D(self.resolution[res], config['hidden_size'], config['hidden_size'])
            if res != 128:
                self.fconv['Conv_{}'.format(res)] = \
                    Conv2d(self.resolution[res], self.resolution[res * 2], kernel_size=1, stride=1, padding=0)
        for k in self.ra.keys():
            self.add_module(k, self.ra[k])
        for k in self.fconv.keys():
            self.add_module(k, self.fconv[k])
        for k in self.ali_predictor.keys():
            self.add_module(k, self.ali_predictor[k])
        self._load_pretrained_weights()

        self.resnet = resnet50(pretrained=True)

        # for n, p in self.named_parameters():
        #     print(n, p.requires_grad)
        # exit(0)

    def _load_pretrained_weights(self):
        path = '/home/user/pytorch-i3d/models/rgb_charades.pt'
        # path = '/home/user/pytorch-i3d/models/rgb_imagenet.pt'
        state_dict = torch.load(path)
        print('load i3d model:', self.load_state_dict(state_dict, strict=False))

    def forward(self, clip, query, query_len, coarse_gt_mask, fine_gt_mask, anch_mask, mask, segment, **kwargs):
        # input torch.Size([bsz, 3, 8, 256, 256])
        # Conv3d_1a_7x7 torch.Size([bsz, 64, 4, 128, 128])
        # Conv3d_2c_3x3 torch.Size([bsz, 192, 4, 64, 64])
        # Mixed_3c torch.Size([bsz, 480, 4, 32, 32])
        # Mixed_4f torch.Size([bsz, 832, 2, 16, 16])
        query_mask = generate_mask(query, query_len)
        query_h = self.gru(query, query_len)
        # query_h_ = []
        # for i, l in enumerate(query_len):
        #     query_h_.append(query_h[i, :l].mean(dim=0))
        # query_h = torch.stack(query_h_, dim=0)

        # coord_emb = {
        #     16: generate_coordinate_emb(clip, 16, 16).transpose(-1, -2).transpose(-3, -2),
        #     32: generate_coordinate_emb(clip, 32, 32).transpose(-1, -2).transpose(-3, -2),
        #     64: generate_coordinate_emb(clip, 64, 64).transpose(-1, -2).transpose(-3, -2),
        #     128: generate_coordinate_emb(clip, 128, 128).transpose(-1, -2).transpose(-3, -2),
        #     256: generate_coordinate_emb(clip, 256, 256).transpose(-1, -2).transpose(-3, -2)
        # }
        multi_res = {}

        clip_h = clip
        for end_point in self.end_points:
            clip_h = self._modules[end_point](clip_h)  # use _modules to work with dataparallel
            # print(end_point, clip_h.size())
            if end_point == 'Conv3d_1a_7x7':
                res = 128
            elif end_point == 'Conv3d_2c_3x3':
                res = 64
            elif end_point == 'Mixed_3c':
                res = 32
            elif end_point == 'Mixed_4f':
                res = 16
            elif end_point == 'Mixed_5c':
                res = 8
            else:
                continue
            x = self._modules['RegionAttention_{}_i3d'.format(res)](clip_h.mean(dim=2),
                                                                    mask[res], segment[res],
                                                                    query_h, query_len, query_mask)
            clip_h = clip_h + x.unsqueeze(2)
            multi_res[res] = clip_h[:, :, clip_h.size(2) // 2]
            # multi_res[res] = clip_h.mean(dim=2)
        # exit(0)

        ali_score_map = {}
        x = 0.0
        for layer_idx, res in enumerate(self.resolution):
            x = x + multi_res[res]
            x = x + self._modules['RegionAttention_{}'.format(res)](x, mask[res], segment[res],
                                                                    query_h, query_len, query_mask)
            if res != 128:
                x = self._modules['Conv_{}'.format(res)](x, mask[res])
            x = self.up(x)
            if 'pred_{}'.format(res * 2) in self.ali_predictor.keys():
                input_x = x
                # if res != 128:
                #     input_x = torch.cat([input_x, multi_res[res * 2]], dim=1)
                ali_score_map[res * 2] = self._modules['pred_{}'.format(res * 2)](input_x, mask[res * 2])
                ali_score_map[res * 2] = torch.sigmoid(
                    ali_score_map[res * 2]).masked_fill(mask[res * 2].unsqueeze(1) == 0, 0)

        fix_score_map = ali_score_map[16]
        for res in [32, 64, 128, 256]:
            fix_score_map = self.up(fix_score_map) + ali_score_map[res]
            fix_score_map = fix_score_map.masked_fill(mask[res].unsqueeze(1) == 0, 0)
        final_dict = {
            'ali_score_map': ali_score_map,
            'fix_score_map': fix_score_map,
            'fine_gt_mask': fine_gt_mask,
            'coarse_gt_mask': coarse_gt_mask,
            'mask': mask
        }

        return final_dict

    def _contrastive_score(self, clip_h, mask, anch_mask, fg_score_map):
        # clip_h = F.relu(self.conv1(clip_h, mask))
        # clip_h = self.conv2(clip_h, mask)

        bsz, dim, h, w = clip_h.size()
        anch_mask = anch_mask.unsqueeze(1)

        mask1 = (anch_mask == 1).long()
        score = fg_score_map.masked_fill(mask1 == 0, float('-1e30')).reshape(bsz, h * w)
        score = F.softmax(score, dim=-1)
        diversity_loss = -(score * torch.log(score + 1e-10)).sum(dim=-1).mean(dim=0)
        score = score.reshape(bsz, 1, h, w)
        anchor_emb = (clip_h * score).sum(dim=-1).sum(dim=-1)
        anchor_emb = F.normalize(anchor_emb, dim=-1)

        mask1 = (anch_mask == 2).long()
        score = fg_score_map.masked_fill(mask1 == 0, float('-1e30')).reshape(bsz, h * w)
        score = F.softmax(score, dim=-1).reshape(bsz, 1, h, w)
        pos_emb = (clip_h * score).sum(dim=-1).sum(dim=-1)

        # print(anchor_emb.size(), pos_emb.size())

        exist_mask = [torch.ones(bsz, 1).type_as(mask)]
        contrast_emb = [pos_emb]

        same_loss = 0.0

        for neg_idx in [3, 4, 5, 6]:
            mask1 = (anch_mask == neg_idx).long()
            is_exist = ((mask1 == 1).sum(dim=-1).sum(dim=-1) > 0).long()
            score = fg_score_map.masked_fill(mask1 == 0, float('-1e30')).reshape(bsz, h * w)
            score = F.softmax(score, dim=-1).reshape(bsz, 1, h, w)
            same_loss += ((-(score * torch.log(score + 1e-10)).sum(dim=-1))
                          * is_exist.float()).sum() / (is_exist.sum().float() + 1e-10)

            neg_emb = (clip_h * score).sum(dim=-1).sum(dim=-1)
            contrast_emb.append(neg_emb)
            exist_mask.append(is_exist)
        same_loss /= 4
        exist_mask = torch.cat(exist_mask, dim=1)

        contrast_emb = F.normalize(torch.stack(contrast_emb, dim=1), dim=-1)

        score = torch.matmul(anchor_emb.unsqueeze(1), contrast_emb.transpose(-1, -2)).squeeze(1) * self.lambda_.cuda()
        score = F.softmax(score.masked_fill(exist_mask == 0, float('-1e30')), dim=-1)
        return score[:, 0], diversity_loss, same_loss

    def _build_clip_encoder(self):
        in_channels = 3
        name = 'inception_i3d'

        self.end_points = {}
        end_point = 'Conv3d_1a_7x7'
        self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7],
                                            stride=(2, 2, 2), padding=(3, 3, 3), name=name + end_point)

        end_point = 'MaxPool3d_2a_3x3'
        self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
                                                          padding=0)

        end_point = 'Conv3d_2b_1x1'
        self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0,
                                            name=name + end_point)

        end_point = 'Conv3d_2c_3x3'
        self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1,
                                            name=name + end_point)

        end_point = 'MaxPool3d_3a_3x3'
        self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
                                                          padding=0)

        end_point = 'Mixed_3b'
        self.end_points[end_point] = InceptionModule(192, [64, 96, 128, 16, 32, 32], name + end_point)

        end_point = 'Mixed_3c'
        self.end_points[end_point] = InceptionModule(256, [128, 128, 192, 32, 96, 64], name + end_point)

        end_point = 'MaxPool3d_4a_3x3'
        self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2),
                                                          padding=0)

        end_point = 'Mixed_4b'
        self.end_points[end_point] = InceptionModule(128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point)

        end_point = 'Mixed_4c'
        self.end_points[end_point] = InceptionModule(192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point)

        end_point = 'Mixed_4d'
        self.end_points[end_point] = InceptionModule(160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point)

        end_point = 'Mixed_4e'
        self.end_points[end_point] = InceptionModule(128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point)

        end_point = 'Mixed_4f'
        self.end_points[end_point] = InceptionModule(112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128],
                                                     name + end_point)

        end_point = 'MaxPool3d_5a_2x2'
        self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2),
                                                          padding=0)

        end_point = 'Mixed_5b'
        self.end_points[end_point] = InceptionModule(256 + 320 + 128 + 128, [256, 160, 320, 32, 128, 128],
                                                     name + end_point)

        end_point = 'Mixed_5c'
        self.end_points[end_point] = InceptionModule(256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128],
                                                     name + end_point)
        for k in self.end_points.keys():
            self.add_module(k, self.end_points[k])

        # for k, m in self.end_points.items():
        #     if 'Mixed_4' not in k:
        #         for n, p in m.named_parameters():
        #             p.requires_grad = False

    def _build_query_encoder(self):
        self.gru = DynamicGRU(self.config['query_dim'], self.config['hidden_size'] >> 1,
                              num_layers=1, bidirectional=True, batch_first=True)


class VisionGuidedAttention(nn.Module):
    def __init__(self, src_dim, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.fc1 = nn.Linear(src_dim, self.hidden_size)
        self.fc2 = nn.Linear(self.hidden_size, self.hidden_size)

    def forward(self, clip, clip_mask, query, query_mask=None):
        bsz, h, w, _ = clip.size()

        a = self.fc1(clip)
        b = self.fc2(query)
        # print(a.size(), b.size())
        # exit(0)
        a = a.reshape(bsz, h * w, -1)
        score = torch.matmul(a, b.transpose(-1, -2)) / math.sqrt(self.hidden_size)
        if query_mask is not None:
            score = score.masked_fill_(query_mask.unsqueeze(1) == 0, float('-inf'))
        score = F.softmax(score, -1)
        # assert not torch.isnan(score).any()
        query_ = torch.matmul(score, query)
        # .view(bsz, h, w, -1)
        # assert not torch.isnan(query_).any()
        # if clip_mask is not None:
        #     clip_mask = clip_mask.reshape(bsz, -1)
        #     query_ = query_.masked_fill_(clip_mask.unsqueeze(-1) == 0, 0)
        return query_.reshape(bsz, h, w, -1)


class TanhAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        # self.dropout = nn.Dropout(dropout)
        self.ws1 = nn.Linear(d_model, d_model, bias=True)
        self.ws2 = nn.Linear(d_model, d_model, bias=False)
        self.wst = nn.Linear(d_model, 1, bias=False)

    def reset_parameters(self):
        self.ws1.reset_parameters()
        self.ws2.reset_parameters()
        self.wst.reset_parameters()

    def forward(self, x, memory, memory_mask=None, fast_weights=None, **kwargs):
        if fast_weights is None:
            item1 = self.ws1(x)  # [nb, len1, d]
            item2 = self.ws2(memory)  # [nb, len2, d]
            # print(item1.shape, item2.shape)
            item = item1.unsqueeze(2) + item2.unsqueeze(1)  # [nb, len1, len2, d]
            S = self.wst(torch.tanh(item)).squeeze(-1)  # [nb, len1, len2]
        else:
            item1 = F.linear(x, fast_weights['ws1.weight'], fast_weights['ws1.bias'])  # [nb, len1, d]
            item2 = F.linear(memory, fast_weights['ws2.weight'])  # [nb, len2, d]
            # print(item1.shape, item2.shape)
            item = item1.unsqueeze(2) + item2.unsqueeze(1)  # [nb, len1, len2, d]
            S = F.linear(torch.tanh(item), fast_weights['wst.weight']).squeeze(-1)  # [nb, len1, len2]
        if memory_mask is not None:
            memory_mask = memory_mask.unsqueeze(1)  # [nb, 1, len2]
            S = S.masked_fill(memory_mask == 0, float('-inf'))
        S = F.softmax(S, -1)
        return torch.matmul(S, memory), S  # [nb, len1, d]


class CrossGate(nn.Module):
    def __init__(self, h1, h2):
        super().__init__()
        self.g1 = nn.Linear(h2, h1)
        self.g2 = nn.Linear(h1, h2)

    def forward(self, x1, x2):
        return x1 * torch.sigmoid(self.g1(x2)), x2 * torch.sigmoid(self.g2(x1))


def generate_mask(x, x_len):
    if False and int(x_len.min()) == x.size(1):
        mask = None
    else:
        mask = []
        for l in x_len:
            mask.append(torch.zeros([x.size(1)]).long())
            mask[-1][:l] = 1
        mask = torch.stack(mask, 0).cuda()
    return mask


def generate_coordinate_emb(clip, h, w):
    bsz = clip.size(0)
    x = (2 * (torch.linspace(0, h, h).type_as(clip) / h) - 1)
    x = x.unsqueeze(-1).expand(h, w)
    y = (2 * (torch.linspace(0, w, w).type_as(clip) / w) - 1)
    y = y.unsqueeze(0).expand(h, w)
    co = torch.stack([x, y], -1)
    co = F.normalize(co, -1)
    co = co.unsqueeze(0).expand(bsz, -1, -1, -1)
    return co


if __name__ == '__main__':
    from torch.utils.data import DataLoader
    from fairseq.utils import move_to_cuda
    from datasets.a2d import A2D

    args = {
        "videoset_path": "/home/user/data/A2D/Release/videoset.csv",
        "annotation_path": "/home/user/data/A2D/Release/Annotations",
        "vocab_path": "/home/user/code/mm-2020/data/glove_a2d.bin",
        "sample_path": "/home/user/data/A2D/a2d_annotation2.txt",
        "max_num_words": 20,
    }
    dataset = A2D(args)
    # dataset.train_set[66]
    # exit(0)
    loader = DataLoader(dataset.train_set, batch_size=4, shuffle=True, num_workers=1,
                        pin_memory=True, collate_fn=dataset.collate_fn)
    args = {
        "hidden_size": 256,
        "clip_dim": 832,
        "query_dim": 300,
    }
    model = MainModel(args).cuda()
    for batch in loader:
        net_input = move_to_cuda(batch['net_input'])
        output = model(**net_input)
        exit(0)
