import torch
import torch.nn as nn
from transformers import AutoModel, AutoVideoProcessor
import math
import re
from einops import repeat

# Helper function from LDM codebase
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    if not repeat_only:
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=timesteps.device)
        args = timesteps[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    else:
        embedding = repeat(timesteps, 'b -> b d', d=dim)
    return embedding


class VJEPA2_UNet_Wrapper(nn.Module):
    """
    VJEPA2_UNet_Wrapper serves as a UNet replacement with VJEPA2 as the backbone network.
    It is designed to be compatible with LatentVisualDiffusionv84 interface, receiving 
    latent noise, timesteps, and conditions, and outputting predicted noise (model_output) 
    and intermediate feature list (feat_list).

    Input:
        - x (torch.Tensor): Noisy latent representation, [B, C_in, T, H, W]
        - timesteps (torch.Tensor): Timesteps, [B]
        - context (dict): Condition dictionary containing 'c_crossattn' and 'c_concat'
    
    Output:
        - tuple[torch.Tensor, list[torch.Tensor]]:
            - model_output (torch.Tensor): Predicted noise, same shape as input x
            - feat_list (list[torch.Tensor]): List of intermediate feature layers simulating UNet output
    """
    def __init__(
        self,
        # --- UNet compatibility params ---
        in_channels: int,
        model_channels: int,
        out_channels: int,
        context_dim: int,
        # --- VJEPA2 specific params ---
        vjepa_model_size: str = 'giant',
        vjepa_checkpoint_path: str = None,
        # --- Compatibility params for v84 ---
        v84_feat_dims: list = None
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.context_dim = context_dim
        self.model_channels = model_channels

        # 1. Load VJEPA2 backbone network
        hf_repo = self._get_hf_repo(vjepa_model_size)
        print(f"Loading VJEPA2 model from: {hf_repo}")
        try:
            model_path = vjepa_checkpoint_path if vjepa_checkpoint_path else hf_repo
            self.vjepa = AutoModel.from_pretrained(model_path, output_hidden_states=True)
        except Exception as e:
            raise IOError(f"Failed to load VJEPA2 model: {e}")

        self.vjepa_hidden_dim = self.vjepa.config.hidden_size
        
        # Freeze all VJEPA2 parameters since we only use it as a feature extractor
        for param in self.vjepa.parameters():
            param.requires_grad = False
        
        # 2. Timestep embedding module
        # Embed timesteps and project to VJEPA2's hidden dimension
        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            nn.Linear(model_channels, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, self.vjepa_hidden_dim),
        )

        # 3. Input projection module
        # VJEPA2's native patch embedding is designed for 3-channel RGB images.
        # We need to create a new 3D convolution layer to handle 4-channel latent variables from VAE.
        vjepa_patch_embed = self.vjepa.embeddings.patch_embeddings.projection
        self.input_proj = nn.Conv3d(
            in_channels=in_channels,
            out_channels=self.vjepa_hidden_dim,
            kernel_size=vjepa_patch_embed.kernel_size,
            stride=vjepa_patch_embed.stride
        )
        self.pos_embed = self.vjepa.embeddings.position_embeddings

        # 4. Conditional context projection module
        self.context_proj = nn.Linear(context_dim, self.vjepa_hidden_dim)

        # 5. Output and feature list decoder
        self.v84_feat_dims = v84_feat_dims or [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280, 1280, 1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320, 320]
        
        # VJEPA2-giant has 40 transformer layers, we need to select 25 of them to generate feat_list
        num_vjepa_layers = len(self.vjepa.encoder.layer)
        self.vjepa_layer_indices = torch.linspace(0, num_vjepa_layers - 1, len(self.v84_feat_dims)).long().tolist()

        # Create a linear projection layer for each required feature
        self.feat_list_projectors = nn.ModuleList([
            nn.Linear(self.vjepa_hidden_dim, c_out) for c_out in self.v84_feat_dims
        ])

        # Final output projection layer using transposed convolution for upsampling
        self.out_proj = nn.ConvTranspose3d(
            self.vjepa_hidden_dim, 
            out_channels, 
            kernel_size=vjepa_patch_embed.kernel_size, 
            stride=vjepa_patch_embed.stride
        )

    def _get_hf_repo(self, model_size: str) -> str:
        """Build HuggingFace repository name based on model size"""
        size_map = {'large': 'vitl', 'huge': 'vith', 'giant': 'vitg'}
        if model_size not in size_map:
            raise ValueError(f"Unsupported VJEPA2 model size: {model_size}.")
        # Assume using pretrained model with 256 resolution and 64 frames
        return f"facebook/vjepa2-{size_map[model_size]}-fpc64-256"

    def forward(self, x, timesteps, context, **kwargs):
        B, C_in, T_in, H_in, W_in = x.shape

        # 1. Project input latent variable x to patch embeddings
        x_embed = self.input_proj(x)
        B, D, T_p, H_p, W_p = x_embed.shape
        x_embed = x_embed.flatten(2).transpose(1, 2)  # [B, T_p*H_p*W_p, D]
        
        # 2. Add positional embeddings
        # VJEPA2's positional embeddings may not match our patch count, need to check
        if x_embed.shape[1] != self.pos_embed.shape[1]:
             # If patch count doesn't match, may need to interpolate or truncate positional embeddings
             # Here we simply repeat or truncate to match
            num_patches = x_embed.shape[1]
            if num_patches > self.pos_embed.shape[1]:
                pos_embed = torch.cat(
                    [self.pos_embed] * (num_patches // self.pos_embed.shape[1]) + [self.pos_embed[:, :num_patches % self.pos_embed.shape[1]]], 
                    dim=1
                )
            else:
                pos_embed = self.pos_embed[:, :num_patches]
        else:
            pos_embed = self.pos_embed

        x_embed = x_embed + pos_embed
        
        # 3. Prepare and add timestep embeddings
        t_embed = self.time_embed(timestep_embedding(timesteps, self.model_channels)) # [B, D]
        t_embed = t_embed.unsqueeze(1) # [B, 1, D]
        x_embed = x_embed + t_embed

        # 4. Prepare and add conditional context embeddings
        c_crossattn = context['c_crossattn'][0]
        c_embed = self.context_proj(c_crossattn)  # [B, N_ctx, D]

        # 5. Concatenate sequence and pass through VJEPA2
        full_seq = torch.cat([c_embed, x_embed], dim=1)
        
        vjepa_outputs = self.vjepa(
            inputs_embeds=full_seq,
            output_hidden_states=True
        )
        
        all_hidden_states = vjepa_outputs.hidden_states # Tuple containing (embedding_output, layer_1_output, ...)
        
        # Separate content and condition tokens
        content_hidden_states = [
            h[:, c_embed.shape[1]:, :] for h in all_hidden_states
        ]

        # 6. Generate feat_list
        feat_list = []
        for i, proj_layer in enumerate(self.feat_list_projectors):
            vjepa_layer_idx = self.vjepa_layer_indices[i]
            # +1 because the first element of all_hidden_states is initial embedding
            hidden_state = content_hidden_states[vjepa_layer_idx + 1] # [B, N_content, D]

            # Project to expected channel dimension
            projected_feat = proj_layer(hidden_state)  # [B, N_content, C_feat]
            C_feat = projected_feat.shape[-1]
            
            # Reshape to 5D tensor: B x T x C x H x W
            # First restore channel dimension: [B, C_feat, N_content]
            projected_feat_reshaped = projected_feat.transpose(1, 2)
            # Then restore spatial/temporal dimensions: [B, C_feat, T_p, H_p, W_p]
            projected_feat_reshaped = projected_feat_reshaped.view(B, C_feat, T_p, H_p, W_p)
            # Finally permute to v84 expected format: [B, T_p, C_feat, H_p, W_p]
            # Note: dimension order here is T, C, H, W
            final_feat = projected_feat_reshaped.permute(0, 2, 1, 3, 4)
            feat_list.append(final_feat)

        # 7. Generate model_output
        # Use last layer features to predict noise
        last_hidden_state = content_hidden_states[-1] # [B, N_content, D]
        # Restore spatial/temporal dimensions
        last_hidden_state_reshaped = last_hidden_state.transpose(1, 2).view(B, D, T_p, H_p, W_p)
        
        # Upsample and project to output channels
        model_output = self.out_proj(last_hidden_state_reshaped)

        # Check and correct output size in case upsampling produces deviation
        if model_output.shape != x.shape:
            # Use interpolation to precisely match dimensions
            model_output = nn.functional.interpolate(model_output, size=x.shape[2:], mode='trilinear', align_corners=False)

        return model_output, feat_list


class VJEPA2_Direct_Wrapper(nn.Module):
    """
    A VJEPA2 wrapper that directly processes raw videos.
    It receives raw videos and conditions, outputting feature lists for physics loss calculation.
    This version bypasses VAE and diffusion processes. This version outputs features with 
    uniform dimensions to match LatentVisualDiffusionv84_VJEPA requirements.

    Input:
        - x (torch.Tensor): Raw video frames, [B, C, T, H, W] (C=3)
        - context (dict): Condition dictionary

    Output:
        - tuple[torch.Tensor, list[torch.Tensor]]:
            - model_output (None): Placeholder for UNet interface compatibility.
            - feat_list (list[torch.Tensor]): List of features extracted from VJEPA2 intermediate layers,
                                           all features have uniform dimensions.
    """
    def __init__(
        self,
        context_dim: int,
        vjepa_model_size: str = 'giant',
        vjepa_checkpoint_path: str = None,
        num_output_features: int = 16,
        temporal_length: int = 16,
        image_cross_attention: bool = True,
        **kwargs, # Accept and ignore unused arguments from the config
    ):
        super().__init__()
        self.context_dim = context_dim
        self.num_output_features = num_output_features
        self.temporal_length = temporal_length
        self.image_cross_attention = image_cross_attention

        # 1. Load VJEPA2 backbone network and its matching preprocessor
        hf_repo = self._get_hf_repo(vjepa_model_size)
        print(f"Loading VJEPA2 model for direct video processing from: {hf_repo}")
        try:
            model_path = vjepa_checkpoint_path if vjepa_checkpoint_path else hf_repo
            self.vjepa = AutoModel.from_pretrained(model_path, output_hidden_states=True)
            self.processor = AutoVideoProcessor.from_pretrained(model_path)
        except Exception as e:
            raise IOError(f"Failed to load VJEPA2 model or processor: {e}")

        self.vjepa_hidden_dim = self.vjepa.config.hidden_size
        
        # Freeze all VJEPA2 parameters
        for param in self.vjepa.parameters():
            param.requires_grad = False
        
        # 2. Conditional context projection module (if needed, but may not be used in this direct model)
        self.context_proj = nn.Linear(context_dim, self.vjepa_hidden_dim)

        # 3. Select VJEPA layers to output
        # We no longer project features to different dimensions.
        # Instead, we directly select N layers from VJEPA and output their hidden states.
        num_vjepa_layers = len(self.vjepa.encoder.layer)
        self.vjepa_layer_indices = torch.linspace(
            0, num_vjepa_layers - 1, self.num_output_features
        ).long().tolist()

        print(f"VJEPA2 Wrapper will output {self.num_output_features} features.")
        print(f"Selected VJEPA layer indices: {self.vjepa_layer_indices}")

        # `feat_list_projectors` has been removed because we output features with uniform dimensions.
        self.feat_list_projectors = None


    def _get_hf_repo(self, model_size: str) -> str:
        size_map = {'large': 'vitl', 'huge': 'vith', 'giant': 'vitg'}
        if model_size not in size_map:
            raise ValueError(f"Unsupported VJEPA2 model size: {model_size}.")
        return f"facebook/vjepa2-{size_map[model_size]}-fpc64-256"

    def forward(self, x, t=None, timesteps=None, cond=None, context=None, **kwargs):
        # timesteps are ignored in this direct processing model.
        # Accept multiple parameter names for compatibility
        final_context = cond if cond is not None else context

        # Input x shape is [B, C, T, H, W]
        # AutoVideoProcessor expects a list of [B, T, C, H, W]
        # So we need to adjust dimensions and convert to list
        B, C, T_in, H, W = x.shape
        x_permuted = x.permute(0, 2, 1, 3, 4) # -> [B, T, C, H, W]
        video_list = [frame for frame in x_permuted]
        
        # 1. Process video using official processor
        inputs = self.processor(video_list, return_tensors="pt").to(x.device)

        # 2. Process video directly with VJEPA2
        vjepa_outputs = self.vjepa(
            **inputs,
            output_hidden_states=True
        )
        
        # all_hidden_states[0] is initial patch embedding
        # all_hidden_states[1:] are outputs from each Transformer layer
        all_hidden_states = vjepa_outputs.hidden_states[1:] 

        # 2. Generate feat_list from intermediate layers
        feat_list = []
        
        # Ultimate solution: combine processor output and model hidden_state to reverse-engineer all dimensions
        # 1. Get real H and W from processed video
        processed_video = inputs['pixel_values_videos']
        _, _, _, H, W = processed_video.shape

        # 2. Get reliable patch_size from config to calculate spatial patch count
        cfg = self.vjepa.config
        patch_size = cfg.patch_size
        H_p = H // patch_size
        W_p = W // patch_size
        
        # 3. Get real total patch count from model output
        temp_hidden_state = all_hidden_states[0]
        B, N_patches, D = temp_hidden_state.shape

        # 4. Reverse-engineer temporal patch count
        assert N_patches % (H_p * W_p) == 0, f"Total patches {N_patches} not divisible by spatial patches {H_p}x{W_p}"
        T_p = N_patches // (H_p * W_p)

        for vjepa_layer_idx in self.vjepa_layer_indices:
            hidden_state = all_hidden_states[vjepa_layer_idx] # [B, N_patches, D]
            hidden_state_final = all_hidden_states[-1] # [B, N_patches, D]


            
            # No projection needed. Use the hidden_state directly.
            # C_feat is now the VJEPA hidden dimension.
            C_feat = self.vjepa_hidden_dim
            
            # Reshape to 5D tensor: B x T x C x H x W
            feat_reshaped = hidden_state.transpose(1, 2)
            feat_reshaped = feat_reshaped.view(B, C_feat, T_p, H_p, W_p)
            final_feat = feat_reshaped.permute(0, 2, 1, 3, 4)
            feat_list.append(final_feat)
        
        # 3. model_output is not necessary, return None as placeholder, final_feat as model_output
        model_output = hidden_state_final   

        return model_output, feat_list