import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter
from torch.nn import LayerNorm
from models.conv import Conv2d, FullyConvolution
from models.i3d_model2 import I3D
from models.rnn import DynamicGRU
from models.info_nce import InfoNCE

class StyleRandomization(nn.Module):
    def __init__(self, eps=1e-5):
        super().__init__()
        self.eps = eps

    def forward(self, x, K = 1):
        N, C, H, W = x.size()

        if self.training:
            x = x.view(N, C, -1)
            mean = x.mean(-1, keepdim=True)
            var = x.var(-1, keepdim=True)

            x = (x - mean) / (var + self.eps).sqrt()

            idx_swap = torch.randperm(N)
            alpha = torch.rand(N, 1, 1) / K
            if x.is_cuda:
                alpha = alpha.cuda()
            mean = (1 - alpha) * mean + alpha * mean[idx_swap]
            var = (1 - alpha) * var + alpha * var[idx_swap]

            x = x * (var + self.eps).sqrt() + mean
            x = x.view(N, C, H, W)

        return x



def get_video_spatial_feature(featmap_H, featmap_W):
    import numpy as np
    spatial_batch_val = np.zeros((1, 8, featmap_H, featmap_W))
    for h in range(featmap_H):
        for w in range(featmap_W):
            xmin = w / featmap_W * 2 - 1
            xmax = (w + 1) / featmap_W * 2 - 1
            xctr = (xmin + xmax) / 2
            ymin = h / featmap_H * 2 - 1
            ymax = (h + 1) / featmap_H * 2 - 1
            yctr = (ymin + ymax) / 2
            spatial_batch_val[0, :, h, w] = [xmin, ymin, xmax, ymax, xctr, yctr, 1 / featmap_W, 1 / featmap_H]
    return torch.from_numpy(spatial_batch_val).float()


class PosEmb(nn.Module):
    def __init__(self):
        super().__init__()
        self.pos10 = nn.Parameter(get_video_spatial_feature(10, 10), requires_grad=False)
        self.pos20 = nn.Parameter(get_video_spatial_feature(20, 20), requires_grad=False)
        self.pos40 = nn.Parameter(get_video_spatial_feature(40, 40), requires_grad=False)
        self.pos80 = nn.Parameter(get_video_spatial_feature(80, 80), requires_grad=False)
        self.pos160 = nn.Parameter(get_video_spatial_feature(160, 160), requires_grad=False)

    def forward(self, x):
        bsz, dim, h, w = x.size()
        if h == 10:
            pos_emb = self.pos10
        elif h == 20:
            pos_emb = self.pos20
        elif h == 40:
            pos_emb = self.pos40
        elif h == 80:
            pos_emb = self.pos80
        elif h == 160:
            pos_emb = self.pos160
        pos_emb = pos_emb.expand(bsz, 8, h, w).cuda(x.device)
        return torch.cat([x, pos_emb], 1)


