import math

import torch
from torch import Tensor, nn


class MLPTimeConcat(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim):
        super().__init__()
        # Important: for activations, use ReLU or GELU. Need zero gradient zones for some reason.
        # ELU and SiLU don't work.
        self.net = nn.Sequential(
            nn.Linear(in_dim + 1, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, out_dim),
        )

        self.reset_parameters()

    def reset_parameters(self):
        for i, layer in enumerate(self.net):
            if isinstance(layer, nn.Linear):
                # Use larger initialization
                nn.init.xavier_uniform_(layer.weight, gain=2.0)
                if layer.bias is not None:
                    # Positive bias to encourage non-zero outputs
                    nn.init.uniform_(layer.bias, 0, 0.1)

    def forward(self, x_t: Tensor, t: Tensor) -> Tensor:
        # x_t: [B, N, D]
        # t: [B] or [1] (scalar time value)
        B = x_t.shape[0]
        orig_dim = x_t.shape[-1]

        x_t = x_t.flatten(1)
        if t.dim() == 0:
            t = t.view(1, 1).expand(B, 1)
        elif t.dim() == 1:
            t = t.view(B, 1)

        xin = torch.cat((x_t, t), dim=1)

        out = self.net(xin)

        out = out.view(B, orig_dim, orig_dim)

        return out


class ComplexMLPTimeConcat(MLPTimeConcat):
    def forward(self, x_t: Tensor, t: Tensor) -> Tensor:
        # x_t: [B, N, D]
        # t: [B] or [1] (scalar time value)
        B = x_t.shape[0]
        orig_dim = x_t.shape[-1]

        x_t = torch.view_as_real(x_t).flatten(1)  # [B, N*D*D]
        if t.dim() == 0:
            t = t.view(1, 1).expand(B, 1)
        elif t.dim() == 1:
            t = t.view(B, 1)

        xin = torch.cat((x_t, t), dim=1)

        out = self.net(xin)

        out = out.view(B, -1, orig_dim)

        return out


class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim, theta=10000):
        super().__init__()
        self.dim = dim
        self.theta = theta

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.theta) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]

        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class MLPSinusoidTimeEmbedding(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim):
        super().__init__()
        # Important: for activations, use ReLU or GELU. Need zero gradient zones for some reason.
        # ELU and SiLU don't work.
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, out_dim),
        )

        self.time_embed = nn.Sequential(
            SinusoidalPosEmb(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, in_dim),
        )

        # self.reset_parameters()

    def reset_parameters(self):
        for i, layer in enumerate(self.net):
            if isinstance(layer, nn.Linear):
                if i == len(self.net) - 1:  # Last layer
                    # Initialize to output slightly larger values
                    # This encourages non-zero rotations from the start
                    nn.init.normal_(layer.weight, mean=0.0, std=0.1)
                    nn.init.uniform_(layer.bias, -0.1, 0.1)
                else:
                    nn.init.normal_(layer.weight, mean=0.0, std=0.1)
                    # nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
                    nn.init.zeros_(layer.bias)

    def forward(self, x_t: Tensor, t: Tensor) -> Tensor:
        # x_t: [B, N, D]
        # t: [B] or [1] (scalar time value)
        B = x_t.shape[0]
        orig_dim = x_t.shape[-1]

        x_t = x_t.flatten(1)
        if t.dim() == 0:
            t = t.view(1).expand(B)
        if t.dim() == 2:
            t = t.squeeze(-1)

        t_emb = self.time_embed(t)

        out = self.net(x_t + t_emb)

        out = out.view(B, orig_dim, orig_dim)

        return out

        # w1, w2, w3 = out[:, 0], out[:, 1], out[:, 2]
        # zero = torch.zeros_like(w1)
        #
        # # Create skew-symmetric matrix
        # skew = torch.stack(
        #     [
        #         torch.stack([zero, -w3, w2], dim=-1),
        #         torch.stack([w3, zero, -w1], dim=-1),
        #         torch.stack([-w2, w1, zero], dim=-1),
        #     ],
        #     dim=-2,
        # )  # [B, 3, 3]
        #
        # return skew
