import math
from dataclasses import dataclass

from pyparsing import Optional
import torch
from einops import rearrange
from torch import Tensor, einsum, nn


def gelu(x):
    """Implementation of the gelu activation function.

    For information: OpenAI GPT's gelu is slightly different
    (and gives slightly different results):
    0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)


class EmbedND(nn.Module):
    def __init__(self, dim: int, theta: int, axes_dim: list[int]):
        super().__init__()
        self.dim = dim
        self.theta = theta
        self.axes_dim = axes_dim

    def forward(self, ids: Tensor) -> Tensor:
        n_axes = ids.shape[-1]
        emb = torch.cat(
            [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
            dim=-3,
        )

        return emb.unsqueeze(1)


def attention(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    pe: Tensor | None = None,
    mode: str = "scaled_dot_product",
) -> Tensor:
    q, k = apply_rope(q, k, pe) if pe is not None else (q, k)
    if mode == "scaled_dot_product":
        x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
    else:
        x = attention_linear(q, k, v)
    x = rearrange(x, "B H L D -> B L (H D)")
    return x


def attention_linear(
    q: Tensor,
    k: Tensor,
    v: Tensor,
) -> Tensor:
    """
    https://github.com/lucidrains/linear-attention-transformer
    """
    dim = q.size(-1)
    q = q.softmax(dim=-1)
    k = k.softmax(dim=-2)
    q = q * dim**-0.5
    context = einsum("bhnd,bhne->bhde", k, v)
    x = einsum("bhnd,bhde->bhne", q, context)
    return x


def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
    assert dim % 2 == 0
    scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
    omega = 1.0 / (theta**scale)
    out = torch.einsum("...n,d->...nd", pos, omega)
    out = torch.stack(
        [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
    )
    out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
    return out.float()


def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)


def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
    """Create sinusoidal timestep embeddings.

    :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an (N, D) Tensor of positional embeddings.
    """
    t = time_factor * t
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    ).to(t.device)

    args = t[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    if torch.is_floating_point(t):
        embedding = embedding.to(t)
    return embedding


class MLPEmbedder(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int):
        super().__init__()
        self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
        self.silu = nn.SiLU()
        self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)

    def forward(self, x: Tensor) -> Tensor:
        return self.out_layer(self.silu(self.in_layer(x)))


class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x: Tensor):
        x_dtype = x.dtype
        x = x.float()
        rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
        return (x * rrms).to(dtype=x_dtype) * self.scale


class QKNorm(torch.nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.query_norm = RMSNorm(dim)
        self.key_norm = RMSNorm(dim)

    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
        q = self.query_norm(q)
        k = self.key_norm(k)
        return q.to(v), k.to(v)


class SelfAttention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = False,
        attention_mode: str = "scaled_dot_product",
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.attention_mode = attention_mode

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.norm = QKNorm(head_dim)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x: Tensor, pe: Tensor | None = None) -> Tensor:
        qkv = self.qkv(x)
        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
        q, k = self.norm(q, k, v)
        x = attention(q, k, v, pe=pe, mode=self.attention_mode)
        x = self.proj(x)
        return x


@dataclass
class ModulationOut:
    shift: Tensor
    scale: Tensor
    gate: Tensor


class Modulation(nn.Module):
    def __init__(self, dim: int, double: bool):
        super().__init__()
        self.is_double = double
        self.multiplier = 6 if double else 3
        self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)

    def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
        out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)

        return (
            ModulationOut(*out[:3]),
            ModulationOut(*out[3:]) if self.is_double else None,
        )


class ModulationTriple(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.multiplier = 9
        self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)

    def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut, ModulationOut]:
        out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
        return (
            ModulationOut(*out[:3]),
            ModulationOut(*out[3:6]),
            ModulationOut(*out[6:]),
        )


class ParallelMLPAttentionV2(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qk_scale: float | None = None,
        attention_mode: str = "scaled_dot_product",
    ):
        super().__init__()
        self.hidden_dim = hidden_size
        self.num_heads = num_heads
        head_dim = hidden_size // num_heads
        self.scale = qk_scale or head_dim**-0.5
        self.attention_mode = attention_mode

        self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
        self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
        self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)

        self.norm = QKNorm(head_dim)

        self.hidden_size = hidden_size
        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)

    def forward(self, x: Tensor, pe: Tensor | None = None) -> Tensor:
        qkv, mlp = torch.split(
            self.linear1(x), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
        )
        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
        q, k = self.norm(q, k, v)

        attn = attention(q, k, v, pe=pe, mode=self.attention_mode)
        output = self.linear2(torch.cat((attn, gelu(mlp)), 2))
        return output


class LatentSIV3Layer(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        mlp_ratio: int,
        attention_mode: str = "scaled_dot_product",
    ):
        super().__init__()
        self.modulation = Modulation(hidden_size, double=True)
        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)

        self.spatial_block = ParallelMLPAttentionV2(
            hidden_size=hidden_size,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            attention_mode=attention_mode,
        )

        self.temporal_block = ParallelMLPAttentionV2(
            hidden_size=hidden_size,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            attention_mode=attention_mode,
        )

    def forward(
        self, x: Tensor, y: Tensor, pe_spatial: EmbedND, pe_temporal: EmbedND
    ) -> Tensor:
        _, T, L, _ = x.size()

        mod1, mod2 = self.modulation(y)
        residual = x

        x = modulate(self.pre_norm(x), mod1.shift, mod1.scale)
        x = rearrange(x, "B T L D -> (B T) L D", L=L)
        x = self.spatial_block(x=x, pe=pe_spatial)
        x = rearrange(x, "(B T) L D -> B T L D", T=T)
        x = residual + mod1.gate.unsqueeze(1) * x

        residual = x
        x = modulate(self.pre_norm(x), mod2.shift, mod2.scale)
        x = rearrange(x, "B T L D -> (B L) T D", L=L)
        x = self.temporal_block(x=x, pe=pe_temporal)
        x = rearrange(x, "(B L) T D -> B T L D", L=L)
        x = residual + mod2.gate.unsqueeze(1) * x

        return x


class LatentSIV3TemporalLayer(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        mlp_ratio: int,
        attention_mode: str = "scaled_dot_product",
    ):
        super().__init__()
        self.modulation = Modulation(hidden_size, double=True)
        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)

        self.temporal_block = ParallelMLPAttentionV2(
            hidden_size=hidden_size,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            attention_mode=attention_mode,
        )

    def forward(
        self, x: Tensor, y: Tensor, pe_spatial: EmbedND, pe_temporal: EmbedND
    ) -> Tensor:
        _, T, L, _ = x.size()

        mod1, mod2 = self.modulation(y)
        # residual = x
        # x = modulate(self.pre_norm(x), mod1.shift, mod1.scale)
        # x = rearrange(x, "B T L D -> (B T) L D", L=L)
        # x = self.spatial_block(x=x, pe=pe_spatial)
        # x = rearrange(x, "(B T) L D -> B T L D", T=T)
        # x = residual + mod1.gate.unsqueeze(1) * x

        residual = x
        x = modulate(self.pre_norm(x), mod2.shift, mod2.scale)
        x = rearrange(x, "B T L D -> (B L) T D", L=L)
        x = self.temporal_block(x=x, pe=pe_temporal)
        x = rearrange(x, "(B L) T D -> B T L D", L=L)
        x = residual + mod2.gate.unsqueeze(1) * x

        return x


class LatentSIV3SpatialLayer(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        mlp_ratio: int,
        attention_mode: str = "scaled_dot_product",
    ):
        super().__init__()
        self.modulation = Modulation(hidden_size, double=True)
        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)

        self.spatial_block = ParallelMLPAttentionV2(
            hidden_size=hidden_size,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            attention_mode=attention_mode,
        )

    def forward(
        self, x: Tensor, y: Tensor, pe_spatial: EmbedND, pe_temporal: EmbedND
    ) -> Tensor:
        _, T, L, _ = x.size()

        mod1, mod2 = self.modulation(y)
        residual = x
        x = modulate(self.pre_norm(x), mod1.shift, mod1.scale)
        x = rearrange(x, "B T L D -> (B T) L D", L=L)
        x = self.spatial_block(x=x, pe=pe_spatial)
        x = rearrange(x, "(B T) L D -> B T L D", T=T)
        x = residual + mod1.gate.unsqueeze(1) * x

        return x


class LatentSIV3(nn.Module):
    def __init__(
        self,
        depth: int,
        in_dim: int,
        hidden_size: int,
        num_heads: int,
        vec_in_dim=None,
        mlp_ratio: int = 2,
        n_timesteps: int = 10,
        theta: int = 10_000,
        checkpointing: bool = False,
        normalize: bool = False,
        attention_mode: str = "scaled_dot_product",
        share_weights: bool = False,
        reset_parameters: bool = True,
    ):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = self.in_dim
        self.n_timesteps = n_timesteps
        self.checkpointing = checkpointing
        self.normalize = normalize
        self.attention_mode = attention_mode

        if hidden_size % num_heads != 0:
            raise ValueError(
                f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
            )
        pe_dim = hidden_size // num_heads

        self.x_in = nn.Linear(in_dim, hidden_size)
        self.cond_to_emb = nn.Linear(in_dim, hidden_size)
        self.mask_to_emb = nn.Embedding(2, hidden_size)
        self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size)
        if vec_in_dim is not None:
            self.vec_in = MLPEmbedder(in_dim=vec_in_dim, hidden_dim=hidden_size)
        self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=[pe_dim])
        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)

        self.blocks = nn.ModuleList()
        if share_weights:
            block = LatentSIV3Layer(hidden_size, num_heads, mlp_ratio, attention_mode)
            for _ in range(depth):
                self.blocks.append(block)
        else:
            for _ in range(depth):
                self.blocks.append(
                    LatentSIV3Layer(hidden_size, num_heads, mlp_ratio, attention_mode)
                )

        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )
        self.linear = nn.Linear(hidden_size, self.out_dim)

        if reset_parameters:
            self.reset_parameters()

    def reset_parameters(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight, gain=1.0 / math.sqrt(2))
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

        self.apply(_basic_init)

        nn.init.normal_(self.time_in.in_layer.weight, std=0.02)
        nn.init.normal_(self.time_in.out_layer.weight, std=0.02)

        if hasattr(self, "vec_in"):
            nn.init.normal_(self.vec_in.in_layer.weight, std=0.02)
            nn.init.normal_(self.vec_in.out_layer.weight, std=0.02)

        for block in self.blocks:
            nn.init.xavier_uniform_(block.spatial_block.linear1.weight, gain=1 / math.sqrt(2))
            nn.init.xavier_uniform_(block.spatial_block.linear2.weight, gain=1 / math.sqrt(2))
            nn.init.constant_(block.spatial_block.linear2.bias, 0.0)

            nn.init.xavier_uniform_(block.temporal_block.linear1.weight, gain=1 / math.sqrt(2))
            nn.init.xavier_uniform_(block.temporal_block.linear2.weight, gain=1 / math.sqrt(2))
            nn.init.constant_(block.temporal_block.linear2.bias, 0.0)

            nn.init.constant_(block.modulation.lin.weight, 0.0)
            nn.init.constant_(block.modulation.lin.bias, 0.0)

        nn.init.normal_(self.linear.weight, std=0.00)
        nn.init.normal_(self.linear.bias, std=0.00)

    def temporal_rope_embedding(self, B: int, T: int, L: int, device: torch.device) -> Tensor:
        return self.pe_embedder(torch.arange(T, device=device)[None, :, None]).expand(
            B * L, -1, -1, -1, -1, -1
        )

    def spatial_rope_embedding(self, B: int, T: int, L: int, device: torch.device) -> Tensor:
        return self.pe_embedder(torch.arange(L, device=device)[None, :, None]).expand(
            B * T, -1, -1, -1, -1, -1
        )

    def forward(self, x: Tensor, t: Tensor, x_cond: Tensor, y: Tensor = None) -> Tensor:
        B, T, L, _ = x.size()
        x = self.x_in(x) + self.cond_to_emb(x_cond)
        if self.normalize:
            x = nn.functional.layer_norm(x, (x.size(-1),))

        vec = self.time_in(timestep_embedding(t, 256))
        if y is not None:
            vec = vec + self.vec_in(y)

        pe_spatial = self.spatial_rope_embedding(B, T, L, x.device)
        pe_temporal = self.temporal_rope_embedding(B, T, L, x.device)
        for block in self.blocks:
            x = block(x=x, y=vec, pe_spatial=pe_spatial, pe_temporal=pe_temporal)

        shift, scale = self.adaLN_modulation(vec)[:, None, :].chunk(2, dim=-1)
        x = modulate(self.pre_norm(x), shift, scale)
        x = self.linear(x)
        return x


class TransformerLayer(nn.Module):
    def __init__(
        self,
        axis: str,
        in_dim: int,
        hidden_size: int,
        num_heads: int,
        vec_in_dim=None,
        mlp_ratio: int = 2,
        n_timesteps: int = 10,
        theta: int = 10_000,
        checkpointing: bool = False,
        normalize: bool = False,
        attention_mode: str = "scaled_dot_product",
        share_weights: bool = False,
        reset_parameters: bool = True,
    ):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = self.in_dim
        self.n_timesteps = n_timesteps
        self.checkpointing = checkpointing
        self.normalize = normalize
        self.attention_mode = attention_mode

        if hidden_size % num_heads != 0:
            raise ValueError(
                f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
            )
        pe_dim = hidden_size // num_heads

        self.x_in = nn.Linear(in_dim, hidden_size)
        self.cond_to_emb = nn.Linear(in_dim, hidden_size)
        self.mask_to_emb = nn.Embedding(1, hidden_size)
        self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size)
        if vec_in_dim is not None:
            self.vec_in = MLPEmbedder(in_dim=vec_in_dim, hidden_dim=hidden_size)
        self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=[pe_dim])
        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)

        self.blocks = nn.ModuleList()
        if axis == "spatial":
            self.blocks.append(
                LatentSIV3SpatialLayer(hidden_size, num_heads, mlp_ratio, attention_mode)
            )
        elif axis == "temporal":
            self.blocks.append(
                LatentSIV3TemporalLayer(hidden_size, num_heads, mlp_ratio, attention_mode)
            )
        else:
            raise ValueError(f"Unknown axis type: {axis}")
        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )
        self.linear = nn.Linear(hidden_size, self.out_dim)

        if reset_parameters:
            self.reset_parameters()

    def reset_parameters(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight, gain=1.0 / math.sqrt(2))
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

        self.apply(_basic_init)

        nn.init.normal_(self.time_in.in_layer.weight, std=0.02)
        nn.init.normal_(self.time_in.out_layer.weight, std=0.02)

        if hasattr(self, "vec_in"):
            nn.init.normal_(self.vec_in.in_layer.weight, std=0.02)
            nn.init.normal_(self.vec_in.out_layer.weight, std=0.02)

        for block in self.blocks:
            if isinstance(block, LatentSIV3SpatialLayer):
                nn.init.xavier_uniform_(
                    block.spatial_block.linear1.weight, gain=1 / math.sqrt(2)
                )
                nn.init.xavier_uniform_(
                    block.spatial_block.linear2.weight, gain=1 / math.sqrt(2)
                )
                nn.init.constant_(block.spatial_block.linear2.bias, 0.0)
            elif isinstance(block, LatentSIV3TemporalLayer):
                nn.init.xavier_uniform_(
                    block.temporal_block.linear1.weight, gain=1 / math.sqrt(2)
                )
                nn.init.xavier_uniform_(
                    block.temporal_block.linear2.weight, gain=1 / math.sqrt(2)
                )
                nn.init.constant_(block.temporal_block.linear2.bias, 0.0)

            nn.init.constant_(block.modulation.lin.weight, 0.0)
            nn.init.constant_(block.modulation.lin.bias, 0.0)

        nn.init.normal_(self.linear.weight, std=0.00)
        nn.init.normal_(self.linear.bias, std=0.00)

    def temporal_rope_embedding(self, B: int, T: int, L: int, device: torch.device) -> Tensor:
        return self.pe_embedder(torch.arange(T, device=device)[None, :, None]).expand(
            B * L, -1, -1, -1, -1, -1
        )

    def spatial_rope_embedding(self, B: int, T: int, L: int, device: torch.device) -> Tensor:
        return self.pe_embedder(torch.arange(L, device=device)[None, :, None]).expand(
            B * T, -1, -1, -1, -1, -1
        )

    def forward(self, x: Tensor, t: Tensor, x_cond: Tensor, y: Tensor = None) -> Tensor:
        B, T, L, _ = x.size()
        x = self.x_in(x) + self.cond_to_emb(x_cond)
        if self.normalize:
            x = nn.functional.layer_norm(x, (x.size(-1),))

        vec = self.time_in(timestep_embedding(t, 256))
        if y is not None:
            vec = vec + self.vec_in(y)

        pe_spatial = self.spatial_rope_embedding(B, T, L, x.device)
        pe_temporal = self.temporal_rope_embedding(B, T, L, x.device)
        for block in self.blocks:
            x = block(x=x, y=vec, pe_spatial=pe_spatial, pe_temporal=pe_temporal)

        shift, scale = self.adaLN_modulation(vec)[:, None, :].chunk(2, dim=-1)
        x = modulate(self.pre_norm(x), shift, scale)
        x = self.linear(x)
        return x
