import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50

from models.conv import DeConvolution, FullyConvolution
from models.i3d_model import InceptionModule, MaxPool3dSamePadding, Unit3D
from models.rnn import DynamicGRU


class ga2018(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self._build_clip_encoder()
        self._build_query_encoder()

        self.fc_q1 = nn.Linear(config['hidden_size'], 832)
        self.fc_q2 = nn.Linear(config['hidden_size'], 256)
        self.fc_q3 = nn.Linear(config['hidden_size'], 128)

        self.d1 = DeConvolution(config['clip_dim'], 256)
        self.d2 = DeConvolution(256, 128)

        self.fcn1 = FullyConvolution(config['clip_dim'], config['hidden_size'], 1)
        self.fcn2 = FullyConvolution(256, config['hidden_size'], 1)
        self.fcn3 = FullyConvolution(128, config['hidden_size'], 1)

        self.seq = nn.Sequential(
            nn.PixelShuffle(2),
            FullyConvolution(128, 128, 128),
            nn.UpsamplingBilinear2d(scale_factor=2),
            FullyConvolution(128, 128, 128),
            nn.UpsamplingBilinear2d(scale_factor=2),
        )

        self.pix_emb = FullyConvolution(128, 128, 128)
        self.pix_score = FullyConvolution(128, 128, 1)
        self.lambda_ = nn.Parameter(torch.ones(1, 1).float() * 10, requires_grad=False)

        self._load_pretrained_weights()

        self.resnet = resnet50(pretrained=True)
        for n, p in self.resnet.named_parameters():
            p.requires_grad = False

    def _load_pretrained_weights(self):
        path = '/home/user/pytorch-i3d/models/rgb_charades.pt'
        # path = '/home/user/pytorch-i3d/models/rgb_imagenet.pt'
        state_dict = torch.load(path)
        self.load_state_dict(state_dict, strict=False)
        # print('load i3d model:', self.load_state_dict(state_dict, strict=False))

    def forward(self, clip, target_frame, query, query_len,
                coarse_gt_mask=None, fine_gt_mask=None, anch_mask=None, mask=None, segment=None,
                **kwargs):
        # input torch.Size([bsz, 3, 8, 256, 256])
        # Conv3d_1a_7x7 torch.Size([bsz, 64, 4, 128, 128])
        # Conv3d_2c_3x3 torch.Size([bsz, 192, 4, 64, 64])
        # Mixed_3c torch.Size([bsz, 480, 4, 32, 32])
        # Mixed_4f torch.Size([bsz, 832, 2, 16, 16])
        query_mask = generate_mask(query, query_len)
        query_h = self.gru(query, query_len)
        new_query_h = []
        for i, l in enumerate(query_len):
            new_query_h.append(query_h[i, :l].mean(dim=0))
        query_h = torch.stack(new_query_h, dim=0)

        clip_h = clip
        for end_point in self.end_points:
            clip_h = self._modules[end_point](clip_h)  # use _modules to work with dataparallel
            if end_point == 'Mixed_4f':
                break

        ali_score_map = {}
        q1, q2, q3 = self.fc_q1(query_h), self.fc_q2(query_h), self.fc_q3(query_h)
        x = clip_h.mean(dim=2)

        ali_score_map[16] = torch.sigmoid(self.fcn1(x * q1.unsqueeze(-1).unsqueeze(-1)))
        x = self.d1(x)
        ali_score_map[64] = torch.sigmoid(self.fcn2(x * q2.unsqueeze(-1).unsqueeze(-1)))
        x = self.d2(x)
        ali_score_map[256] = torch.sigmoid(self.fcn3(x * q3.unsqueeze(-1).unsqueeze(-1)))

        final_dict = {
            'ali_score_map': ali_score_map,
            # 'fix_score_map': fix_score_map,
            'fine_gt_mask': fine_gt_mask,
            'coarse_gt_mask': coarse_gt_mask,
            'mask': mask
        }

        if self.training:
            frame_h = self._forward_impl(target_frame)
            frame_h = self.seq(frame_h)

            emb = self.pix_emb(frame_h)
            fg_score_map = self.pix_score(frame_h)

            # for i in fg_score_map[0]:
            #     print(i.tolist())
            # exit(0)

            contrast_score, _, _ = self._contrastive_score(emb, mask[256], anch_mask[256],
                                                           fg_score_map)
            final_dict.update({
                'fg_score_map': torch.sigmoid(fg_score_map).masked_fill(mask[256].unsqueeze(1) == 0, 0),
                'contrast_score': contrast_score,
                # 'diversity_loss': diversity_loss,
                # 'same_loss': same_loss,
            })

        return final_dict

    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        col = []
        x = self.resnet.layer1(x)
        col.append(x)
        x = self.resnet.layer2(x)
        col.append(x)
        # x = self.resnet.layer3(x)
        # col.append(x)
        # x = self.resnet.layer4(x)
        # col.append(x)

        return col[1]


    def _build_clip_encoder(self):
        in_channels = 3
        name = 'inception_i3d'

        self.end_points = {}
        end_point = 'Conv3d_1a_7x7'
        self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7],
                                            stride=(2, 2, 2), padding=(3, 3, 3), name=name + end_point)

        end_point = 'MaxPool3d_2a_3x3'
        self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
                                                          padding=0)

        end_point = 'Conv3d_2b_1x1'
        self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0,
                                            name=name + end_point)

        end_point = 'Conv3d_2c_3x3'
        self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1,
                                            name=name + end_point)

        end_point = 'MaxPool3d_3a_3x3'
        self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
                                                          padding=0)

        end_point = 'Mixed_3b'
        self.end_points[end_point] = InceptionModule(192, [64, 96, 128, 16, 32, 32], name + end_point)

        end_point = 'Mixed_3c'
        self.end_points[end_point] = InceptionModule(256, [128, 128, 192, 32, 96, 64], name + end_point)

        end_point = 'MaxPool3d_4a_3x3'
        self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2),
                                                          padding=0)

        end_point = 'Mixed_4b'
        self.end_points[end_point] = InceptionModule(128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point)

        end_point = 'Mixed_4c'
        self.end_points[end_point] = InceptionModule(192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point)

        end_point = 'Mixed_4d'
        self.end_points[end_point] = InceptionModule(160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point)

        end_point = 'Mixed_4e'
        self.end_points[end_point] = InceptionModule(128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point)

        end_point = 'Mixed_4f'
        self.end_points[end_point] = InceptionModule(112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128],
                                                     name + end_point)

        end_point = 'MaxPool3d_5a_2x2'
        self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2),
                                                          padding=0)

        end_point = 'Mixed_5b'
        self.end_points[end_point] = InceptionModule(256 + 320 + 128 + 128, [256, 160, 320, 32, 128, 128],
                                                     name + end_point)

        end_point = 'Mixed_5c'
        self.end_points[end_point] = InceptionModule(256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128],
                                                     name + end_point)
        for k in self.end_points.keys():
            self.add_module(k, self.end_points[k])

        for k, m in self.end_points.items():
            if 'Mixed_4f' not in k:
                for n, p in m.named_parameters():
                    p.requires_grad = False

    def _build_query_encoder(self):
        self.gru = DynamicGRU(self.config['query_dim'], self.config['hidden_size'] >> 1,
                              num_layers=1, bidirectional=True, batch_first=True)

    def _contrastive_score(self, clip_h, mask, anch_mask, fg_score_map):
        # clip_h = F.relu(self.conv1(clip_h, mask))
        # clip_h = self.conv2(clip_h, mask)

        bsz, dim, h, w = clip_h.size()
        anch_mask = anch_mask.unsqueeze(1)

        mask1 = (anch_mask == 1).long()
        score = fg_score_map.masked_fill(mask1 == 0, float('-1e30')).reshape(bsz, h * w)
        score = F.softmax(score, dim=-1)
        diversity_loss = -(score * torch.log(score + 1e-10)).sum(dim=-1).mean(dim=0)
        score = score.reshape(bsz, 1, h, w)
        anchor_emb = (clip_h * score).sum(dim=-1).sum(dim=-1)
        anchor_emb = F.normalize(anchor_emb, dim=-1)

        mask1 = (anch_mask == 2).long()
        score = fg_score_map.masked_fill(mask1 == 0, float('-1e30')).reshape(bsz, h * w)
        score = F.softmax(score, dim=-1).reshape(bsz, 1, h, w)
        pos_emb = (clip_h * score).sum(dim=-1).sum(dim=-1)

        # print(anchor_emb.size(), pos_emb.size())

        exist_mask = [torch.ones(bsz, 1).type_as(mask)]
        contrast_emb = [pos_emb]

        same_loss = 0.0

        for neg_idx in [3, 4, 5, 6]:
            mask1 = (anch_mask == neg_idx).long()
            is_exist = ((mask1 == 1).sum(dim=-1).sum(dim=-1) > 0).long()
            score = fg_score_map.masked_fill(mask1 == 0, float('-1e30')).reshape(bsz, h * w)
            score = F.softmax(score, dim=-1).reshape(bsz, 1, h, w)
            same_loss += ((-(score * torch.log(score + 1e-10)).sum(dim=-1))
                          * is_exist.float()).sum() / (is_exist.sum().float() + 1e-10)

            neg_emb = (clip_h * score).sum(dim=-1).sum(dim=-1)
            contrast_emb.append(neg_emb)
            exist_mask.append(is_exist)
        same_loss /= 4
        exist_mask = torch.cat(exist_mask, dim=1)

        contrast_emb = F.normalize(torch.stack(contrast_emb, dim=1), dim=-1)

        score = torch.matmul(anchor_emb.unsqueeze(1), contrast_emb.transpose(-1, -2)).squeeze(1) * self.lambda_.cuda()
        score = F.softmax(score.masked_fill(exist_mask == 0, float('-1e30')), dim=-1)
        return score[:, 0], diversity_loss, same_loss


