from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn, Tensor

from video_light.transformer import TransformerEncoderLayer, TransformerEncoder


def mask_logits(inputs, mask, mask_value=-1e30):
    mask = mask.type(torch.float32)
    return inputs + (1.0 - mask) * mask_value


class GlobalFeatureExtractor(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=1, bidirectional=False):
        super(GlobalFeatureExtractor, self).__init__()
        self.gru = nn.GRU(input_dim, hidden_dim, num_layers=num_layers, bidirectional=bidirectional, batch_first=True)

    def forward(self, x):
        # x: (batch_size, seq_length, input_dim)
        output, hidden = self.gru(x)

        # If bidirectional, concatenate the forward and backward hidden states
        if self.gru.bidirectional:
            # Combine the hidden states of both directions (forward and backward)
            hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
        else:
            # If not bidirectional, use the hidden state of the last layer
            hidden = hidden[-1, :, :]

        # Output the hidden state(s) as the global feature(s)
        # For a single layer, hidden shape: (batch_size, hidden_dim)
        # For multiple layers, hidden shape: (batch_size, num_layers * hidden_dim)
        return hidden


class Conv1D(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size=1, stride=1, padding=0, bias=True):
        super(Conv1D, self).__init__()
        self.conv1d = nn.Conv1d(in_channels=in_dim, out_channels=out_dim, kernel_size=kernel_size, padding=padding,
                                stride=stride, bias=bias)

    def forward(self, x):
        # suppose all the input with shape (batch_size, seq_len, dim)
        x = x.transpose(1, 2)  # (batch_size, dim, seq_len)
        x = self.conv1d(x)
        return x.transpose(1, 2)  # (batch_size, seq_len, dim)


class WeightedPool(nn.Module):
    def __init__(self, dim):
        super(WeightedPool, self).__init__()
        weight = torch.empty(dim, 1)
        nn.init.xavier_uniform_(weight)
        self.weight = nn.Parameter(weight, requires_grad=True)

    def forward(self, x, mask):
        alpha = torch.tensordot(x, self.weight, dims=1)  # shape = (batch_size, seq_length, 1)
        alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
        alphas = nn.Softmax(dim=1)(alpha)
        pooled_x = torch.matmul(x.transpose(1, 2), alphas)  # (batch_size, dim, 1)
        pooled_x = pooled_x.squeeze(2)
        return pooled_x


class FeatureRefinement(nn.Module):
    def __init__(self, d_model, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        sim_w = torch.empty(d_model, 1)
        nn.init.xavier_uniform_(sim_w)
        self.sim_w = nn.Parameter(sim_w, requires_grad=True)

        cor_v_w = torch.empty(1, d_model)
        nn.init.xavier_uniform_(cor_v_w)
        self.cor_v_w = nn.Parameter(cor_v_w, requires_grad=True)

        cor_q_w = torch.empty(1, 1)
        nn.init.xavier_uniform_(cor_q_w)
        self.cor_q_w = nn.Parameter(cor_q_w, requires_grad=True)

        # self.sentence_feature_extractor = GlobalFeatureExtractor(d_model, d_model)
        self.word_to_sentence_pool = WeightedPool(dim=d_model)
        # self.mixer = Conv1D(in_dim=3 * d_model, out_dim=d_model, kernel_size=1, stride=1, padding=0, bias=True)
        self.mixer = Conv1D(in_dim=4 * d_model, out_dim=d_model, kernel_size=1, stride=1, padding=0, bias=True)

    def forward(self, video_features, query_features,
                video_mask: Optional[Tensor] = None, query_mask: Optional[Tensor] = None, ):
        bs, vl, dim = video_features.shape
        _, ql, _ = query_features.shape

        query_expanded = query_features.unsqueeze(2)  # [bs, num_words, 1, hidden_size]
        video_expanded = video_features.unsqueeze(1)  # [bs, 1, num_clips, hidden_size]

        correlation = (query_expanded * video_expanded) # [bs, num_words, num_clips, hidden_size]
        correlation_scores = nn.Softmax(dim=1)(correlation.sum(dim=-1))  # [bs, num_words, num_clips]

        cor_v_w = self.cor_v_w.repeat(bs, 1, 1)
        cor_q_w = self.cor_q_w.repeat(bs, ql, 1)

        cor_w = cor_v_w * cor_q_w

        corr_matrix = self.dropout(torch.matmul(correlation_scores.transpose(1,2), cor_w))

        # word-level -> sentence-level
        sentence_feature = self.word_to_sentence_pool(query_features, query_mask).unsqueeze(1)
        sim = F.cosine_similarity(video_features, sentence_feature, dim=-1) + (video_mask + 1e-45).log()

        sim_features = self.dropout(torch.matmul(self.sim_w.transpose(1, 0).expand(bs, 1, dim)
                                                 .transpose(1, 2), sim.unsqueeze(1))).transpose(1, 2)

        # pooled_query = self.weighted_pool(query_features, query_mask)
        pooled_query = self.dropout(sentence_feature.repeat(1, vl, 1))

        features = torch.cat([self.dropout(video_features), sim_features, pooled_query, corr_matrix], dim=2)
        # features = torch.cat([self.dropout(video_features), sim_features, pooled_query], dim=2)

        # output = self.conv1d(output)
        out_features = self.mixer(features)
        return self.dropout(F.relu(out_features))

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
