# CViT (Continuous Vision Transformer) - PyTorch Version (1:1 with Flax)
from typing import Any, Dict, Tuple, cast, Literal

import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat
from functorch import vmap
from torch import Tensor


def resize_abs_pos_embed(
    pos_embed: torch.Tensor,
    new_size,
    old_size,
    interpolation: str = "bicubic",
    antialias: bool = True,
) -> torch.Tensor:
    """
    Resize absolute position embedding from [1, H, W, C] -> [1, H', W', C]
    """
    if new_size == old_size:
        return pos_embed

    pos_embed = pos_embed.permute(0, 3, 1, 2)  # [1, C, H, W]
    pos_embed = F.interpolate(pos_embed, size=new_size, mode=interpolation, antialias=antialias)
    pos_embed = pos_embed.permute(0, 2, 3, 1)  # [1, H', W', C]
    return pos_embed


def choose_kernel_size_random(kernel_scales_seq, probabilities=None, seed=None):
    # Create a new RNG generator
    generator = torch.Generator()
    if seed is not None:
        generator.manual_seed(seed)  # Set the seed only for this generator

    # Use the generator for random choice with probabilities
    if probabilities is None:
        probabilities = [1 / len(kernel_scales_seq)] * len(kernel_scales_seq)  # Uniform distribution
        probabilities = torch.tensor(probabilities, dtype=torch.float32)
    else:
        probabilities = torch.tensor(probabilities, dtype=torch.float32)

    # Sample index based on probabilities
    index = torch.multinomial(probabilities, 1, replacement=True, generator=generator).item()
    #index = seed%len(kernel_scales_seq)

    # Return the selected kernel size based on the sampled index
    return kernel_scales_seq[index]

def choose_kernel_size_alternating(kernel_scales_seq, probabilities=None, seed=None):
    index = 0
    if seed > 0:
        index = seed%len(kernel_scales_seq)
    # Return the selected kernel size based on the sampled index
    return kernel_scales_seq[index]

