from pydoc import locate
import math

import torch
from torch import nn
from jaxtyping import Float
from einops import rearrange, repeat

from ..modular.layers import apply_wd


class AddTokensProj(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.proj = apply_wd(nn.Linear(in_features, out_features, bias=False))

    def forward(
        self,
        x: Float[torch.Tensor, "b ... d"],
        pos: Float[torch.Tensor, "b ... n"],
        x_extra: Float[torch.Tensor, "b ... d"],
        pos_extra: Float[torch.Tensor, "b ... n"],
        **kwargs,
    ):
        if "check_dict" in kwargs:
            kwargs["check_dict"]["x_extra"] = True
            kwargs["check_dict"]["pos_extra"] = True

        return torch.cat(
            [
                self.proj(rearrange(x, "b ... d -> b (...) d")),
                rearrange(x_extra, "b ... d -> b (...) d"),
            ],
            dim=-2,
        ), torch.cat(
            [
                rearrange(pos, "b ... n -> b (...) n"),
                rearrange(pos_extra, "b ... n -> b (...) n"),
            ],
            dim=-2,
        )


class Merge3DAddTokensProj(nn.Module):
    def __init__(self, in_features, out_features, patch_size=(1, 2, 2)):
        super().__init__()
        self.t = patch_size[0]
        self.h = patch_size[1]
        self.w = patch_size[2]
        self.proj = apply_wd(nn.Linear(in_features * self.t * self.h * self.w, out_features, bias=False))

    def forward(
        self,
        x: Float[torch.Tensor, "b c t h w"],
        pos: Float[torch.Tensor, "b c t h w"],
        x_extra: Float[torch.Tensor, "b ... d"],
        pos_extra: Float[torch.Tensor, "b ... n"],
        **kwargs,
    ):
        if "check_dict" in kwargs:
            kwargs["check_dict"]["x_extra"] = True
            kwargs["check_dict"]["pos_extra"] = True

        x = rearrange(
            x,
            "... (t nt) (h nh) (w nw) e -> ... t h w (nt nh nw e)",
            nt=self.t,
            nh=self.h,
            nw=self.w,
        )
        pos = rearrange(
            pos,
            "... (t nt) (h nh) (w nw) e -> ... t h w (nt nh nw) e",
            nt=self.t,
            nh=self.h,
            nw=self.w,
        ).mean(
            dim=-2
        )  # (B, T, H, W, C)

        # overwrite pos_extra with mean of pos
        pos_extra: Float[torch.Tensor, "b t d"] = pos.mean(dim=(-3, -2))

        return torch.cat(
            [
                self.proj(rearrange(x, "b ... d -> b (...) d")),
                rearrange(x_extra, "b ... d -> b (...) d"),
            ],
            dim=-2,
        ), torch.cat(
            [
                rearrange(pos, "b ... n -> b (...) n"),
                rearrange(pos_extra, "b ... n -> b (...) n"),
            ],
            dim=-2,
        )


class JustKeepExtraTokensProj(nn.Module):
    def __init__(self, in_features, out_features, num_motion_tokens_per_frame=1):
        super().__init__()
        self.num_motion_tokens_per_frame = num_motion_tokens_per_frame
        self.proj = apply_wd(
            nn.Linear(in_features * num_motion_tokens_per_frame, out_features * num_motion_tokens_per_frame, bias=False)
        )

    def forward(self, x, x_extra, **kwargs):
        pos = kwargs.get("pos", None)
        num_extra = math.prod(x_extra.shape[1:-1])  # b (t n) d -> t * n
        x = rearrange(x[..., -num_extra:, :], "... (t n) d -> ... t (n d)", n=self.num_motion_tokens_per_frame)
        if pos is not None:
            pos = rearrange(pos[..., -num_extra:, :], "... (t n) d -> ... t (n d)", n=self.num_motion_tokens_per_frame)
            return self.proj(x), pos
        return self.proj(x)


class FrameCondProj(nn.Module):
    def __init__(self, in_features, out_features, pre_proj_cls=None, pre_proj_params={}):
        super().__init__()
        self.proj = nn.Linear(out_features + 1, out_features, bias=False)
        self.pre_proj = locate(pre_proj_cls)(in_features, out_features, **pre_proj_params) if pre_proj_cls else None

    def forward(
        self,
        x: Float[torch.Tensor, "b ... d"],  # the motion tokens [b (t n) d]
        pos: Float[torch.Tensor, "b ... n"],
        x_cond: Float[torch.Tensor, "b ... d"],  # the image latents [(b t) h w d]
        **kwargs,
    ):
        if "check_dict" in kwargs:
            kwargs["check_dict"]["x_cond"] = True
        if self.pre_proj:
            (x, pos), (x_cond, pos_extra) = self.pre_proj(x, pos, **kwargs), self.pre_proj(x_cond, pos, **kwargs)
        return self.proj(
            torch.cat(
                [
                    torch.stack([x_cond, x], dim=1),
                    torch.stack(
                        [
                            torch.zeros_like(x_cond[..., :1]),
                            torch.ones_like(x[..., :1]),
                        ],
                        dim=1,
                    ),
                ],
                dim=-1,
            )
        ), torch.stack([pos_extra, pos], dim=1)


class RemoveFrameCondProj(nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        post_proj_cls=None,
        post_proj_params={},
    ):
        super().__init__()
        self.post_proj = locate(post_proj_cls)(in_features, out_features, **post_proj_params) if post_proj_cls else None

    def forward(self, x, **kwargs):
        pos = kwargs.get("pos", None)
        x = x[:, -1]
        pos = pos[:, -1] if not pos is None else pos
        if self.post_proj:
            x = self.post_proj(x, **kwargs) if pos is None else self.post_proj(x, pos, **kwargs)
        return x if pos is None else (x, pos)


class SingleToken2DProj(nn.Module):
    def __init__(self, in_features, out_features, size=(16, 16), **kwargs):
        super().__init__()
        self.h = size[0]
        self.w = size[1]
        self.proj = apply_wd(nn.Linear(in_features, out_features * self.h * self.w, bias=False))

    def forward(self, x, pos, **kwargs):
        x = rearrange(self.proj(x), "... (c h w) -> ... h w c", h=self.h, w=self.w)
        return x, pos