class QueryGuidedRegionAttention2D(nn.Module):
    def __init__(self, src_dim1, src_dim2, hidden_size, region_self=True):
        super().__init__()
        self.fc_input1 = nn.Linear(src_dim1, hidden_size)
        self.fc_input2 = nn.Linear(src_dim2, hidden_size)

        # self.fc1 = nn.Linear(src_dim1, hidden_size)
        self.attn = VisionGuidedAttention(hidden_size + 8, hidden_size, hidden_size)

        self.pos_emb = PosEmb()

        self.fc_q = nn.Linear(hidden_size + 8, hidden_size)
        self.fc_k = nn.Linear(hidden_size + 8, hidden_size)
        self.fc_v = nn.Linear(hidden_size + 8, hidden_size)

        self.fc1 = nn.Linear(hidden_size, hidden_size << 1)
        self.fc2 = nn.Linear(hidden_size << 1, hidden_size)

        self.hidden_size = hidden_size
        self.region_self = region_self
        self.fc_o = nn.Conv2d(hidden_size, src_dim1, kernel_size=1, stride=1, padding=0)
        self.query_linear = nn.Linear(src_dim2, hidden_size)
        self.important_linear = nn.Linear(hidden_size, 1)
        self.video_linear = nn.Linear(hidden_size, hidden_size)
        #self.affine = nn.Linear(hidden_size, hidden_size)
        self.attention_linar = nn.Linear(hidden_size, hidden_size)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(0.1)
        self.style_random_layer = StyleRandomization()
  

    def forward(self, clip_h, clip_mask, segment, query, query_len, query_mask, summary_query = None, IN = False, Aug = False,IN_query = False, epoch = 0):
        bsz, dim, h, w = clip_h.size()

        x = self.fc_input1(clip_h.transpose(-2, -3).transpose(-1, -2))
        query = self.fc_input2(query)

        # region self-attention
        res = x
        x = self.pos_emb(x.transpose(-1, -2).transpose(-2, -3)).transpose(-2, -3).transpose(-1, -2)
        segment_ = segment.reshape(bsz, h * w, 1).expand(bsz, h * w, x.size(-1))
        x = x.reshape(bsz, h * w, -1)
        cluster_emb = scatter(src=x, index=segment_, dim=1, reduce='max')
        cluster_len = torch.max(segment.reshape(bsz, -1), dim=-1)[0] + 1
        cluster_mask = generate_mask(cluster_emb, cluster_len)
        q, k = self.fc_q(cluster_emb), self.fc_k(cluster_emb)
        v = self.fc_v(cluster_emb)
        s = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.hidden_size)
        s = s.masked_fill(cluster_mask.unsqueeze(1) == 0, float('-inf'))
        s = F.softmax(s, dim=-1)
        x = torch.matmul(s, v)
        segment_ = segment.reshape(bsz, h * w, 1).expand(bsz, h * w, x.size(-1))
        x = x.gather(dim=1, index=segment_)
        x = x.reshape(bsz, h, w, -1)
        # x = F.dropout(x, p=0.2, training=self.training)
        x = (res + x)  # [nb, h, w, dim]


        # query-guided visual attention
        if summary_query is not None:
            res = x
            #x = x.permute(0, 3, 1, 2)
            #global_visual_feature = F.adaptive_avg_pool2d(x, (1,1)).squeeze(2).squeeze(-1)
            #print("global_visual_feature:",global_visual_feature.size())
            #summary_query = summary_query + self.dropout(self.sum_linear(torch.cat([summary_query, global_visual_feature], dim=-1)))
            #x = x.permute(0, 2, 3, 1)
            x = x.reshape(bsz, h*w, -1)
            summary_query = self.relu(self.query_linear(summary_query)).unsqueeze(-2)
            #print("summary_query:",summary_query.size())
            x_feature = self.relu(self.video_linear(x))
            query_x_raw = x_feature * summary_query
            # print("query_x_raw:",query_x_raw.size())

            query_att_map = (self.attention_linar(query_x_raw).sigmoid())  #(bsz, h*w, channel)
            if self.training:
                if IN == True:
                    query_att_map = query_att_map.reshape(bsz, h, w, -1)
                    # query_att_map = query_att_map.permute(0, 3, 1, 2)
                    query_att_map = self.style_random_layer(query_att_map, K=2)  
                    # query_att_map = query_att_map.permute(0, 2, 3, 1)
                    query_att_map = query_att_map.reshape(bsz, h * w, -1)
                if Aug == True:
                    with torch.no_grad():
                        orginal_important_area = (self.important_linear(query_att_map).sigmoid() > 0.05)
                        unimportant_area = ~orginal_important_area
                        unimportant_area = unimportant_area.float()


                    with torch.no_grad():
                        idx_swap = torch.randperm(bsz)
                        query_generlization_raw = x_feature[idx_swap] * summary_query
                        query_generlization_map = (
                            self.attention_linar(query_generlization_raw).sigmoid())  # (bsz, h*w, 1)
                        one_map = torch.ones(bsz, h * w, 1).cuda()
                        background_map = one_map - query_generlization_map
                        background_feat = x_feature[idx_swap] * background_map
                        #x_aug = x_feature * query_att_map + background_feat * unimportant_area
                        x_aug = x_feature + background_feat * unimportant_area
                        #x_feature = x_feature + background_feat*unimportant_area


            x_feat = x_feature * query_att_map
            if self.training and Aug == True:
                x_feat = x_aug



            # print("x_feat:", x_feat.size())
            # print("query_att_map:",query_att_map.size())
            x = x_feat.reshape(bsz, h, w, -1) + res


        # video-query attention
        res = x
        input_x = self.pos_emb(x.transpose(-1, -2).transpose(-2, -3))
        x = self.attn(input_x.transpose(-2, -3).transpose(-1, -2),
                      None, query, query_mask, IN = IN_query)  # [nb, h, w, dim]
        x = F.dropout(x, p=0.1, training=self.training)  
        x = res + x

        res = x
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.dropout(x, p=0.1, training=self.training)
        x = res + x

        if self.training and Aug == True and epoch >= 6:
            x_feature = x_feature.reshape(bsz, h, w, -1).transpose(-1, -2).transpose(-2, -3)
            x_aug = x_aug.reshape(bsz, h, w, -1).transpose(-1, -2).transpose(-2, -3)
            return self.fc_o(x.transpose(-1, -2).transpose(-2, -3)), x_feature, x_aug, summary_query

        return self.fc_o(x.transpose(-1, -2).transpose(-2, -3))


class FinalModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self._build_clip_encoder()
        self._build_query_encoder()

        self.ali_predictor = {}

        self.ori_r = [8, 16, 32, 64, 128]
        self.ori_r = [10, 20, 40, 80, 160]

        for (a, b, c) in [(self.ori_r[-1] * 2, 64, 64),
                          (self.ori_r[-2] * 2, 64, 64),
                          (self.ori_r[-3] * 2, 192, 128),
                          (self.ori_r[-4] * 2, 480, 256),
                          (self.ori_r[-5] * 2, 832, 512)]:
            self.ali_predictor['pred_{}'.format(a)] = FullyConvolution(b, c, 1)

        self.up = nn.UpsamplingBilinear2d(scale_factor=2)

        # self.ali_predictor2 = FullyConvolution(config['hidden_size'],
        #                                        config['hidden_size'], 1)

        self.fconv = {}

        self.ra = {}
        self.resolution = {
            self.ori_r[0]: 1024,
            self.ori_r[1]: 832,
            self.ori_r[2]: 480,
            self.ori_r[3]: 192,
            self.ori_r[4]: 64
        }

        def get_video_spatial_feature(featmap_H, featmap_W):
            import numpy as np
            spatial_batch_val = np.zeros((1, 8, featmap_H, featmap_W))
            for h in range(featmap_H):
                for w in range(featmap_W):
                    xmin = w / featmap_W * 2 - 1
                    xmax = (w + 1) / featmap_W * 2 - 1
                    xctr = (xmin + xmax) / 2
                    ymin = h / featmap_H * 2 - 1
                    ymax = (h + 1) / featmap_H * 2 - 1
                    yctr = (ymin + ymax) / 2
                    spatial_batch_val[0, :, h, w] = [xmin, ymin, xmax, ymax, xctr, yctr, 1 / featmap_W, 1 / featmap_H]
            return torch.from_numpy(spatial_batch_val).float()

        for res in self.resolution.keys():
            self.ra['RegionAttention_{}'.format(res)] = \
                QueryGuidedRegionAttention2D(self.resolution[res], self.config['query_output_dim'] << 1,
                                             config['hidden_size'])
            # if res != self.ori_r[-1]:
            self.ra['RegionAttention_{}_i3d'.format(res)] = \
                QueryGuidedRegionAttention2D(self.resolution[res], self.config['query_output_dim'] << 1,
                                             config['hidden_size'], region_self=True)
            if res != self.ori_r[-1]:
                self.fconv['Conv_{}'.format(res)] = \
                    Conv2d(self.resolution[res], self.resolution[res * 2], kernel_size=1, stride=1, padding=0)

        for k in self.ra.keys():
            self.add_module(k, self.ra[k])
        for k in self.fconv.keys():
            self.add_module(k, self.fconv[k])
        for k in self.ali_predictor.keys():
            self.add_module(k, self.ali_predictor[k])

        self.style_random_layer = StyleRandomization()
        self.query_gru = nn.LSTM(input_size=768, hidden_size=1024,
                                 num_layers=1, batch_first=True, bidirectional=True, dropout=0.3)

    def _forward_impl(self, x):
        return x

    def load_pretrained_weights(self):

        path = '/home/user/Archive/i3d_rgb.pth'
        state_dict = torch.load(path)
        print("-------------------------------------------")
        print(self.i3d.load_state_dict(state_dict, strict=True))

    def forward(self, clip, query, query_len,
                coarse_gt_mask=None, fine_gt_mask=None, anch_mask=None, mask=None, segment=None, epoch=None,
                **kwargs):
        #print("query:",query.size())
        query_mask = generate_mask(query, query_len)
        #query_aug = random_swap(query, n=1)
        #print("query_aug:",query_aug.size())
        #query_h_aug = self.gru(query_aug, query_len)
       # _, (summary_query_aug, _) = self.query_gru(query_aug)
        query_h = self.gru(query, query_len)
        bs = query_h.size(0)
        _, (summary_query, _) = self.query_gru(query)
        summary_query = summary_query.transpose(0, 1).reshape(bs, -1)
        #summary_query_aug = summary_query_aug.transpose(0, 1).reshape(bs, -1)

        # if self.training and epoch >= 6:
        #     if_aug = random.randint(1, 2)
        #     if if_aug == 1:
        #         query_h = query_h_aug
        #         summary_query = summary_query_aug
        #     else:
        #         pass


        multi_res = {}

        def add_pos_emb(x):
            return x

        clip_h = clip
        #print("clip_h1:",clip_h.size())
        clip_h = self.i3d.conv3d_1a_7x7(clip_h)
        # print("clip_h2:", clip_h.size())

        clip_h = add_pos_emb(clip_h)
        res = self.ori_r[-1]
        x = clip_h.mean(dim=2)  # [batch, channel, H, W]
        # *****************************************************************
        x = self.style_random_layer(x, K = 4)  
        # *****************************************************************
        if self.training and epoch >= 6:
            x, x_ori, x_aug, query_NCE = self._modules['RegionAttention_{}_i3d'.format(res)](x, mask[res], segment[res], query_h, query_len,
                                                                query_mask, summary_query, Aug = True, IN_query = False,  epoch = epoch)
            x_aug_pool = nn.functional.adaptive_avg_pool2d(x_aug, (1, 1)).squeeze(-1).squeeze(-1)
            query_NCE = query_NCE.squeeze(1)
        else:
            #print("query_h:",query_h.size())
            #print("query_len",query_len)
            #print("x:",x.size())
            x = self._modules['RegionAttention_{}_i3d'.format(res)](x, mask[res], segment[res], query_h, query_len,
                                                                    query_mask, summary_query, Aug = True)

        clip_h = (clip_h + x.unsqueeze(2))
        multi_res[res] = clip_h.mean(dim=2)

        clip_h = self.i3d.maxPool3d_2a_3x3(clip_h)
        clip_h = self.i3d.conv3d_2b_1x1(clip_h)
        clip_h = self.i3d.conv3d_2c_3x3(clip_h)

        clip_h = add_pos_emb(clip_h)
        res = self.ori_r[-2]
        x = clip_h.mean(dim=2)  # + self._modules['conv_fuck_{}'.format(res)](self.pos_emb[res][0].cuda())
        x = self._modules['RegionAttention_{}_i3d'.format(res)](x, mask[res], segment[res], query_h, query_len,
                                                                query_mask, summary_query)
        clip_h = (clip_h + x.unsqueeze(2))
        multi_res[res] = clip_h.mean(dim=2)

        # print(out.shape) #192, 8, 128, 128 when 512 input
        clip_h = self.i3d.maxPool3d_3a_3x3(clip_h)
        clip_h = self.i3d.mixed_3b(clip_h)
        clip_h = self.i3d.mixed_3c(clip_h)

        clip_h = add_pos_emb(clip_h)
        res = self.ori_r[-3]
        x = clip_h.mean(dim=2)  # + self._modules['conv_fuck_{}'.format(res)](self.pos_emb[res][0].cuda())
        x = self._modules['RegionAttention_{}_i3d'.format(res)](x, mask[res], segment[res], query_h, query_len,
                                                                query_mask, summary_query)
        clip_h = (clip_h + x.unsqueeze(2))
        multi_res[res] = clip_h.mean(dim=2)

        # print(out.shape) #480, 8, 64, 64 when 512 input
        clip_h = self.i3d.maxPool3d_4a_3x3(clip_h)
        # print(out.shape) #480, 4, 32, 32 when 512 input
        clip_h = self.i3d.mixed_4b(clip_h)
        clip_h = self.i3d.mixed_4c(clip_h)
        clip_h = self.i3d.mixed_4d(clip_h)
        clip_h = self.i3d.mixed_4e(clip_h)
        # print(out.shape) #528, 4, 32, 32 when 512 input
        clip_h = self.i3d.mixed_4f(clip_h)

        clip_h = add_pos_emb(clip_h)
        res = self.ori_r[-4]
        x = clip_h.mean(dim=2)  # + self._modules['conv_fuck_{}'.format(res)](self.pos_emb[res][0].cuda())
        # if self.training and epoch >= 6:
        #     x, x_ori, x_aug, query_NCE = self._modules['RegionAttention_{}_i3d'.format(res)](x, mask[res], segment[res], query_h, query_len,
        #                                                             query_mask, summary_query, IN=True, epoch=epoch)
        #     #print("query:",query.size())  #query [4, 1, 512] x_ori [4, 512, h, w]
        #
        #     #x_ori_pool = nn.functional.adaptive_avg_pool2d(x_ori, (1, 1)).squeeze(-1).squeeze(-1)
        #     x_aug_pool = nn.functional.adaptive_avg_pool2d(x_aug, (1, 1)).squeeze(-1).squeeze(-1)
        #         #bsz = x.size(0)
        #         # idx_swap = torch.randperm(bsz)
        #     query_NCE = query_NCE.squeeze(1)
        #         #x_neg_pool = x_ori_pool[idx_swap]
        #         # info_nce = InfoNCE()
        #         # info_video_loss = info_nce(x_ori_pool, x_aug_pool, x_neg_pool)
        #         # info_query_loss = info_nce(query_NCE, x_aug_pool, x_neg_pool)


        if epoch >= 50:
            x = self._modules['RegionAttention_{}_i3d'.format(res)](x, mask[res], segment[res], query_h, query_len,
                                                                query_mask, summary_query, IN_query = True)
        else:
            x = self._modules['RegionAttention_{}_i3d'.format(res)](x, mask[res], segment[res], query_h, query_len,
                                                                    query_mask, summary_query, IN_query=False)
        clip_h = (clip_h + x.unsqueeze(2))
        multi_res[res] = clip_h.mean(dim=2)

        # print(out.shape) #832, 4, 32, 32 when 512 input
        clip_h = self.i3d.maxPool3d_5a_2x2(clip_h)
        clip_h = self.i3d.mixed_5b(clip_h)
        clip_h = self.i3d.mixed_5c(clip_h)

        clip_h = add_pos_emb(clip_h)
        res = self.ori_r[-5]
        x = clip_h.mean(dim=2)  # + self._modules['conv_fuck_{}'.format(res)](self.pos_emb[res][0].cuda())
        x = self._modules['RegionAttention_{}_i3d'.format(res)](x, mask[res], segment[res], query_h, query_len,
                                                                query_mask, summary_query)
        clip_h = (clip_h + x.unsqueeze(2))
        multi_res[res] = clip_h.mean(dim=2)

        ali_score_map = {}
        x = 0.0
        for layer_idx, res in enumerate(self.resolution):
            x = x  # + self._modules['conv_fuck_{}'.format(res)](self.pos_emb[res][0].cuda())
            x = (x + multi_res[res])
            x = x + self._modules['RegionAttention_{}'.format(res)](x, mask[res], segment[res],
                                                                    query_h, query_len, query_mask)
            if res != self.ori_r[-1]:
                x = self._modules['Conv_{}'.format(res)](x, mask[res])
            x = self.up(x)
            if 'pred_{}'.format(res * 2) in self.ali_predictor.keys():
                input_x = x
                # if res != 128:
                #     input_x = torch.cat([input_x, multi_res[res * 2]], dim=1)
                ali_score_map[res * 2] = self._modules['pred_{}'.format(res * 2)](input_x, mask[res * 2])
                ali_score_map[res * 2] = torch.sigmoid(
                    ali_score_map[res * 2])  #
                if not self.training:
                    ali_score_map[res * 2] = ali_score_map[res * 2].masked_fill(mask[res * 2].unsqueeze(1) == 0, 0)

        final_dict = {
            'ali_score_map': ali_score_map,
            # 'fix_score_map': fix_score_map,
            'fine_gt_mask': fine_gt_mask,
            'coarse_gt_mask': coarse_gt_mask,
            'mask': mask
        }

        if self.training and False:
            contrast_score, _, _ = self._contrastive_score(emb, mask[self.ori_r[-1] * 2], anch_mask[self.ori_r[-1] * 2],
                                                           fg_score_map)
            final_dict.update({
                'contrast_score': contrast_score,
                # 'diversity_loss': diversity_loss,
                # 'same_loss': same_loss,
            })
        if self.training and epoch >=6:
            return final_dict, query_NCE, x_aug_pool

        return final_dict

    def _contrastive_score(self, clip_h, mask, anch_mask, fg_score_map):

        bsz, dim, h, w = clip_h.size()
        anch_mask = anch_mask.unsqueeze(1)

        mask1 = (anch_mask == 1).long()
        score = fg_score_map.masked_fill(mask1 == 0, float('-1e30')).reshape(bsz, h * w)
        score = F.softmax(score, dim=-1)
        diversity_loss = -(score * torch.log(score + 1e-10)).sum(dim=-1).mean(dim=0)
        score = score.reshape(bsz, 1, h, w)
        anchor_emb = (clip_h * score).sum(dim=-1).sum(dim=-1)
        anchor_emb = F.normalize(anchor_emb, dim=-1)

        mask1 = (anch_mask == 2).long()
        score = fg_score_map.masked_fill(mask1 == 0, float('-1e30')).reshape(bsz, h * w)
        score = F.softmax(score, dim=-1).reshape(bsz, 1, h, w)
        pos_emb = (clip_h * score).sum(dim=-1).sum(dim=-1)

        # print(anchor_emb.size(), pos_emb.size())

        exist_mask = [torch.ones(bsz, 1).type_as(mask)]
        contrast_emb = [pos_emb]

        same_loss = 0.0

        for neg_idx in [3, 4, 5, 6]:
            mask1 = (anch_mask == neg_idx).long()
            is_exist = ((mask1 == 1).sum(dim=-1).sum(dim=-1) > 0).long()
            score = fg_score_map.masked_fill(mask1 == 0, float('-1e30')).reshape(bsz, h * w)
            score = F.softmax(score, dim=-1).reshape(bsz, 1, h, w)
            same_loss += ((-(score * torch.log(score + 1e-10)).sum(dim=-1))
                          * is_exist.float()).sum() / (is_exist.sum().float() + 1e-10)

            neg_emb = (clip_h * score).sum(dim=-1).sum(dim=-1)
            contrast_emb.append(neg_emb)
            exist_mask.append(is_exist)
        same_loss /= 4
        exist_mask = torch.cat(exist_mask, dim=1)

        contrast_emb = F.normalize(torch.stack(contrast_emb, dim=1), dim=-1)

        score = torch.matmul(anchor_emb.unsqueeze(1), contrast_emb.transpose(-1, -2)).squeeze(1) * self.lambda_.cuda()
        score = F.softmax(score.masked_fill(exist_mask == 0, float('-1e30')), dim=-1)
        return score[:, 0], diversity_loss, same_loss

    def _build_clip_encoder(self):
        self.i3d = I3D(num_classes=400, modality='rgb')

    def _build_query_encoder(self):
        self.word2vec = nn.Embedding(1200, self.config['query_dim'], padding_idx=0)
        self.gru = DynamicGRU(self.config['query_dim'], self.config['query_output_dim'],
                              num_layers=1, bidirectional=True, batch_first=True)