# === Positional Embeddings ===
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    assert embed_dim % 2 == 0
    omega = torch.arange(embed_dim // 2, dtype=torch.float32)
    omega /= embed_dim / 2.0
    omega = 1.0 / (10000**omega)

    pos = pos.reshape(-1)
    out = torch.einsum("m,d->md", pos, omega)
    emb_sin = torch.sin(out)
    emb_cos = torch.cos(out)
    emb = torch.cat([emb_sin, emb_cos], dim=1)
    return emb

def get_1d_sincos_pos_embed(embed_dim, length):
    pos = torch.arange(length, dtype=torch.float32)
    return get_1d_sincos_pos_embed_from_grid(embed_dim, pos).unsqueeze(0)

def get_2d_sincos_pos_embed(embed_dim, grid_size):
    def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
        assert embed_dim % 2 == 0
        emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
        emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
        return torch.cat([emb_h, emb_w], dim=1)  # (H*W, D)

    grid_h = torch.arange(grid_size[0], dtype=torch.float32)
    grid_w = torch.arange(grid_size[1], dtype=torch.float32)
    grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")  # Flax: W then H
    grid = torch.stack([grid_h.reshape(-1), grid_w.reshape(-1)], dim=0)  # (2, H*W)

    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    return pos_embed.unsqueeze(0)  # (1, H*W, D)

class PatchEmbed(nn.Module):
    def __init__(self, patch_size=(1, 16, 16), emb_dim=768, use_norm=False, layer_norm_eps=1e-5, in_chans=4):
        super().__init__()
        self.patch_size = patch_size
        self.emb_dim = emb_dim
        self.use_norm = use_norm

        self.proj = nn.Conv3d(
            in_channels=in_chans,               # assuming 4-channel input
            out_channels=emb_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

        if use_norm:
            self.norm = nn.LayerNorm(emb_dim, eps=layer_norm_eps)
        else:
            self.norm = nn.Identity()

    def forward(self, x, **kwargs):
        # x: (B, T, H, W, C)
        b, t, h, w, c = x.shape
        x = x.permute(0, 4, 1, 2, 3)  # → (B, C, T, H, W)
        x = self.proj(x)             # → (B, D, T', H', W')
        _, d, t_p, h_p, w_p = x.shape
        x = x.permute(0, 2, 3, 4, 1).reshape(b, t_p, h_p * w_p, d)  # (B, T', S, D)
        x = self.norm(x)
        return x
    
InterpolationType = Literal[
    "nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
]

class PatchEmbedFlexi(nn.Module):
    def __init__(
        self,
        base_kernel_size1d: Tuple[Tuple[int, int], ...] = ((1, 16),),
        base_kernel_size2d: Tuple[Tuple[int, int], ...] = ((1, 16, 16),),
        base_kernel_size3d: Tuple[Tuple[int, int], ...] = ((1, 16, 16, 16),),
        in_chans: int = 4,
        emb_dim: int = 768,
        spatial_dims: int = 2,
        bias: bool = True,
        antialias: bool = False,
        interpolation: InterpolationType = "bicubic",
        groups: int = 12,
        use_norm=False,
        layer_norm_eps=1e-5,
        flexi=False,
    ) -> None:
        super().__init__()

        # Use base kernel for a single conv layer
        self.base_kernel_size = (
            base_kernel_size2d[0] if spatial_dims == 2 else base_kernel_size3d[0]
        )
        self.base_kernel = self.base_kernel_size#[0]
        self.stride = self.base_kernel

        self.in_chans = in_chans
        self.spatial_dims = spatial_dims
        self.antialias = antialias
        self.interpolation = interpolation
        self.norm_layer = nn.GroupNorm
        self.flexi = flexi

        if spatial_dims == 1:
            conv: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = nn.Conv1d
            self.conv_func = F.conv1d
            self.interpolation = "nearest"
        elif spatial_dims == 2:
            conv = nn.Conv2d
            self.conv_func = F.conv3d
        elif spatial_dims == 3:
            conv = nn.Conv3d
            self.conv_func = F.conv3d
            self.interpolation = "trilinear"

        # Single convolutional layer
        self.proj = nn.Conv3d(
            in_chans,
            emb_dim,
            kernel_size=self.base_kernel,
            stride=self.base_kernel,
            #bias=False,
        )

        if use_norm:
            self.norm = nn.LayerNorm(emb_dim, eps=layer_norm_eps)
        else:
            self.norm = nn.Identity()

        self.kernel_scales = ((1, 4, 4), (1, 8, 8), (1, 16, 16)) #generate_two_conv_combinations(kernel_seq, self.spatial_dims)
        # Pre-calculate pinvs
        self.pinvs = self._cache_pinvs()

    def _cache_pinvs(self) -> dict:
        """Pre-calculate all pinv matrices"""
        pinvs = {}
        for ps in self.kernel_scales:
            #ps = to_2tuple(ps)
            pinvs[ps] = self._calculate_pinv(self.base_kernel, ps)
        return pinvs

    def _resize(self, x: Tensor, shape: Tuple[int, int]) -> Tensor:
        """
        Resize a 3D kernel of shape (1, H, W) to new shape (1, H', W') using 2D interpolation.
        """
        assert x.ndim == 3 and x.shape[0] == 1, f"Expected input shape (1, H, W), got {x.shape}"
        x = x.squeeze(0)                    # (H, W)
        x = x.unsqueeze(0).unsqueeze(0)     # → (1, 1, H, W)
        x_resized = F.interpolate(
            x,
            size=shape[-2:],                # (H', W')
            mode="bicubic",
            antialias=self.antialias,
        )
        return x_resized.squeeze(0).squeeze(0).unsqueeze(0)  # → (1, H', W')

    def _calculate_pinv(
        self, old_shape: Tuple[int, ...], new_shape: Tuple[int, ...]
    ) -> Tensor:
        mat = []
        for i in range(np.prod(old_shape)):
            basis_vec = torch.zeros(tuple(old_shape))
            basis_vec[np.unravel_index(i, old_shape)] = 1.0
            mat.append(self._resize(basis_vec, new_shape).reshape(-1))
        resize_matrix = torch.stack(mat)
        return torch.linalg.pinv(resize_matrix)

    def resize_patch_embed(self, patch_embed: Tensor, new_patch_size: Tuple[int, int]):
        """Resize patch_embed to target resolution via pseudo-inverse resizing"""
        # Return original kernel if no resize is necessary
        print(f'base kernel size {self.base_kernel_size}')
        print(f'new patch size {new_patch_size}')
        if self.base_kernel == new_patch_size:
            return patch_embed

        # Calculate pseudo-inverse of resize matrix
        if new_patch_size not in self.pinvs:
            self.pinvs[new_patch_size] = self._calculate_pinv(
                self.base_kernel_size, new_patch_size
            )
        pinv = self.pinvs[new_patch_size]
        pinv = pinv.to(patch_embed.device)

        def resample_patch_embed(patch_embed: Tensor):
            #_, h, w = new_patch_size
            resampled_kernel = pinv @ patch_embed.reshape(-1)
            #return rearrange(resampled_kernel, "(h w) -> h w", h=h, w=w)
            return resampled_kernel.view(*new_patch_size)

        v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1)

        return v_resample_patch_embed(patch_embed)

    def forward(
        self, x: Tensor, bcs=None, metadata=None, **kwargs
    ) -> Tensor:
        new_kernel = tuple(kwargs["random_kernel"])
        assert len(new_kernel) == len(self.base_kernel), "Patch size dimensionality mismatch"
        #new_kernel = tuple(embed_kernel[i] for i in range(self.spatial_dims))
        if self.flexi:
            print(f'Flexi mode with base patch of {self.base_kernel} and random patch of {new_kernel}')

            if new_kernel != self.base_kernel:
                weight = self.resize_patch_embed(
                    self.proj.weight,
                    new_kernel,
                )
            else:
                weight = self.proj.weight
        #print(f'x shape {x.shape}; weight shape {weight.shape}; stride shape {new_kernel}; base kernel shape {self.base_kernel}')
        # x: (B, T, H, W, C)
        b, t, h, w, c = x.shape
        x = x.permute(0, 4, 1, 2, 3)  # → (B, C, T, H, W)
        x = self.conv_func(x, weight, stride=new_kernel)
        _, d, t_p, h_p, w_p = x.shape
        x = x.permute(0, 2, 3, 4, 1).reshape(b, t_p, h_p * w_p, d)  # (B, T', S, D)
        x = self.norm(x)
        return x

class MlpBlock(nn.Module):
    def __init__(self, hidden_dim: int, out_dim: int):
        super().__init__()
        self.fc1 = nn.Linear(out_dim, hidden_dim)  # Matches self.dim in Flax
        self.fc2 = nn.Linear(hidden_dim, out_dim)     # Matches self.out_dim in Flax

    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x

class Mlp(nn.Module):
    def __init__(self, num_layers, hidden_dim, out_dim, eps=1e-5):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.eps = eps
        self.num_layers = num_layers

        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.GELU()
            ) for _ in range(num_layers)
        ])
        self.norms = nn.ModuleList([
            nn.LayerNorm(hidden_dim, eps) for _ in range(num_layers)
        ])
        self.final = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        for layer, norm in zip(self.layers, self.norms):
            y = layer(x)
            x = norm(x + y)  # ✅ Matches Flax
        return self.final(x)

