import math
import threading
from typing import Literal
from dataclasses import dataclass
from functools import reduce
from pydoc import locate
from contextlib import contextmanager

import torch.nn as nn
import torch.nn.functional as F
import torch
from einops import rearrange


# ==================================================================================================


state = threading.local()
state.checkpointing = False


@contextmanager
def checkpointing(enable=True):
    try:
        old_checkpointing, state.checkpointing = state.checkpointing, enable
        yield
    finally:
        state.checkpointing = old_checkpointing


def get_checkpointing():
    return getattr(state, "checkpointing", False)


def checkpoint(function, *args, **kwargs):
    if get_checkpointing():
        kwargs.setdefault("use_reentrant", True)
        return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
    else:
        return function(*args, **kwargs)


def zero_init(layer):
    nn.init.zeros_(layer.weight)
    if layer.bias is not None:
        nn.init.zeros_(layer.bias)
    return layer


def tag_param(param, tag):
    if not hasattr(param, "_tags"):
        param._tags = set([tag])
    else:
        param._tags.add(tag)
    return param


def tag_module(module, tag):
    for param in module.parameters():
        tag_param(param, tag)
    return module


def apply_wd(module):
    for name, param in module.named_parameters():
        if name.endswith("weight"):
            tag_param(param, "wd")
    return module


def rms_norm(x, scale, eps):
    dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
    mean_sq = torch.mean(x.to(dtype) ** 2, dim=-1, keepdim=True)
    scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
    return x * scale.to(x.dtype)


def scale_for_cosine_sim(q, k, scale, eps):
    dtype = reduce(torch.promote_types, (q.dtype, k.dtype, scale.dtype, torch.float32))
    sum_sq_q = torch.sum(q.to(dtype) ** 2, dim=-1, keepdim=True)
    sum_sq_k = torch.sum(k.to(dtype) ** 2, dim=-1, keepdim=True)
    sqrt_scale = torch.sqrt(scale.to(dtype))
    scale_q = sqrt_scale * torch.rsqrt(sum_sq_q + eps)
    scale_k = sqrt_scale * torch.rsqrt(sum_sq_k + eps)
    return q * scale_q.to(q.dtype), k * scale_k.to(k.dtype)


def scale_for_cosine_sim_qkv(qkv, scale, eps):
    q, k, v = qkv.unbind(2)
    q, k = scale_for_cosine_sim(q, k, scale[:, None], eps)
    return torch.stack((q, k, v), dim=2)


def scale_for_cosine_sim_kv(q, kv, scale, eps):
    k, v = kv.unbind(2)
    q, k = scale_for_cosine_sim(q, k, scale[:, None], eps)
    return q, torch.stack((k, v), dim=2)


def linear_swiglu(x, weight, bias=None):
    x = x @ weight.mT
    if bias is not None:
        x = x + bias
    x, gate = x.chunk(2, dim=-1)
    return x * F.silu(gate)


class LinearSwiGLU(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features * 2, bias=bias)
        self.out_features = out_features

    def forward(self, x):
        return linear_swiglu(x, self.weight, self.bias)


# ==================================================================================================


# Note this uses SwiGLU instead of GeGLU
class FeedForwardBlock(nn.Module):
    def __init__(self, d_model, d_ff, d_cond_norm=None, dropout=0.0):
        super().__init__()
        if d_cond_norm is not None:
            self.norm = AdaRMSNorm(d_model, d_cond_norm)
        else:
            self.norm = RMSNorm(d_model)
        self.up_proj = apply_wd(LinearSwiGLU(d_model, d_ff, bias=False))
        self.dropout = nn.Dropout(dropout)
        self.down_proj = apply_wd(zero_init(nn.Linear(d_ff, d_model, bias=False)))

    # @torch.compile(fullgraph=True)
    def forward(self, x, check_dict, cond_norm=None, **kwargs):
        skip = x
        if cond_norm is not None:
            x = self.norm(x, cond_norm)
            check_dict["cond_norm"] = True
        else:
            x = self.norm(x)
        x = self.up_proj(x)
        x = self.dropout(x)
        x = self.down_proj(x)
        return x + skip


# ===================================================================================================
# need to adjust the adarms norm to work in the 1d case
# the standard rms norm is fine because no conditioning is needed
class AdaRMSNorm(nn.Module):
    def __init__(self, features, cond_features, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.linear = apply_wd(zero_init(nn.Linear(cond_features, features, bias=False)))
        tag_module(self.linear, "mapping")

    def extra_repr(self):
        return f"eps={self.eps},"

    def forward(self, x, cond):
        # removed one additional expansion here
        return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)


class RMSNorm(nn.Module):
    def __init__(self, shape, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(shape))

    def extra_repr(self):
        return f"shape={tuple(self.scale.shape)}, eps={self.eps}"

    def forward(self, x):
        return rms_norm(x, self.scale, self.eps)


