import torch
import torch.nn as nn
import torch.nn.functional as F
import time

class TemporalGlobalPoolingLayer(nn.Module):
    def __init__(self, args):
        super(TemporalGlobalPoolingLayer, self).__init__()
        self.args = args
        self.fc = nn.Linear(args.in_dim, args.out_dim) if args.use_fc else None
        if args.attn_mask_dropout:
            self.attn_mask_dropout = args.mha_dropout
            args.mha_dropout = 0.0
        else:
            self.attn_mask_dropout = 0.0
        self.add_zero_attn = args.add_zero_attn
        self.temporal_attention = nn.MultiheadAttention(embed_dim=args.out_dim, num_heads=args.num_head,
                                                        dropout=args.mha_dropout, batch_first=True,
                                                        add_zero_attn=args.add_zero_attn)
        self.mha_factor = torch.nn.Parameter(torch.Tensor([10]))
        self.spatial_attention = nn.MultiheadAttention(embed_dim=args.out_dim, num_heads=args.num_head,
                                                       dropout=args.mha_dropout, batch_first=True)
        self.eps = 1e-4
        self.alpha_type = args.alpha_type
        if self.alpha_type == 'learnable':
            self.alpha = nn.Parameter(torch.zeros(1))
        elif self.alpha_type == 'fixed':
            self.alpha = torch.tensor(1.0, requires_grad=False)
        elif self.alpha_type == 'never_negative_gradient':
            self.alpha = nn.Parameter(torch.zeros(1))
            self.alpha.register_hook(lambda grad: torch.clamp(grad, min=0))
        self.use_layer_norm = args.use_layer_norm
        self.skip_connection = args.skip_connection
        if self.use_layer_norm:
            self.layer_norm1 = nn.LayerNorm(args.out_dim)
            self.layer_norm2 = nn.LayerNorm(args.out_dim)
        self.T = None
        self.learn_query = getattr(args, 'learn_query', False)
        if self.learn_query:
            self.query = nn.Parameter(torch.zeros(1, 1, args.out_dim))
            print(f"Learning query")

    def _generate_positional_embeddings(self, out_dim: int, T: int, device: torch.device) -> torch.Tensor:
        if self.T is None:  # will only be called once
            position = torch.arange(0, T, dtype=torch.float, device=device).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, out_dim, 2, device=device).float() * (
                    -torch.log(torch.tensor(10000.0, device=device)) / out_dim))
            pos_emb = torch.zeros(T, out_dim, device=device)
            pos_emb[:, 0::2] = torch.sin(position * div_term)
            pos_emb[:, 1::2] = torch.cos(position * div_term)
            self.register_buffer('pos_emb', pos_emb)  # Register as buffer
            self.T = T
        else:
            pass  # it's already registered

    def forward(self, inp: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        # x is of shape (B, N_max, T, D)
        # assert not torch.isnan(x).any(), "x contains NaN"
        if self.fc is not None:
            inp = self.fc(inp)  # shape is (B, N_max, T, D)
        B, N_max, T, D = inp.size()
        self._generate_positional_embeddings(self.args.out_dim, T, device=inp.device)
        x = inp + self.pos_emb[None, None, :, :]  # (1, 1, T, D)
        y = x[mask]  # shape is (N_total, T, D)
        x = x.flatten(end_dim=1)  # shape is (B*N_max, T, D)
        attn_mask = None
        if self.attn_mask_dropout > 0 and self.training:
            shapex = 1 # due to choosing the last timestep as query
            shapey = y.shape[-2]
            attn_mask = torch.rand(shapex, shapey, device=x.device) < self.attn_mask_dropout

        y = self.layer_norm1(y) if self.use_layer_norm else y
        last_query = y[:, -1:] if not self.learn_query else self.query.expand_as(y[:, -1:])
        _, temporal_attn_weights = self.temporal_attention(last_query, y, y, attn_mask=attn_mask)  # shape is (N_total, 1, T)
        C = torch.log(temporal_attn_weights * F.relu(self.mha_factor) + self.eps)  # shape is (N_total, 1, T)
        C = C.mean(dim=0, keepdim=True)  # shape is (1, 1, T)
        attn_scores = C.clone()
        norm = C.sum(dim=-1, keepdim=True) + self.eps  # normalize such that the sum of C is 1 shape is (1, 1, 1)
        C = C / norm  # shape is (1, 1, T)
        if self.add_zero_attn: C = C[..., :-1]  # remove the zero token
        C = C.transpose(1, 2)  # shape is (1, T, 1)
        temporal_attn_output = (x * C).sum(dim=1)  # shape is (B*N_max, D)
        if self.skip_connection:
            temporal_attn_output += x[:, -1]
        T_ = 1

        temporal_attn_output = temporal_attn_output.reshape(B, N_max, T_, D)
        summerized_temporal_attn_output = temporal_attn_output.sum(dim=-2)  # (B, N_max, D)

        z = self.layer_norm2(summerized_temporal_attn_output) \
            if self.use_layer_norm else summerized_temporal_attn_output
        spatial_attn_output, _ = self.spatial_attention(z, z, z, need_weights=False)
        if self.skip_connection:
            spatial_attn_output += z
        r = spatial_attn_output.reshape(B, N_max, T_, D)
        alpha = torch.sigmoid(self.alpha)
        if self.args.task == 'graph_classification':
            result = (r * mask.unsqueeze(-1).unsqueeze(-1)).sum(dim=(1, 2)) / (T_ * (mask.sum(dim=1, keepdim=True)))
            last = (inp.reshape(B, N_max, T, D)[:, :, -1] * mask.unsqueeze(-1)).sum(dim=1) / (
                mask.sum(dim=1, keepdim=True))  # [B,D]
            ret = alpha * result + (1 - alpha) * last
            return ret
        elif self.args.task == 'node_classification':
            ret = alpha * r[:, :, 0] + (1 - alpha) * inp.reshape(B, N_max, T, D)[:, :, -1]
            if 'return_att_scores' in self.args and self.args.return_att_scores:
                return ret, (attn_scores, x)
            else:
                return ret, _
        else:
            raise ValueError(f"Unknown task: {self.args.task}")
