import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter

from models.conv import FullyConvolution
from models.i3d_model import InceptionModule
from models.rnn import DynamicGRU


class QueryGuidedRegionAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.fc_q = nn.Linear(config['hidden_size'], config['hidden_size'])
        self.fc_k = nn.Linear(config['hidden_size'], config['hidden_size'])
        self.fc_v = nn.Linear(config['hidden_size'], config['hidden_size'])

        self.attn = TanhAttention(config['hidden_size'])
        self.fc_g1 = nn.Linear(config['hidden_size'], config['hidden_size'])
        self.fc_g2 = nn.Linear(config['hidden_size'], config['hidden_size'])
        self.fc_g3 = nn.Linear(config['hidden_size'], config['hidden_size'])

    def forward(self, clip_h, clip_mask, segment, query, query_len, query_mask):
        # query_ = []
        # for i, l in enumerate(query_len):
        #     query_.append(query[i, :l].mean(dim=0))
        # query_ = torch.stack(query_, dim=0)

        bsz, dim, h, w = clip_h.size()
        segment_ = segment.reshape(bsz, h * w, 1).expand(bsz, h * w, dim)
        clip_h_ = clip_h.transpose(-2, -3).transpose(-1, -2).reshape(bsz, h * w, dim)
        # segment_: [bsz, h * w, dim], clip_h_: [bsz, h * w, dim]
        cluster_emb = scatter(src=clip_h_, index=segment_, dim=1, reduce='max')
        cluster_len = torch.max(segment.reshape(bsz, -1), dim=-1)[0] + 1
        cluster_mask = generate_mask(cluster_emb, cluster_len)
        cluster_mask[:, 0] = 0

        query_, _ = self.attn(cluster_emb, query, query_mask)

        q, k = self.fc_q(cluster_emb) * torch.tanh(self.fc_g1(query_)), \
               self.fc_k(cluster_emb) * torch.tanh(self.fc_g2(query_))
        v = self.fc_v(cluster_emb) * torch.tanh(self.fc_g3(query_))
        s = torch.matmul(q, k.transpose(-1, -2))
        s = s.masked_fill(cluster_mask.unsqueeze(1) == 0, float('-inf'))
        s = F.softmax(s, dim=-1)
        x = torch.matmul(s, v)
        segment_ = segment.reshape(bsz, h * w, 1).expand(bsz, h * w, dim)
        x = x.gather(dim=1, index=segment_)
        x = x.reshape(bsz, h, w, -1).transpose(-1, -2).transpose(-2, -3)

        return clip_h + x * clip_mask.unsqueeze(1).float()


class QueryGuidedRegionAttention2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.fc_q = nn.Linear(config['hidden_size'], config['hidden_size'])
        self.fc_k = nn.Linear(config['hidden_size'], config['hidden_size'])
        self.fc_v = nn.Linear(config['hidden_size'], config['hidden_size'])

        self.attn = VisionGuidedAttention(config['hidden_size'], config['hidden_size'])
        self.fc_g = nn.Linear(config['hidden_size'], config['hidden_size'])

    def forward(self, clip_h, clip_mask, segment, query, query_len, query_mask):
        clip_h_ = clip_h.transpose(-2, -3).transpose(-1, -2)
        query_ = self.attn(clip_h_, clip_mask, query, query_mask)
        clip_h_ = torch.tanh(self.fc_g(query_)) * clip_h_
        clip_h_ = clip_h_.transpose(-1, -2).transpose(-2, -3)

        bsz, dim, h, w = clip_h_.size()
        segment_ = segment.reshape(bsz, h * w, 1).expand(bsz, h * w, dim)
        clip_h_ = clip_h_.transpose(-2, -3).transpose(-1, -2).reshape(bsz, h * w, dim)
        # segment_: [bsz, h * w, dim], clip_h_: [bsz, h * w, dim]
        cluster_emb = scatter(src=clip_h_, index=segment_, dim=1, reduce='max')
        cluster_len = torch.max(segment.reshape(bsz, -1), dim=-1)[0] + 1
        cluster_mask = generate_mask(cluster_emb, cluster_len)
        cluster_mask[:, 0] = 0

        q, k = self.fc_q(cluster_emb), self.fc_k(cluster_emb)
        v = self.fc_v(cluster_emb)
        s = torch.matmul(q, k.transpose(-1, -2))
        s = s.masked_fill(cluster_mask.unsqueeze(1) == 0, float('-inf'))
        s = F.softmax(s, dim=-1)
        x = torch.matmul(s, v)
        segment_ = segment.reshape(bsz, h * w, 1).expand(bsz, h * w, dim)
        x = x.gather(dim=1, index=segment_)
        x = x.reshape(bsz, h, w, -1).transpose(-1, -2).transpose(-2, -3)
        x = x * clip_mask.unsqueeze(1).float()
        x = F.dropout(x, 0.2, self.training)
        x = F.relu(x)

        return clip_h + x


class BestSupervisedModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self._build_clip_encoder()
        self._build_query_encoder()

        self.ra1 = QueryGuidedRegionAttention2(config)
        self.ra2 = QueryGuidedRegionAttention2(config)

        # contrastive & tightness: fg and bg
        # self.dconv2 = DeConvolution(config['clip_dim'] + 2, config['hidden_size'])
        # self.conv2 = Conv2d(config['hidden_size'] + 2, config['hidden_size'], kernel_size=1, stride=1, padding=0)
        # self.fg_predictor = Conv2d(config['hidden_size'], 1, kernel_size=1, stride=1, padding=0)
        # self.lambda_ = nn.Parameter(torch.ones(1, 1).float() * 10, requires_grad=False)

        # cross-modal learning
        self.fconv1 = FullyConvolution(config['clip_dim'] + 2, config['hidden_size'],
                                       config['hidden_size'])
        self.fconv2 = FullyConvolution(config['hidden_size'] + 2, config['hidden_size'],
                                       config['hidden_size'])
        # self.dconv1 = DeConvolution(config['hidden_size'], config['hidden_size'])
        self.ali_predictor = FullyConvolution(config['hidden_size'],
                                              config['hidden_size'], 1)
        self.up = nn.UpsamplingBilinear2d(scale_factor=4)

        self.ali_predictor2 = FullyConvolution(config['hidden_size'],
                                               config['hidden_size'], 1)

        self._load_pretrained_weights()

    def _load_pretrained_weights(self):
        path = '/home/user/pytorch-i3d/models/rgb_imagenet.pt'
        state_dict = torch.load(path)
        m = {'b': '0', 'c': '1', 'd': '2', 'e': '3', 'f': '4'}
        new_state_dict = {}
        for k, v in state_dict.items():
            if 'Mixed_4' in k:
                nk = m[k[len('Mixed_4')]] + k[len('Mixed_4*'):]
                # print(nk)
                new_state_dict[nk] = v

        print('load i3d model:', self.i3d.load_state_dict(new_state_dict))

    def forward(self, clip, query, query_len, coarse_gt_mask, fine_gt_mask, anch_mask, mask, segment, **kwargs):
        query_mask = generate_mask(query, query_len)
        query_h = self.gru(query, query_len)

        coord_emb = generate_coordinate_emb(clip).transpose(-1, -2).transpose(-3, -2)
        clip_h = clip.transpose(-1, -2).transpose(-2, -3).transpose(-3, -4)
        for i in self.i3d:
            clip_h = i(clip_h)
        clip_h = clip_h.mean(dim=2).float()

        clip_h = self.fconv1(torch.cat([clip_h, coord_emb], dim=1), mask[32])
        clip_h = self.ra1(clip_h, mask[32],
                          segment[32], query_h, query_len, query_mask)

        # ali_score_map2 = self.ali_predictor2(clip_h, mask[32])

        clip_h, coord_emb = self.up(clip_h), self.up(coord_emb)
        clip_h = self.fconv2(torch.cat([clip_h, coord_emb], dim=1), mask[128])
        clip_h = self.ra2(clip_h, mask[128],
                          segment[128], query_h, query_len, query_mask)
        ali_score_map = self.ali_predictor(clip_h, mask[128])

        res = {
            'ali_score_map': torch.sigmoid(ali_score_map).masked_fill(mask[128].unsqueeze(1) == 0, 0),
            # 'ali_score_map2': torch.sigmoid(ali_score_map2).masked_fill(mask[32].unsqueeze(1) == 0, 0),
            'fine_gt_mask': fine_gt_mask,
            'coarse_gt_mask': coarse_gt_mask,
            'mask': mask
        }

        return res

    def _contrastive_score(self, clip_h, mask, anch_mask, fg_score_map):
        # clip_h = F.relu(self.conv1(clip_h, mask))
        # clip_h = self.conv2(clip_h, mask)

        bsz, dim, h, w = clip_h.size()
        anch_mask = anch_mask.unsqueeze(1)

        mask1 = (anch_mask == 1).long()
        score = fg_score_map.masked_fill(mask1 == 0, float('-1e30')).reshape(bsz, h * w)
        score = F.softmax(score, dim=-1)
        diversity_loss = -(score * torch.log(score + 1e-10)).sum(dim=-1).mean(dim=0)
        score = score.reshape(bsz, 1, h, w)
        anchor_emb = (clip_h * score).sum(dim=-1).sum(dim=-1)
        anchor_emb = F.normalize(anchor_emb, dim=-1)

        mask1 = (anch_mask == 2).long()
        score = fg_score_map.masked_fill(mask1 == 0, float('-1e30')).reshape(bsz, h * w)
        score = F.softmax(score, dim=-1).reshape(bsz, 1, h, w)
        pos_emb = (clip_h * score).sum(dim=-1).sum(dim=-1)

        # print(anchor_emb.size(), pos_emb.size())

        exist_mask = [torch.ones(bsz, 1).type_as(mask)]
        contrast_emb = [pos_emb]

        same_loss = 0.0

        for neg_idx in [3, 4, 5, 6]:
            mask1 = (anch_mask == neg_idx).long()
            is_exist = ((mask1 == 1).sum(dim=-1).sum(dim=-1) > 0).long()
            score = fg_score_map.masked_fill(mask1 == 0, float('-1e30')).reshape(bsz, h * w)
            score = F.softmax(score, dim=-1).reshape(bsz, 1, h, w)
            same_loss += ((-(score * torch.log(score + 1e-10)).sum(dim=-1))
                          * is_exist.float()).sum() / (is_exist.sum().float() + 1e-10)

            neg_emb = (clip_h * score).sum(dim=-1).sum(dim=-1)
            contrast_emb.append(neg_emb)
            exist_mask.append(is_exist)
        same_loss /= 4
        exist_mask = torch.cat(exist_mask, dim=1)

        contrast_emb = F.normalize(torch.stack(contrast_emb, dim=1), dim=-1)

        score = torch.matmul(anchor_emb.unsqueeze(1), contrast_emb.transpose(-1, -2)).squeeze(1) * self.lambda_.cuda()
        score = F.softmax(score.masked_fill(exist_mask == 0, float('-1e30')), dim=-1)
        return score[:, 0], diversity_loss, same_loss

    def _build_clip_encoder(self):
        self.i3d = []
        name = 'inception_i3d'
        end_point = 'Mixed_4b'
        self.i3d.append(InceptionModule(128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point))
        end_point = 'Mixed_4c'
        self.i3d.append(InceptionModule(192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point))
        end_point = 'Mixed_4d'
        self.i3d.append(InceptionModule(160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point))
        end_point = 'Mixed_4e'
        self.i3d.append(InceptionModule(128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point))
        self.i3d.append(InceptionModule(112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128], name + 'Mixed_4f'))
        self.i3d = nn.ModuleList(self.i3d)


    def _build_query_encoder(self):
        self.gru = DynamicGRU(self.config['query_dim'], self.config['hidden_size'] >> 1,
                              num_layers=2, bidirectional=True, batch_first=True)


