import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange, einsum
import math


class ETLayer(nn.Module):
    def __init__(self, node_dim, hidden_dim, act_fn, time_emb_dim):
        super().__init__()

        assert node_dim == hidden_dim

        self.q_mlp = MLP(
            in_dim=node_dim,
            hidden_dim=hidden_dim,
            out_dim=hidden_dim,
            n_layer=2,
            act_fn=act_fn,
        )
        self.k_mlp = MLP(
            in_dim=node_dim,
            hidden_dim=hidden_dim,
            out_dim=hidden_dim,
            n_layer=2,
            act_fn=act_fn,
        )
        self.v_mlp = MLP(
            in_dim=node_dim,
            hidden_dim=hidden_dim,
            out_dim=node_dim,
            n_layer=2,
            act_fn=act_fn,
        )
        self.x_mlp = MLP(
            in_dim=node_dim, hidden_dim=hidden_dim, out_dim=1, n_layer=2, act_fn=act_fn
        )

        self.time_emb = nn.Linear(time_emb_dim, hidden_dim)
        self.time_emb_dim = time_emb_dim

    def merge_time_dim(self, x):
        return rearrange(x, "m d t -> (m t) d")

    def separate_time_dim(self, x, t):
        return rearrange(x, "(m t) d -> m d t", t=t)

    def get_timestep_embedding(self, timesteps, embedding_dim, max_positions=10000):
        half_dim = embedding_dim // 2
        # magic number 10000 is from transformers
        emb = math.log(max_positions) / (half_dim - 1)
        emb = torch.exp(
            torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb
        )
        emb = timesteps.float()[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        if embedding_dim % 2 == 1:  # zero pad
            emb = F.pad(emb, (0, 1), mode="constant")
        return emb

    def forward(self, x, h):
        """
        :param x: shape [BN, 3, T]
        :param h: shape [BN, Hh, T]
        :return:
        """

        T = x.size(-1)

        time_index = torch.arange(T).to(x)  # [T]
        rel_time_index = time_index.unsqueeze(-1) - time_index.unsqueeze(-2)  # [T, S]
        rel_time_emb = self.get_timestep_embedding(
            rel_time_index.view(-1), embedding_dim=self.time_emb_dim
        )  # [T*S, Ht]
        rel_time_emb = self.time_emb(rel_time_emb).view(T, T, -1)  # [T*S, H]
        rel_time_emb = (
            rel_time_emb.permute(2, 0, 1).unsqueeze(0).repeat(x.size(0), 1, 1, 1)
        )  # [BN, H, T, S]

        q = self.q_mlp(self.merge_time_dim(h))  # [BN*T, H]
        k = self.k_mlp(self.merge_time_dim(h))  # [BN*T, H]
        v = self.v_mlp(self.merge_time_dim(h))  # [BN*T, Hh]
        v_x = self.x_mlp(v)  # [BN*T, 1]
        v_x_s = self.separate_time_dim(v_x, t=T)  # [BN, 1, S]
        qt = self.separate_time_dim(q, t=T).transpose(-1, -2)  # [BN, H, T] -> [BN, T, H]
        ks = self.separate_time_dim(k, t=T)  # [BN, H, S]
        k_ts = ks.unsqueeze(-2).repeat(1, 1, T, 1) + rel_time_emb  # [BN, H, T, S]
        vs = self.separate_time_dim(v, t=T)  # [BN, Hh, S]
        v_ts = vs.unsqueeze(-2).repeat(1, 1, T, 1) + rel_time_emb  # [BN, H, T, S]
        alpha_ts = F.softmax(einsum(qt, k_ts, "n t h, n h t s-> n t s"), dim=-1)  # [BN, T, S]
        h = h + einsum(alpha_ts, v_ts, "n t s, n h t s-> n h t")  # [BN, BH, T]
        x_ts = x.unsqueeze(-1) - x.unsqueeze(-2)  # [BN, 3, T, S]
        alpha_x_ts = alpha_ts.unsqueeze(1) * x_ts  # [BN, 3, T, S]
        x = x + (alpha_x_ts * v_x_s.unsqueeze(-2)).sum(dim=-1)  # [BN, 3, T]
        return x, h


class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, n_layer, act_fn, last_act=False):
        super().__init__()
        assert n_layer >= 2
        actions = nn.ModuleList()
        actions.append(nn.Linear(in_dim, hidden_dim))
        actions.append(act_fn)
        for i in range(n_layer - 2):
            actions.append(nn.Linear(hidden_dim, hidden_dim))
            actions.append(act_fn)
        actions.append(nn.Linear(hidden_dim, out_dim))
        if last_act:
            actions.append(act_fn)
        self.actions = nn.Sequential(*actions)

    def forward(self, x):
        x = self.actions(x)
        return x