class SelfAttnBlock(nn.Module):
    def __init__(self, emb_dim, num_heads, mlp_ratio, layer_norm_eps=1e-5):
        super().__init__()
        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio

        self.norm1 = nn.LayerNorm(emb_dim, eps=layer_norm_eps)
        self.attn = nn.MultiheadAttention(emb_dim, num_heads, batch_first=True)

        self.norm2 = nn.LayerNorm(emb_dim, eps=layer_norm_eps)
        self.mlp = MlpBlock(emb_dim * mlp_ratio, emb_dim)

    def forward(self, x):
        # x: (B, N, D)
        x1 = self.norm1(x)
        attn_output, _ = self.attn(x1, x1, x1)
        #print(f'x shape {x.shape}')
        x = x + attn_output  # residual

        x2 = self.norm2(x)
        x = x + self.mlp(x2)  # second residual

        return x

class CrossAttnBlock(nn.Module):
    def __init__(self, emb_dim, num_heads, mlp_ratio, layer_norm_eps=1e-5):
        super().__init__()
        self.norm_q = nn.LayerNorm(emb_dim, eps=layer_norm_eps)
        self.norm_kv = nn.LayerNorm(emb_dim, eps=layer_norm_eps)
        self.norm2 = nn.LayerNorm(emb_dim, eps=layer_norm_eps)
        self.attn = nn.MultiheadAttention(emb_dim, num_heads, batch_first=True)
        self.mlp = MlpBlock(emb_dim * mlp_ratio, emb_dim)

    def forward(self, q_input, kv_input):
        # Normalize
        q = self.norm_q(q_input)
        kv = self.norm_kv(kv_input)

        # Flatten B, S for TimeAggregation shape (B, S, T, D)
        if q.dim() == 4:
            bq, sq, tq, d = q.shape
            bk, sk, tk, _ = kv.shape
            q = q.reshape(bq * sq, tq, d)
            kv = kv.reshape(bk * sk, tk, d)
            x, _ = self.attn(q, kv, kv)
            x = x + q_input.reshape(bq * sq, tq, d)
            x = x.reshape(bq, sq, tq, d)
        else:
            # Normal (B, N, D) inputs for decoder cross-attn
            x, _ = self.attn(q, kv, kv)
            x = x + q_input

        # FFN
        x2 = self.norm2(x)
        #print(f'x shape {x.shape}')
        x2 = self.mlp(x2)
        return x + x2