class VisionGuidedAttention(nn.Module):
    def __init__(self, src_dim, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.fc1 = nn.Linear(src_dim, self.hidden_size)
        self.fc2 = nn.Linear(self.hidden_size, self.hidden_size)

    def forward(self, clip, clip_mask, query, query_mask=None):
        bsz, h, w, _ = clip.size()
        a, b = self.fc1(clip), self.fc2(query)
        a = a.reshape(bsz, h * w, -1)
        score = torch.matmul(a, b.transpose(-1, -2)) / math.sqrt(self.hidden_size)
        if query_mask is not None:
            score = score.masked_fill_(query_mask.unsqueeze(1) == 0, float('-inf'))
        score = F.softmax(score, -1)
        # assert not torch.isnan(score).any()
        query_ = torch.matmul(score, query)
        # .view(bsz, h, w, -1)
        # assert not torch.isnan(query_).any()
        if clip_mask is not None:
            clip_mask = clip_mask.reshape(bsz, -1)
            query_ = query_.masked_fill_(clip_mask.unsqueeze(-1) == 0, 0)
        return query_.reshape(bsz, h, w, -1)


class TanhAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        # self.dropout = nn.Dropout(dropout)
        self.ws1 = nn.Linear(d_model, d_model, bias=True)
        self.ws2 = nn.Linear(d_model, d_model, bias=False)
        self.wst = nn.Linear(d_model, 1, bias=False)

    def reset_parameters(self):
        self.ws1.reset_parameters()
        self.ws2.reset_parameters()
        self.wst.reset_parameters()

    def forward(self, x, memory, memory_mask=None, fast_weights=None, **kwargs):
        if fast_weights is None:
            item1 = self.ws1(x)  # [nb, len1, d]
            item2 = self.ws2(memory)  # [nb, len2, d]
            # print(item1.shape, item2.shape)
            item = item1.unsqueeze(2) + item2.unsqueeze(1)  # [nb, len1, len2, d]
            S = self.wst(torch.tanh(item)).squeeze(-1)  # [nb, len1, len2]
        else:
            item1 = F.linear(x, fast_weights['ws1.weight'], fast_weights['ws1.bias'])  # [nb, len1, d]
            item2 = F.linear(memory, fast_weights['ws2.weight'])  # [nb, len2, d]
            # print(item1.shape, item2.shape)
            item = item1.unsqueeze(2) + item2.unsqueeze(1)  # [nb, len1, len2, d]
            S = F.linear(torch.tanh(item), fast_weights['wst.weight']).squeeze(-1)  # [nb, len1, len2]
        if memory_mask is not None:
            memory_mask = memory_mask.unsqueeze(1)  # [nb, 1, len2]
            S = S.masked_fill(memory_mask == 0, float('-inf'))
        S = F.softmax(S, -1)
        return torch.matmul(S, memory), S  # [nb, len1, d]


class CrossGate(nn.Module):
    def __init__(self, h1, h2):
        super().__init__()
        self.g1 = nn.Linear(h2, h1)
        self.g2 = nn.Linear(h1, h2)

    def forward(self, x1, x2):
        return x1 * torch.sigmoid(self.g1(x2)), x2 * torch.sigmoid(self.g2(x1))


def generate_mask(x, x_len):
    if False and int(x_len.min()) == x.size(1):
        mask = None
    else:
        mask = []
        for l in x_len:
            mask.append(torch.zeros([x.size(1)]).long())
            mask[-1][:l] = 1
        mask = torch.stack(mask, 0).cuda()
    return mask


def generate_coordinate_emb(clip):
    bsz, t, h, w, dim = clip.size()
    x = (2 * (torch.linspace(0, h, h).type_as(clip) / h) - 1)
    x = x.unsqueeze(-1).expand(h, w)
    y = (2 * (torch.linspace(0, w, w).type_as(clip) / w) - 1)
    y = y.unsqueeze(0).expand(h, w)
    co = torch.stack([x, y], -1)
    co = F.normalize(co, -1)
    co = co.unsqueeze(0).expand(bsz, -1, -1, -1)
    return co


if __name__ == '__main__':
    from torch.utils.data import DataLoader
    from fairseq.utils import move_to_cuda
    from datasets.a2d import A2D

    args = {
        "videoset_path": "/home/user/data/A2D/Release/videoset.csv",
        "annotation_path": "/home/user/data/A2D/Release/Annotations",
        "vocab_path": "/home/user/code/mm-2020/data/glove_a2d.bin",
        "sample_path": "/home/user/data/A2D/a2d_annotation2.txt",
        "max_num_words": 20,
    }
    dataset = A2D(args)
    # dataset.train_set[66]
    # exit(0)
    loader = DataLoader(dataset.train_set, batch_size=4, shuffle=True, num_workers=1,
                        pin_memory=True, collate_fn=dataset.collate_fn)
    args = {
        "hidden_size": 256,
        "clip_dim": 832,
        "query_dim": 300,
    }
    model = MainModel(args).cuda()
    for batch in loader:
        net_input = move_to_cuda(batch['net_input'])
        output = model(**net_input)
        exit(0)
