from typing import Dict, Optional, Union
import numpy as np 
import random
from collections import OrderedDict

import torch
import torchvision
from torchvision import transforms
from accelerate import Accelerator
from diffusers.utils.torch_utils import is_compiled_module
from safetensors import safe_open
from einops import rearrange


def unwrap_model(accelerator: Accelerator, model):
    """Unwrap a model from Accelerator wrapping and compilation wrappers."""
    model = accelerator.unwrap_model(model)
    # Remove PyTorch compilation wrapper if the model is compiled
    model = model._orig_mod if is_compiled_module(model) else model
    return model


def align_device_and_dtype(
    x: Union[torch.Tensor, Dict[str, torch.Tensor]],
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
):
    """
    Move a tensor or dictionary of tensors to the specified device and dtype.
    
    Args:
        x: Input tensor or dictionary of tensors to process.
        device: Target device (e.g., 'cuda', 'cpu'). If None, device remains unchanged.
        dtype: Target data type (e.g., torch.float32, torch.bfloat16). If None, dtype remains unchanged.
    
    Returns:
        Processed tensor or dictionary of tensors with updated device/dtype.
    """
    if isinstance(x, torch.Tensor):
        if device is not None:
            x = x.to(device)
        if dtype is not None:
            x = x.to(dtype)
    elif isinstance(x, dict):
        # Recursively process each value in the dictionary
        if device is not None or dtype is not None:
            x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
    return x


def seed_everything(seed: int = 0):
    """Set random seeds for all relevant libraries to ensure reproducibility."""
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # Set seed for all available GPUs


def load_safetensor(path: str, device_map: Union[str, torch.device] = 'cpu'):
    """
    Load a Safetensors file into a dictionary of PyTorch tensors.
    
    Args:
        path: Path to the Safetensors file.
        device_map: Target device to move tensors to (e.g., 'cpu', 'cuda:0').
    
    Returns:
        Dictionary where keys are tensor names and values are corresponding PyTorch tensors.
    """
    tensors = {}
    # Open Safetensors file and load all tensors
    with safe_open(path, framework="pt", device=0) as f:
        for k in f.keys():
            tensors[k] = f.get_tensor(k).to(device_map)
    return tensors


def load_state_dict_skip_mismatch(model, state_dict: OrderedDict):
    """
    Load a state dictionary into a model, skipping layers with shape mismatches.
    
    Args:
        model: PyTorch model to load the state dictionary into.
        state_dict: OrderedDict containing pre-trained weights.
    
    Returns:
        List of keys corresponding to layers skipped due to shape mismatch.
    """
    model_state_dict = model.state_dict()
    mismatched_keys = []
    
    # Step 1: Filter state_dict to only include layers with matching keys and shapes
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k in model_state_dict and v.shape == model_state_dict[k].shape:
            new_state_dict[k] = v
        else:
            mismatched_keys.append(k)
            print(f"Skipping {k} due to shape mismatch.")
            model_shape = getattr(model_state_dict.get(k, 'N/A'), 'shape', 'N/A')
            print(f"    Loaded shape: {v.shape}, Model shape: {model_shape}")

    # Step 2: Load the filtered state dictionary into the model
    missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
    
    return mismatched_keys, missing_keys, unexpected_keys


def load_state_dict(
    model, 
    path: str, 
    strict: bool = False, 
    device_map: Union[str, torch.device] = 'cpu', 
    ignore_keys: Optional[list] = None
):
    """
    Load a state dictionary (from .pth or .safetensors) into a model, with support for skipping mismatches.
    
    Args:
        model: PyTorch model to load weights into.
        path: Path to the weight file (.pth or .safetensors).
        strict: Whether to enforce exact key matching (ignored here, as mismatches are skipped).
        device_map: Target device to load tensors onto.
        ignore_keys: List of keys to remove from the state_dict before loading.
    
    Returns:
        Model with loaded weights.
    """
    # Load state_dict based on file type
    if path.endswith('safetensors'):
        state_dict = load_safetensor(path, device_map)
    else:
        state_dict = torch.load(path, map_location=device_map)

    # Remove specified keys from the state_dict
    if ignore_keys is not None:
        for k in ignore_keys:
            state_dict.pop(k, None)  # Use pop with None to avoid KeyError if key doesn't exist

    # Load state_dict and skip shape mismatches
    mismatched_keys, missing_keys, unexpected_keys = load_state_dict_skip_mismatch(model, state_dict)
    print(f'Loaded state dict from {path}.\nMissing keys: {missing_keys}, Unexpected keys: {unexpected_keys}, Mismatched keys: {mismatched_keys}')

    return model