class VisionGuidedAttention(nn.Module):
    def __init__(self, src_dim, src_dim2, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.fc1 = nn.Linear(src_dim, self.hidden_size)
        self.fc2 = nn.Linear(src_dim2, self.hidden_size)
        self.fc3 = nn.Linear(src_dim2, self.hidden_size)
        self.fco = nn.Linear(self.hidden_size, src_dim - 8)
        self.style_random_layer = StyleRandomization()

    def forward(self, clip, clip_mask, query, query_mask=None, IN = False):
        bsz, h, w, _ = clip.size()
        l = query.size(1)

        nh = 1

        a = self.fc1(clip)
        b = self.fc2(query)
        c = self.fc3(query)
        # print(a.size(), b.size())
        # exit(0)
        a = a.reshape(bsz, h * w, nh, -1)
        b = b.reshape(bsz, l, nh, -1)
        c = c.reshape(bsz, l, nh, -1)
        # score = torch.matmul(a, b.transpose(-1, -2)) / math.sqrt(self.hidden_size)
        score = torch.einsum('bihd,bjhd->bijh', a, b) / math.sqrt(self.hidden_size // nh)  # b,h*w,l,nh

        if self.training and IN == True:
            score = score.squeeze(-1)
            score = score.reshape(bsz, h, w, l)
            score = score.permute(0, 3, 1, 2)
            score = self.style_random_layer(score, K = 2)
            score = score.permute(0, 2, 3, 1)
            score = score.reshape(bsz, h*w, l).unsqueeze(-1)




        if query_mask is not None:
            score = score.masked_fill_(query_mask.unsqueeze(-1).unsqueeze(1) == 0,
                                       float('-inf'))
        score = F.softmax(score, -2)

        query_ = torch.einsum('bijh,bjhd->bihd', score, c).reshape(bsz, h * w, -1)

        return self.fco(query_.reshape(bsz, h, w, -1))


class TanhAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        # self.dropout = nn.Dropout(dropout)
        self.ws1 = nn.Linear(d_model, d_model, bias=True)
        self.ws2 = nn.Linear(d_model, d_model, bias=False)
        self.wst = nn.Linear(d_model, 1, bias=False)

    def reset_parameters(self):
        self.ws1.reset_parameters()
        self.ws2.reset_parameters()
        self.wst.reset_parameters()

    def forward(self, x, memory, memory_mask=None, fast_weights=None, **kwargs):
        if fast_weights is None:
            item1 = self.ws1(x)  # [nb, len1, d]
            item2 = self.ws2(memory)  # [nb, len2, d]
            # print(item1.shape, item2.shape)
            item = item1.unsqueeze(2) + item2.unsqueeze(1)  # [nb, len1, len2, d]
            S = self.wst(torch.tanh(item)).squeeze(-1)  # [nb, len1, len2]
        else:
            item1 = F.linear(x, fast_weights['ws1.weight'], fast_weights['ws1.bias'])  # [nb, len1, d]
            item2 = F.linear(memory, fast_weights['ws2.weight'])  # [nb, len2, d]
            # print(item1.shape, item2.shape)
            item = item1.unsqueeze(2) + item2.unsqueeze(1)  # [nb, len1, len2, d]
            S = F.linear(torch.tanh(item), fast_weights['wst.weight']).squeeze(-1)  # [nb, len1, len2]
        if memory_mask is not None:
            memory_mask = memory_mask.unsqueeze(1)  # [nb, 1, len2]
            S = S.masked_fill(memory_mask == 0, float('-inf'))
        S = F.softmax(S, -1)
        return torch.matmul(S, memory), S  # [nb, len1, d]


class CrossGate(nn.Module):
    def __init__(self, h1, h2):
        super().__init__()
        self.g1 = nn.Linear(h2, h1)
        self.g2 = nn.Linear(h1, h2)

    def forward(self, x1, x2):
        return x1 * torch.sigmoid(self.g1(x2)), x2 * torch.sigmoid(self.g2(x1))


def generate_mask(x, x_len):
    if False and int(x_len.min()) == x.size(1):
        mask = None
    else:
        mask = []
        for l in x_len:
            mask.append(torch.zeros([x.size(1)]).long())
            mask[-1][:l] = 1
        mask = torch.stack(mask, 0).cuda()
    return mask



def generate_coordinate_emb(clip, h, w):
    bsz = clip.size(0)
    x = (2 * (torch.linspace(0, h, h).type_as(clip) / h) - 1)
    x = x.unsqueeze(-1).expand(h, w)
    y = (2 * (torch.linspace(0, w, w).type_as(clip) / w) - 1)
    y = y.unsqueeze(0).expand(h, w)
    co = torch.stack([x, y], -1)
    co = F.normalize(co, -1)
    co = co.unsqueeze(0).expand(bsz, -1, -1, -1)
    return co


if __name__ == '__main__':
    from torch.utils.data import DataLoader
    from fairseq.utils import move_to_cuda
    from datasets.a2d import A2D

    args = {
        "videoset_path": "/home1/user/data/A2D/Release/videoset.csv",
        "annotation_path": "/home1/user/data/A2D/Release/Annotations",
        "vocab_path": "/home1/user/code/mm-2020/data/glove_a2d.bin",
        "sample_path": "/home1/user/data/A2D/a2d_annotation2.txt",
        "max_num_words": 20,
    }
    dataset = A2D(args)
    # dataset.train_set[66]
    # exit(0)
    loader = DataLoader(dataset.train_set, batch_size=4, shuffle=True, num_workers=1,
                        pin_memory=True, collate_fn=dataset.collate_fn)
    args = {
        "hidden_size": 256,
        "clip_dim": 832,
        "query_dim": 300,
    }
    model = MainModel(args).cuda()
    for batch in loader:
        net_input = move_to_cuda(batch['net_input'])
        output = model(**net_input)
        exit(0)