class VisionGuidedAttention(nn.Module):
    def __init__(self, src_dim, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.fc1 = nn.Linear(src_dim, self.hidden_size)
        self.fc2 = nn.Linear(self.hidden_size, self.hidden_size)

    def forward(self, clip, clip_mask, query, query_mask=None):
        bsz, h, w, _ = clip.size()

        a = self.fc1(clip)
        b = self.fc2(query)
        # print(a.size(), b.size())
        # exit(0)
        a = a.reshape(bsz, h * w, -1)
        score = torch.matmul(a, b.transpose(-1, -2)) / math.sqrt(self.hidden_size)
        if query_mask is not None:
            score = score.masked_fill_(query_mask.unsqueeze(1) == 0, float('-inf'))
        score = F.softmax(score, -1)
        # assert not torch.isnan(score).any()
        query_ = torch.matmul(score, query)
        # .view(bsz, h, w, -1)
        # assert not torch.isnan(query_).any()
        # if clip_mask is not None:
        #     clip_mask = clip_mask.reshape(bsz, -1)
        #     query_ = query_.masked_fill_(clip_mask.unsqueeze(-1) == 0, 0)
        return query_.reshape(bsz, h, w, -1)


class TanhAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        # self.dropout = nn.Dropout(dropout)
        self.ws1 = nn.Linear(d_model, d_model, bias=True)
        self.ws2 = nn.Linear(d_model, d_model, bias=False)
        self.wst = nn.Linear(d_model, 1, bias=False)

    def reset_parameters(self):
        self.ws1.reset_parameters()
        self.ws2.reset_parameters()
        self.wst.reset_parameters()

    def forward(self, x, memory, memory_mask=None, fast_weights=None, **kwargs):
        if fast_weights is None:
            item1 = self.ws1(x)  # [nb, len1, d]
            item2 = self.ws2(memory)  # [nb, len2, d]
            # print(item1.shape, item2.shape)
            item = item1.unsqueeze(2) + item2.unsqueeze(1)  # [nb, len1, len2, d]
            S = self.wst(torch.tanh(item)).squeeze(-1)  # [nb, len1, len2]
        else:
            item1 = F.linear(x, fast_weights['ws1.weight'], fast_weights['ws1.bias'])  # [nb, len1, d]
            item2 = F.linear(memory, fast_weights['ws2.weight'])  # [nb, len2, d]
            # print(item1.shape, item2.shape)
            item = item1.unsqueeze(2) + item2.unsqueeze(1)  # [nb, len1, len2, d]
            S = F.linear(torch.tanh(item), fast_weights['wst.weight']).squeeze(-1)  # [nb, len1, len2]
        if memory_mask is not None:
            memory_mask = memory_mask.unsqueeze(1)  # [nb, 1, len2]
            S = S.masked_fill(memory_mask == 0, float('-inf'))
        S = F.softmax(S, -1)
        return torch.matmul(S, memory), S  # [nb, len1, d]


class CrossGate(nn.Module):
    def __init__(self, h1, h2):
        super().__init__()
        self.g1 = nn.Linear(h2, h1)
        self.g2 = nn.Linear(h1, h2)

    def forward(self, x1, x2):
        return x1 * torch.sigmoid(self.g1(x2)), x2 * torch.sigmoid(self.g2(x1))


def generate_mask(x, x_len):
    if False and int(x_len.min()) == x.size(1):
        mask = None
    else:
        mask = []
        for l in x_len:
            mask.append(torch.zeros([x.size(1)]).long())
            mask[-1][:l] = 1
        mask = torch.stack(mask, 0).cuda()
    return mask


def generate_coordinate_emb(clip, h, w):
    bsz = clip.size(0)
    x = (2 * (torch.linspace(0, h, h).type_as(clip) / h) - 1)
    x = x.unsqueeze(-1).expand(h, w)
    y = (2 * (torch.linspace(0, w, w).type_as(clip) / w) - 1)
    y = y.unsqueeze(0).expand(h, w)
    co = torch.stack([x, y], -1)
    co = F.normalize(co, -1)
    co = co.unsqueeze(0).expand(bsz, -1, -1, -1)
    return co