def apply_color_jitter_to_video(
    tensor: torch.Tensor, 
    jitter: Optional[transforms.ColorJitter] = None, 
    same_jitter_within_view: bool = False, 
    n_view: int = 3
):
    """
    Apply ColorJitter augmentation to a video tensor with shape (B, C, T, H, W) and value range [-1, 1].
    
    Args:
        tensor: Input video tensor, shape (B, C, T, H, W), value range [-1, 1].
        jitter: Pre-configured ColorJitter instance. If None, uses default parameters.
        same_jitter_within_view: Whether to apply the same jitter to all frames of a single view (for multi-view data).
        n_view: Number of views per sample (only used if same_jitter_within_view is True).
    
    Returns:
        Augmented video tensor with the same shape and value range [-1, 1].
    """
    B, C, T, H, W = tensor.shape
    assert C == 3, "ColorJitter only applies to 3-channel RGB images"
    
    # Use default ColorJitter parameters if none are provided
    if jitter is None:
        jitter = transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1)
    
    # Convert value range from [-1, 1] to [0, 1] (required for ColorJitter)
    tensor = (tensor + 1.0) / 2.0

    if same_jitter_within_view:
        # Reshape to group frames by view: (B, C, T, H, W) → (B//n_view, n_view*T, C, H, W)
        tensor = rearrange(tensor, '(b v) c t h w -> b (v t) c h w', v=n_view, t=T)
        # Apply the same jitter to all frames of a single view
        for b in range(B // n_view):
            tensor[b] = jitter(tensor[b])
        # Reshape back to original format
        tensor = rearrange(tensor, 'b (v t) c h w -> (b v) c t h w', v=n_view, t=T)
    else:
        # Reshape to (B, T, C, H, W) to apply jitter frame-wise
        tensor = rearrange(tensor, 'b c t h w -> b t c h w')
        # Apply jitter to each frame individually
        for b in range(B):
            tensor[b, :, :] = jitter(tensor[b, :, :])
        # Reshape back to original format
        tensor = rearrange(tensor, 'b t c h w -> b c t h w')
    
    # Convert value range back to [-1, 1]
    tensor = tensor * 2.0 - 1.0
    
    return tensor


def get_latents(
    vae,
    mem: torch.Tensor,
    video: torch.Tensor,
    patch_size: int = 1,
    patch_size_t: int = 1,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
    generator: Optional[torch.Generator] = None,
    enc_only: bool = False,
    return_unpack: bool = False,
    sampling: bool = True,
    sep_mem: bool = True,
):
    """
    Encode memory and video tensors into latents using a VAE, then pack latents into token format.
    
    Args:
        vae: Variational Autoencoder (VAE) model for encoding.
        mem: Memory tensor, shape (B*V, C, M, H, W), value range [-1, 1] (V = number of views, M = memory frames).
        video: Video tensor, shape (B*V, C, F, H, W), value range [-1, 1] (F = video frames).
        patch_size: Spatial patch size for latent packing (height/width dimension).
        patch_size_t: Temporal patch size for latent packing (frame dimension).
        device: Target device for latents. If None, uses VAE's device.
        dtype: Target dtype for latents. If None, uses VAE's dtype.
        generator: PyTorch Generator for deterministic sampling from VAE latent distribution.
        enc_only: Whether to return raw VAE encoder outputs (instead of latent samples/mode).
        return_unpack: Whether to return both packed and unpacked latents.
        sampling: Whether to sample from the VAE latent distribution (True) or use the mode (False).
        sep_mem: Whether to treat each memory frame as an individual sample during encoding.
    
    Returns:
        Tuple containing:
            - Packed memory latents: Shape (B*V*M, S, D) (S = sequence length, D = feature dimension).
            - Packed video latents: Shape (B*V, S, D).
            - Unpacked memory latents (if return_unpack=True): Shape (B, C, M*F, H_latent, W_latent).
            - Unpacked video latents (if return_unpack=True): Shape (B*V, C, F_latent, H_latent, W_latent).
    """
    device = device or vae.device
    mem_size = mem.shape[2]  # Number of memory frames (M)

    # Reshape memory tensor if treating each frame as an individual sample
    if sep_mem:
        # (B*V, C, M, H, W) → (B*V*M, C, 1, H, W) (add single-frame temporal dimension)
        mem = rearrange(mem, 'b c m h w -> (b m) c h w').unsqueeze(2)

    # Encode video tensor into latents
    if not enc_only:
        video_enc_out = vae.encode(video)
        if sampling:
            video_latents = video_enc_out.latent_dist.sample(generator=generator)
        else:
            video_latents = video_enc_out.latent_dist.mode()
        video_latents = video_latents.to(dtype=dtype)
        # Normalize latents using VAE's preconfigured mean/std
        video_latents = _normalize_latents(video_latents, vae.config.latents_mean, vae.config.latents_std)
    else:
        # Return raw encoder outputs (no sampling/normalization)
        video_latents = vae.encode(video).to(dtype=dtype)
    # Pack video latents into token format
    video_latents_pack = _pack_latents(video_latents, patch_size, patch_size_t)

    # Encode memory tensor into latents
    if not enc_only:
        mem_enc_out = vae.encode(mem)
        if sampling:
            mem_latents = mem_enc_out.latent_dist.sample(generator=generator)
        else:
            mem_latents = mem_enc_out.latent_dist.mode()
        mem_latents = mem_latents.to(dtype=dtype)
        # Normalize latents using VAE's preconfigured mean/std
        mem_latents = _normalize_latents(mem_latents, vae.config.latents_mean, vae.config.latents_std)
    else:
        # Return raw encoder outputs (no sampling/normalization)
        mem_latents = vae.encode(mem).to(dtype=dtype)
    # Pack memory latents into token format
    mem_latents_pack = _pack_latents(mem_latents, patch_size, patch_size_t)

    # Return packed latents + unpacked latents (if requested)
    if return_unpack:
        if sep_mem:
            # Reshape unpacked memory latents back to (B, C, M*F, H_latent, W_latent)
            unpacked_mem = rearrange(mem_latents, '(b m) c f h w -> b c (m f) h w', m=mem_size)
            return mem_latents_pack, video_latents_pack, unpacked_mem, video_latents
        else:
            return mem_latents_pack, video_latents_pack, mem_latents, video_latents
    else:
        return mem_latents_pack, video_latents_pack, None, None


def _normalize_latents(
    latents: torch.Tensor, 
    latents_mean: Union[torch.Tensor, float], 
    latents_std: Union[torch.Tensor, float], 
    scaling_factor: float = 1.0,
    reverse: bool = False,
) -> torch.Tensor:
    """
    Normalize or denormalize VAE latents across the channel dimension.
    
    Args:
        latents: Latent tensor, shape (B, C, F, H, W).
        latents_mean: Mean value(s) for normalization (per channel or scalar).
        latents_std: Standard deviation value(s) for normalization (per channel or scalar).
        scaling_factor: Additional scaling factor applied during normalization.
        reverse: Whether to perform denormalization (reverse of normalization).
    
    Returns:
        Normalized or denormalized latent tensor.
    """
    # Convert mean/std to tensors if provided as scalars
    if not isinstance(latents_mean, torch.Tensor):
        latents_mean = torch.tensor(latents_mean, device=latents.device, dtype=latents.dtype)
    if not isinstance(latents_std, torch.Tensor):
        latents_std = torch.tensor(latents_std, device=latents.device, dtype=latents.dtype)
    
    # Reshape mean/std to broadcast across (B, C, F, H, W)
    latents_mean = latents_mean.view(1, -1, 1, 1, 1)
    latents_std = latents_std.view(1, -1, 1, 1, 1)
    
    if not reverse:
        # Normalization: (latents - mean) * scaling_factor / std
        latents = (latents - latents_mean) * scaling_factor / latents_std
    else:
        # Denormalization: latents * std / scaling_factor + mean
        latents = latents * latents_std / scaling_factor + latents_mean
    return latents


def unpack_latents(
        latents: torch.Tensor, 
        num_frames: int, 
        height: int, 
        width: int, 
        patch_size: int = 1, 
        patch_size_t: int = 1
    ) -> torch.Tensor:
    """
    Unpack packed latent tokens back into a video tensor. Inverse operation of `_pack_latents`.
    
    Args:
        latents: Packed latent tokens, shape (B, S, D) (S = sequence length, D = feature dimension).
        num_frames: Number of frames in the original video (F).
        height: Spatial height of the original latent map (H_latent).
        width: Spatial width of the original latent map (W_latent).
        patch_size: Spatial patch size used during packing (height/width dimension).
        patch_size_t: Temporal patch size used during packing (frame dimension).
    
    Returns:
        Unpacked video latent tensor, shape (B, C, F, H_latent, W_latent).
    """
    batch_size = latents.size(0)
    # Reshape packed tokens back into patch dimensions
    latents = latents.reshape(
        batch_size, 
        num_frames, 
        height, 
        width, 
        -1,  # Number of channels (C)
        patch_size_t,  # Temporal patch size
        patch_size,    # Spatial patch size (height)
        patch_size     #