class LayerNorm(nn.Module):
    def __init__(self, shape, eps=1e-5, elementwise_affine=True):
        super().__init__()
        self.eps = eps
        self.shape = (shape,) if isinstance(shape, int) else shape
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            self.register_parameter("weight", nn.Parameter(torch.ones(shape), requires_grad=True))
            self.register_parameter("bias", nn.Parameter(torch.zeros(shape), requires_grad=True))

    def forward(self, x):
        if self.elementwise_affine:
            d = reduce(
                torch.promote_types,
                (x.dtype, self.weight.dtype, self.bias.dtype, torch.float32),
            )
            return F.layer_norm(x.to(d), self.shape, self.weight.to(d), self.bias.to(d), self.eps).to(x.dtype)
        else:
            d = reduce(torch.promote_types, (x.dtype, torch.float32))
            return F.layer_norm(x.to(d), self.shape, None, None, self.eps).to(x.dtype)


# ===================================================================================================


def Patch2D(
    type: Literal["split", "split_last" "merge"],
    in_features,
    out_features,
    patch_size=(2, 2),
):
    if type == "split":
        return TokenSplit2D(in_features, out_features, patch_size)
    if type == "split_last":
        return TokenSplitLast2D(in_features, out_features, patch_size)
    elif type == "merge":
        return TokenMerge2D(in_features, out_features, patch_size)
    else:
        raise ValueError(f"Unknown type: {type}")


def CustomProj(
    type: Literal["split", "split_last" "merge"],
    in_features,
    out_features,
    in_cls,
    out_cls,
    in_params={},
    out_params={},
    **kwargs,
):
    if type == "split" or type == "split_last":
        return locate(out_cls)(in_features, out_features, **out_params)
    elif type == "merge":
        return locate(in_cls)(in_features, out_features, **in_params)
    else:
        raise ValueError(f"Unknown type: {type}")


class CondTokenMerge2D(nn.Module):
    def __init__(self, in_features, out_features, cond_features, patch_size=(2, 2)):
        super().__init__()
        self.h = patch_size[0]
        self.w = patch_size[1]
        self.proj = apply_wd(nn.Linear(in_features * self.h * self.w + cond_features, out_features, bias=False))

    def forward(self, x, pos, cond_tokens, check_dict, **kwargs):
        check_dict["cond_tokens"] = True
        x = rearrange(x, "... (h nh) (w nw) e -> ... h w (nh nw e)", nh=self.h, nw=self.w)
        x = torch.cat((x, cond_tokens), dim=-1)

        pos = rearrange(pos, "... (h nh) (w nw) e -> ... h w (nh nw) e", nh=self.h, nw=self.w)
        return self.proj(x), torch.mean(pos, dim=-2)