class TimeAggregation(nn.Module):
    def __init__(self, emb_dim, depth, num_latents=1, num_heads=8, mlp_ratio=1, eps=1e-5):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(num_latents, emb_dim))
        self.blocks = nn.ModuleList([
            CrossAttnBlock(emb_dim, num_heads, mlp_ratio, eps) for _ in range(depth)
        ])

    def forward(self, x):  # (B, T, S, D)
        b, t, s, d = x.shape
        latents = repeat(self.latents, 't d -> b s t d', b=b, s=s).to(x.device)
        x = rearrange(x, 'b t s d -> b s t d')
        for block in self.blocks:
            latents = block(latents, x)
        return rearrange(latents, 'b s t d -> b t s d')

class Encoder(nn.Module):
    def __init__(self, patch_size, emb_dim, depth, num_heads, mlp_ratio, eps=1e-5,
                 input_time=6, input_hw=(128, 384), in_chans=4, flexi=False):
        super().__init__()
        if flexi:
            self.patch_embed_flexi = PatchEmbedFlexi(emb_dim=emb_dim, flexi=flexi, in_chans=in_chans)
        else:
            self.patch_embed = PatchEmbed(patch_size, emb_dim, in_chans=in_chans)
        self.time_agg = TimeAggregation(emb_dim, depth=2, num_latents=1,
                                        num_heads=num_heads, mlp_ratio=mlp_ratio, eps=eps)
        self.norm = nn.LayerNorm(emb_dim, eps)
        self.grid_size = input_hw
        self.flexi = flexi
        self.blocks = nn.ModuleList([
            SelfAttnBlock(emb_dim, num_heads, mlp_ratio, eps) for _ in range(depth)
        ])
        self.emb_dim = emb_dim
        self.patch_size = patch_size

        # === Match Flax-style positional embeddings: register them once and reuse ===
        pt, ph, pw = patch_size
        h_p, w_p = input_hw[0] // ph, input_hw[1] // pw
        t_p = input_time // pt

        t_emb = get_1d_sincos_pos_embed(emb_dim, t_p)         # (1, T', D)
        s_emb = get_2d_sincos_pos_embed(emb_dim, (h_p, w_p))  # (1, S, D)
        self.s_emb_grid = rearrange(s_emb, '1 (H W) C -> 1 H W C', H=h_p, W=w_p)
        self.old_s_emb_size = (h_p, w_p)

        self.register_buffer("t_emb", t_emb, persistent=True)
        self.register_buffer("s_emb", s_emb, persistent=True)

    def forward(self, x, **kwargs):
        # x: (B, T, H, W, C)
        if kwargs and self.flexi:
            x = self.patch_embed_flexi(x, **kwargs)
            new_patch_size = kwargs['random_kernel']
            h_p = self.grid_size[0] // new_patch_size[1]
            w_p = self.grid_size[1] // new_patch_size[2]
            pos_embed = resize_abs_pos_embed(
                self.s_emb_grid,
                new_size=(h_p, w_p),
                old_size=self.old_s_emb_size,
                interpolation="bicubic",
                antialias=True,
                )
            pos_embed = rearrange(pos_embed, '1 H W C -> 1 (H W) C')
            t_emb = self.t_emb.to(x.device)
            pos_embed = pos_embed.to(x.device)
            x = x + self.t_emb[:, :, None, :] + pos_embed[:, None, :, :]  # (B, T', S, D)

        else:
            x = self.patch_embed(x)  # (B, T', S, D)
            x = x + self.t_emb[:, :, None, :] + self.s_emb[:, None, :, :]  # (B, T', S, D)

        x = self.time_agg(x)
        x = self.norm(x)
        x = rearrange(x, 'b t s d -> b (t s) d')
        for block in self.blocks:
            x = block(x)
        return x

def get_coords_from_stride(image_hw, patch_hw, stride_hw, device):
    H, W = image_hw
    kernel_h, kernel_w = patch_hw
    stride_h, stride_w = stride_hw

    h_out = (H - kernel_h) // stride_h + 1
    w_out = (W - kernel_w) // stride_w + 1

    x_lin = torch.linspace(0, 1, h_out, device=device)
    y_lin = torch.linspace(0, 1, w_out, device=device)
    xx, yy = torch.meshgrid(x_lin, y_lin, indexing='ij')
    coords = torch.stack([xx.flatten(), yy.flatten()], dim=1)
    return coords, h_out, w_out

