import math
from typing import Optional, Union
from einops import rearrange
# from torchvision import transforms
from pathlib import Path
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---- Constants for normalization (kept once at module level) ----
IMAGENET_DEFAULT_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
IMAGENET_DEFAULT_STD  = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)


def _patch_aligned_resize(
    images: torch.Tensor,
    patch_size: int,
    keep_aspect: bool = True,
    *,
    mean: torch.Tensor,
    std: torch.Tensor,
    assume_normalized: bool = True,
) -> torch.Tensor:
    """
    Inputs:
      images: [B, 3, H, W], float32 in [0, 1]
    Returns:
      out:    [B, 3, H', W'] where H', W' are multiples of patch_size
    Notes:
      - Use GPU-friendly ops only (F.interpolate + manual normalization).
      - Normalize once in FP32 for stability.
    """
    images = images.float().div_(255.0) 
    assert images.ndim == 4 and images.shape[1] == 3, "Expected [B,3,H,W]" 
    assert images.min() >= 0 and images.max() <= 1, "Images must be in [0,1]"

    device = images.device
    B, C, H, W = images.shape
    new_w = ((W + patch_size - 1) // patch_size) * patch_size
    new_h = ((H + patch_size - 1) // patch_size) * patch_size

    if keep_aspect:
        scale = max(new_w / W, new_h / H)
        tgt_h = int(math.ceil(H * scale))
        tgt_w = int(math.ceil(W * scale))
    else:
        tgt_h, tgt_w = new_h, new_w

    # Resize in FP32 (more stable), channels_last helps some kernels
    x = images.contiguous(memory_format=torch.channels_last)
    x = F.interpolate(x, size=(tgt_h, tgt_w), mode="bicubic", antialias=True)

    # Central crop to the nearest patch-aligned window (only if keep_aspect)
    if keep_aspect:
        cut_w = tgt_w - new_w
        cut_h = tgt_h - new_h
        sx, sy = cut_w // 2, cut_h // 2
        x = x[:, :, sy:sy + new_h, sx:sx + new_w]

    # Manual normalization in FP32; broadcast constants to device
    x = (x - mean) / std

    return x

def load_pca_components(pca_path: Union[str, Path]):
    """
    Load PCA components from a .pt file containing a dict with keys:
      - 'mean': Tensor of shape (C,)
      - 'components': Tensor of shape (k, C)
    Returns:
      mean: torch.Tensor (C,)
      components: torch.Tensor (k, C)
    """
    data = torch.load(str(pca_path), map_location='cpu')
    mean = torch.as_tensor(data["mean"], dtype=torch.float32)        # (C,)
    components = torch.as_tensor(data["components"], dtype=torch.float32) # (k, C)
    return mean, components

class DINOv2StreamingEncoder(nn.Module):
    def __init__(
        self,
        pca_path: Union[str, Path],
        model_name: str = "dinov2_vits14_reg",
        device: str = "cuda",
        dtype: torch.dtype = torch.bfloat16,
        resize_grid_factor: int = 2,
        maintain_aspect: bool = True,
    ):
        super().__init__()
        self.model = torch.hub.load("facebookresearch/dinov2", model_name).to(device).eval()
        self.model = self.model.to(dtype=dtype)
        self.patch_size = self.model.patch_size
        self.device = device
        self.autocast_dtype = dtype
        self.keep_aspect = maintain_aspect
        self.resize_grid_factor = int(resize_grid_factor)
        # PCA
        pca_mean, pca_components = load_pca_components(pca_path)
        self.pca_W  = pca_components.to(device=device)
        self.pca_mu = pca_mean.to(device=device)
        
        self.register_buffer(
            "img_mean", torch.tensor([0.485,0.456,0.406], dtype=torch.float32, device=device).view(1,3,1,1),
            persistent=False
        )
        self.register_buffer(
            "img_std", torch.tensor([0.229,0.224,0.225], dtype=torch.float32, device=device).view(1,3,1,1),
            persistent=False
        )

        torch.backends.cuda.matmul.allow_tf32 = True

    @torch.inference_mode()
    def encode_grid(self, frame_bchw: torch.Tensor) -> torch.Tensor:
        """
        Args:
          frame_bchw: [B,3,H,W], float32 in [0,1]
        Returns:
          grid: [B, h, w, C_dino] in autocast dtype (e.g., BF16)
        """
        assume_norm = (frame_bchw.dtype != torch.uint8)
        
        x = _patch_aligned_resize(
          frame_bchw, self.patch_size, keep_aspect=self.keep_aspect,
          mean=self.img_mean, std=self.img_std,
          assume_normalized=assume_norm
        )
        x = x.to(device=self.device, dtype=torch.float32, non_blocking=True)
        with torch.autocast(device_type="cuda", dtype=self.autocast_dtype):
            tokens = self.model.get_intermediate_layers(x)[0]   # [B, (h*w), C_dino]
        H_p = x.shape[2] // self.patch_size
        grid = rearrange(tokens, 'b (h w) d -> b h w d', h=H_p).contiguous()
        return grid

    @torch.inference_mode()
    def downsample_grid(self, grid: torch.Tensor) -> torch.Tensor:
        """
        Spatially downsample the grid by `resize_grid_factor` using bilinear interpolation.
        Input/Output:
          grid: [B, h, w, C]  (dtype = autocast dtype)
        """
        if self.resize_grid_factor <= 1:
            return grid

        B, h, w, C = grid.shape
        new_h = int(round(h / self.resize_grid_factor))
        new_w = int(round(w / self.resize_grid_factor))

        # NHWC -> NCHW
        # in_dtype = grid.dtype
        tchw = grid.permute(0, 3, 1, 2).contiguous()
        # if tchw.dtype != in_dtype:
        #     tchw = tchw.to(in_dtype)

        tchw = F.interpolate(
            tchw, size=(new_h, new_w), mode="bilinear", align_corners=False
        )

        # NCHW -> NHWC
        grid_ds = tchw.permute(0, 2, 3, 1).contiguous()
        # if grid_ds.dtype != in_dtype:
        #     grid_ds = grid_ds.to(in_dtype)
        return grid_ds

    @torch.inference_mode()
    def apply_pca(self, grid: torch.Tensor) -> torch.Tensor:
        """
        Project grid features with PCA in FP32 and return (by default) autocast dtype.
        Args:
          grid: [B, h, w, C_dino] (autocast dtype)
        Returns:
          reduced: [B, h, w, C_pca] (same device), cast back to autocast dtype
        """
        # if self.pca_W is None or self.pca_mu is None:
        #     return grid
        B, h, w, C = grid.shape
        flat2d = grid.to(dtype=torch.float32).view(B*h*w, C)
        flat2d = flat2d - self.pca_mu.view(1,-1)
        reduced2d = flat2d @ self.pca_W.t()   # [B, h*w, C_pca]
        reduced = reduced2d.view(B, h, w, -1).to(self.autocast_dtype)
        return reduced

    @torch.inference_mode()
    def get_grid_pca(
        self,
        frame_bchw: torch.Tensor,
        prev_grid_pca: Optional[torch.Tensor]
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Returns:
          grid_pca:[B, h', w', C_pca]
          delta:[B, h', w', C_pca] (current - previous)
        """
        grid = self.encode_grid(frame_bchw)         # [B,h,w,C_dino]
        grid = self.downsample_grid(grid)           # [B,h',w',C_dino or C_dino]
        grid_pca = self.apply_pca(grid)             # [B,h',w',C_pca]
        if prev_grid_pca is None:
            delta = torch.zeros_like(grid_pca)
        else:
            delta = grid_pca - prev_grid_pca
        return grid_pca, delta