class SimpleProj(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.proj = apply_wd(nn.Linear(in_features, out_features, bias=False))

    def forward(self, x, **kwargs):
        return self.proj(x)


class TokenMerge2D(nn.Module):
    def __init__(self, in_features, out_features, patch_size=(2, 2)):
        super().__init__()
        self.h = patch_size[0]
        self.w = patch_size[1]
        self.proj = apply_wd(nn.Linear(in_features * self.h * self.w, out_features, bias=False))

    def forward(self, x, pos, **kwargs):
        x = rearrange(x, "... (h nh) (w nw) e -> ... h w (nh nw e)", nh=self.h, nw=self.w)
        pos = rearrange(pos, "... (h nh) (w nw) e -> ... h w (nh nw) e", nh=self.h, nw=self.w)
        return self.proj(x), torch.mean(pos, dim=-2)


class TokenSplit2D(nn.Module):
    def __init__(self, in_features, out_features, patch_size=(2, 2)):
        super().__init__()
        self.h = patch_size[0]
        self.w = patch_size[1]
        self.proj = apply_wd(nn.Linear(in_features, out_features * self.h * self.w, bias=False))
        self.fac = nn.Parameter(torch.ones(1) * 0.5)

    def forward(self, x, skip=None, **kwargs):
        x = self.proj(x)
        x = rearrange(x, "... h w (nh nw e) -> ... (h nh) (w nw) e", nh=self.h, nw=self.w)
        if skip is None:
            return x
        return torch.lerp(skip, x, self.fac.to(x.dtype))


# Doens't have skip but norm
class TokenSplitLast2D(nn.Module):
    def __init__(self, in_features, out_features, patch_size=(2, 2), zero_init=True):
        super().__init__()
        self.h = patch_size[0]
        self.w = patch_size[1]
        self.norm = RMSNorm(in_features)
        self.proj = apply_wd(nn.Linear(in_features, out_features * self.h * self.w, bias=False))
        if zero_init:
            nn.init.zeros_(self.proj.weight)

    def forward(self, x, **kwargs):
        x = self.norm(x)
        x = self.proj(x)
        x = rearrange(x, "... h w (nh nw e) -> ... (h nh) (w nw) e", nh=self.h, nw=self.w)
        return x


# ===================================================================================================


def Patch3D(type: Literal["split", "merge"], in_features, out_features, patch_size=(1, 2, 2)):
    if type == "split":
        return TokenSplit3D(in_features, out_features, patch_size)
    if type == "split_last":
        return TokenSplitLast3D(in_features, out_features, patch_size)
    elif type == "merge":
        return TokenMerge3D(in_features, out_features, patch_size)
    else:
        raise ValueError(f"Unknown type: {type}")


class TokenMerge3D(nn.Module):
    def __init__(self, in_features, out_features, patch_size=(1, 2, 2)):
        super().__init__()
        self.t = patch_size[0]
        self.h = patch_size[1]
        self.w = patch_size[2]
        self.proj = apply_wd(nn.Linear(in_features * self.t * self.h * self.w, out_features, bias=False))

    def forward(self, x, pos, **kwargs):
        x = rearrange(
            x,
            "... (t nt) (h nh) (w nw) e -> ... t h w (nt nh nw e)",
            nt=self.t,
            nh=self.h,
            nw=self.w,
        )
        pos = rearrange(
            pos,
            "... (t nt) (h nh) (w nw) e -> ... t h w (nt nh nw) e",
            nt=self.t,
            nh=self.h,
            nw=self.w,
        )

        return self.proj(x), torch.mean(pos, dim=-2)


class TokenSplit3D(nn.Module):
    def __init__(self, in_features, out_features, patch_size=(1, 2, 2)):
        super().__init__()
        self.t = patch_size[0]
        self.h = patch_size[1]
        self.w = patch_size[2]
        self.proj = apply_wd(nn.Linear(in_features, out_features * self.t * self.h * self.w, bias=False))
        self.fac = nn.Parameter(torch.ones(1) * 0.5)

    def forward(self, x, skip=None, **kwargs):
        x = self.proj(x)
        x = rearrange(
            x,
            "... t h w (nt nh nw e) -> ... (t nt) (h nh) (w nw) e",
            nt=self.t,
            nh=self.h,
            nw=self.w,
        )
        if skip is None:
            return x
        return torch.lerp(skip, x, self.fac.to(x.dtype))


class TokenSplitLast3D(nn.Module):
    def __init__(self, in_features, out_features, patch_size=(1, 2, 2)):
        super().__init__()
        self.t = patch_size[0]
        self.h = patch_size[1]
        self.w = patch_size[2]
        self.proj = apply_wd(nn.Linear(in_features, out_features * self.t * self.h * self.w, bias=False))
        self.norm = RMSNorm(in_features)

    def forward(self, x, **kwargs):
        x = self.norm(x)
        x = self.proj(x)
        x = rearrange(
            x,
            "... t h w (nt nh nw e) -> ... (t nt) (h nh) (w nw) e",
            nt=self.t,
            nh=self.h,
            nw=self.w,
        )
        return x


# ===================================================================================================


@dataclass
class MappingSpec:
    depth: int
    width: int
    d_ff: int
    dropout: float


class MappingFeedForwardBlock(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.0):
        super().__init__()
        self.norm = RMSNorm(d_model)
        self.up_proj = apply_wd(LinearSwiGLU(d_model, d_ff, bias=False))
        self.dropout = nn.Dropout(dropout)
        self.down_proj = apply_wd(zero_init(nn.Linear(d_ff, d_model, bias=False)))

    def forward(self, x):
        skip = x
        x = self.norm(x)
        x = self.up_proj(x)
        x = self.dropout(x)
        x = self.down_proj(x)
        return x + skip


class MappingNetwork(nn.Module):
    def __init__(self, n_layers, d_model, d_ff, dropout=0.0):
        super().__init__()
        self.in_norm = RMSNorm(d_model)
        self.blocks = nn.ModuleList([MappingFeedForwardBlock(d_model, d_ff, dropout=dropout) for _ in range(n_layers)])
        self.out_norm = RMSNorm(d_model)

    def forward(self, x):
        x = self.in_norm(x)
        for block in self.blocks:
            x = block(x)
        x = self.out_norm(x)
        return x


# ===================================================================================================
class FourierFeatures(nn.Module):
    def __init__(self, in_features, out_features, std=1.0):
        super().__init__()
        assert out_features % 2 == 0
        self.register_buffer("weight", torch.randn([out_features // 2, in_features]) * std)

    def forward(self, input):
        f = 2 * math.pi * input @ self.weight.T
        return torch.cat([f.cos(), f.sin()], dim=-1)