class CViT(nn.Module):
    def __init__(self, patch_size=(1, 16, 16), grid_size=(256, 256), latent_dim=256,
                 emb_dim=256, depth=3, num_heads=8, dec_emb_dim=256, dec_num_heads=8,
                 dec_depth=1, num_mlp_layers=1, mlp_ratio=1, out_dim=11, eps=1e5,
                 embedding_type='grid', layer_norm_eps=1e-5, flexi=False,
                 input_time=6):
        super().__init__()
        self.encoder = Encoder(patch_size, emb_dim, depth, num_heads, mlp_ratio, in_chans=out_dim, eps=layer_norm_eps, flexi=flexi, input_time=input_time)
        self.embedding_type = embedding_type
        self.dec_emb_dim = dec_emb_dim
        self.eps = eps
        self.patch_size = patch_size
        self.grid_size = grid_size
        self.flexi = flexi
        self.kernel_scales_seq = (4, 8, 16)

        if embedding_type == "grid":
            x = torch.linspace(0, 1, grid_size[0])
            y = torch.linspace(0, 1, grid_size[1])
            xx, yy = torch.meshgrid(x, y, indexing="ij")
            self.register_buffer("grid", torch.stack([xx.flatten(), yy.flatten()], dim=1))
            self.latents = nn.Parameter(torch.randn(grid_size[0] * grid_size[1], latent_dim))

        self.coords_proj = nn.Linear(dec_emb_dim, dec_emb_dim)
        self.cross_attn = nn.ModuleList([
            CrossAttnBlock(dec_emb_dim, dec_num_heads, mlp_ratio, layer_norm_eps)
            for _ in range(dec_depth)
        ])
        self.norm = nn.LayerNorm(dec_emb_dim, eps=layer_norm_eps)
        self.mlp = Mlp(num_layers=num_mlp_layers,
                       hidden_dim=dec_emb_dim,
                       out_dim=out_dim,
                       eps=layer_norm_eps)
        self.enc_to_dec = nn.Linear(emb_dim, dec_emb_dim)

    def forward(self, x, coords=None, train=False, seed=None):
        b, t, h, w, c = x.shape
        #print(f'input shape, {x.shape}')
        #print(f'grid shape, {self.grid_size}')
        #if coords is not None:
            #print(f'coords shape {coords.shape}')
        if coords is None:
            x_lin = torch.linspace(0, 1, h, device=x.device)
            y_lin = torch.linspace(0, 1, w, device=x.device)
            xx, yy = torch.meshgrid(x_lin, y_lin, indexing='ij')
            coords = torch.stack([xx, yy], dim=-1).reshape(-1, 2)  # (H*W, 2)
        if self.flexi and train:
            random_kernel = choose_kernel_size_random(self.kernel_scales_seq, seed=seed)
            dynamic_ks = (1, random_kernel, random_kernel)
            x = self.encoder(x, random_kernel=dynamic_ks)
        elif (self.flexi and (not train)):
            random_kernel = choose_kernel_size_alternating(self.kernel_scales_seq, seed=seed)
            dynamic_ks = (1, random_kernel, random_kernel)
            print(f'dynamic ks {dynamic_ks}, and seed {seed}')
            x = self.encoder(x, random_kernel=dynamic_ks)
        else:
            x = self.encoder(x)                # → (B, N, D)
        x = F.layer_norm(x, (x.shape[-1],))  # Matches Flax LN after encoder
        #print(f'shape before dec: {x.shape}')
        x = self.enc_to_dec(x)             # → (B, N, D)
        #print(f'shape after dec: {x.shape}')

        if self.embedding_type == "grid":
            d2 = ((coords[:, None, :] - self.grid[None, :, :]) ** 2).sum(dim=-1)
            w_ = torch.exp(-self.eps * d2)
            w_ = w_ / w_.sum(dim=1, keepdim=True)
            coords = torch.einsum("ic,pi->pc", self.latents, w_)
            coords = self.coords_proj(coords)
            coords = F.layer_norm(coords, (self.dec_emb_dim,))
            coords = repeat(coords, 'n d -> b n d', b=b)

        for block in self.cross_attn:
            coords = block(coords, x)

        x = self.norm(coords)
        x = self.mlp(x)  # Final prediction (B, N, out_dim)
        #print(f'x shape after processing {x.shape}')
        #print(h, w)
        if not train:
            x = rearrange(x, "b (h w) c -> b h w c", h=h, w=w)
        return x
