"""
DINOv2 视觉编码器
使用预训练的 DINOv2 作为冻结的 backbone，处理多视角图像输入
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple, Sequence
import torchvision.transforms as tvt

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

class DINOv2MultiViewEncoder(nn.Module):
    """
    使用 DINOv2 作为 backbone 的多视角视觉编码器
    
    特点：
    - 支持多视角输入
    - 支持 framestack（通道堆叠）
    - DINOv2 backbone 保持冻结
    - 可调整的 MLP 用于视觉特征处理
    - 对每帧分别编码，然后拼接
    
    Args:
        num_views: 视角数量
        visual_feature_dim: 每帧经过 MLP 后的特征维度（最终输出是 visual_feature_dim * framestack）
        mlp_hidden_dims: MLP 的隐藏层维度列表
        model_type: DINOv2 模型类型，默认为 'dinov2_vits14'
        use_cls_token: 是否使用 CLS token，否则使用全局平均池化
        dropout: MLP 中的 dropout 率
        framestack: 帧堆叠数量（默认为1，表示不堆叠）
    """
    
    def __init__(
        self,
        num_views: int = 1,
        visual_feature_dim: int = 64,
        mlp_hidden_dims: List[int] = [256, 64],
        model_type: str = 'dinov2_vits14',
        use_cls_token: bool = True,
        dropout: float = 0.0,
        framestack: int = 1
    ):
        super().__init__()
        
        self.num_views = num_views
        self.visual_feature_dim = visual_feature_dim
        self.use_cls_token = use_cls_token
        self.framestack = framestack
        
        # 加载预训练的 DINOv2 模型
        print(f"Loading {model_type} from torch.hub...")
        self.dinov2_backbone = torch.hub.load('facebookresearch/dinov2', model_type)
        
        # 冻结 DINOv2 的所有参数
        for param in self.dinov2_backbone.parameters():
            param.requires_grad = False
        
        self.dinov2_backbone.eval()  # 设置为评估模式
        
        # 获取 DINOv2 的输出维度
        # dinov2_vits14: 384
        # dinov2_vitb14: 768
        # dinov2_vitl14: 1024
        # dinov2_vitg14: 1536
        self.dinov2_output_dim = self._get_dinov2_output_dim()
        
        # 每个视角的特征维度
        single_view_dim = self.dinov2_output_dim
        
        # 所有视角拼接后的特征维度
        concatenated_visual_dim = single_view_dim * num_views
        
        # 图像预处理
        self.transform = tvt.Compose([
            tvt.ConvertImageDtype(torch.float32), 
            tvt.Resize((224, 224)),
            tvt.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
        ])
        
        # 构建 MLP 来调整视觉表征（对单帧的所有视角）
        mlp_layers = []
        input_dim = concatenated_visual_dim
        
        for hidden_dim in mlp_hidden_dims:
            mlp_layers.append(nn.Linear(input_dim, hidden_dim))
            mlp_layers.append(nn.ReLU())
            if dropout > 0:
                mlp_layers.append(nn.Dropout(dropout))
            input_dim = hidden_dim
        
        # 最后一层输出到指定的视觉特征维度
        mlp_layers.append(nn.Linear(input_dim, visual_feature_dim))
        
        self.visual_mlp = nn.Sequential(*mlp_layers)
        
        # 计算最终输出维度（visual_feature_dim * framestack）
        self.output_dim = visual_feature_dim * framestack
        
        print(f"DINOv2 Encoder initialized:")
        print(f"  - Model type: {model_type}")
        print(f"  - Number of views: {num_views}")
        print(f"  - Framestack: {framestack}")
        print(f"  - DINOv2 output dim: {self.dinov2_output_dim}")
        print(f"  - Concatenated visual dim (per frame): {concatenated_visual_dim}")
        print(f"  - Visual feature dim (per frame after MLP): {visual_feature_dim}")
        print(f"  - Total output dim: {self.output_dim} (= {visual_feature_dim} * {framestack})")
    
    def _get_dinov2_output_dim(self) -> int:
        """获取 DINOv2 的输出维度"""
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 224, 224)
            if torch.cuda.is_available():
                dummy_input = dummy_input.cuda()
                self.dinov2_backbone = self.dinov2_backbone.cuda()
            
            output = self.dinov2_backbone(dummy_input)
            
            if self.use_cls_token:
                # 使用 CLS token
                dim = output.shape[-1]
            else:
                # 如果有 patch tokens，使用平均池化
                dim = output.shape[-1]
            
            return dim
    
    @torch.no_grad()
    def _extract_dinov2_features(self, images: torch.Tensor) -> torch.Tensor:
        """
        使用 DINOv2 提取图像特征（冻结模式）
        
        Args:
            images: (B, C, H, W) 输入图像
            
        Returns:
            features: (B, dinov2_output_dim) 特征向量
        """
        self.dinov2_backbone.eval()  # 确保在评估模式
        # DINOv2 期望输入尺寸为 224x224
        images = self.transform(images)
        
        # 提取特征
        features = self.dinov2_backbone(images)
        
        # 如果输出是多维的，取 CLS token 或做平均池化
        if len(features.shape) == 3:  # (B, num_tokens, dim)
            if self.use_cls_token:
                features = features[:, 0]  # 取 CLS token
            else:
                features = features.mean(dim=1)  # 全局平均池化
        
        return features
    
    def forward(
        self,
        images: torch.Tensor,
    ) -> torch.Tensor:
        """
        前向传播（批量处理所有帧）
        
        Args:
            images: 多视角图像（可能包含 framestack）
                - 如果是单个 batch: (B, num_views, C * framestack, H, W)
                - 需要将 framestack 维度分离，批量编码所有帧
        
        Returns:
            encoded_obs: (B, visual_feature_dim * framestack) 编码后的观测
        """
        if len(images.shape) != 5:
            raise ValueError(f"Expected 5D input (B, num_views, C*framestack, H, W), got shape {images.shape}")
        
        B, V, C_stacked, H, W = images.shape
        
        if V != self.num_views:
            raise ValueError(f"Expected {self.num_views} views, got {V}")
        
        # 分离 framestack 维度
        # C_stacked = C * framestack, 其中 C = 3 (RGB)
        C = 3
        if C_stacked % C != 0:
            raise ValueError(f"Channel dimension {C_stacked} is not divisible by {C}. Expected C*framestack format.")
        
        actual_framestack = C_stacked // C
        if actual_framestack != self.framestack:
            raise ValueError(f"Expected framestack={self.framestack}, but got {actual_framestack} from input shape")
        
        # 重塑为 (B, num_views, framestack, C, H, W)
        images = images.view(B, V, self.framestack, C, H, W)
        
        # 批量处理：将所有帧和视角展平为 (B * num_views * framestack, C, H, W)
        images_flat = images.permute(0, 2, 1, 3, 4, 5).contiguous()  # (B, framestack, V, C, H, W)
        images_flat = images_flat.view(B * self.framestack * V, C, H, W)  # (B*F*V, C, H, W)
        
        # 批量通过 DINOv2 提取特征（冻结模式）
        with torch.no_grad():
            visual_features = self._extract_dinov2_features(images_flat)
        # visual_features: (B * framestack * V, dinov2_output_dim)
        
        # 重塑为 (B, framestack, V, dinov2_output_dim)
        visual_features = visual_features.view(B, self.framestack, V, self.dinov2_output_dim)
        
        # 拼接所有视角: (B, framestack, V * dinov2_output_dim)
        visual_features = visual_features.view(B, self.framestack, -1)
        
        # 批量通过 MLP 调整表征: (B, framestack, visual_feature_dim)
        # 将 framestack 维度展平以批量处理
        visual_features_flat = visual_features.view(B * self.framestack, -1)  # (B*F, V*dinov2_output_dim)
        frame_embeddings = self.visual_mlp(visual_features_flat)  # (B*F, visual_feature_dim)
        frame_embeddings = frame_embeddings.view(B, self.framestack, self.visual_feature_dim)  # (B, F, visual_feature_dim)
        
        # 展平所有帧的特征: (B, visual_feature_dim * framestack)
        encoded_obs = frame_embeddings.view(B, -1)
        
        return encoded_obs
    
    def train(self, mode: bool = True):
        """重写 train 方法，确保 DINOv2 始终保持冻结"""
        super().train(mode)
        # 始终保持 DINOv2 在评估模式
        self.dinov2_backbone.eval()
        for param in self.dinov2_backbone.parameters():
            param.requires_grad = False
        return self


class DINOv2Encoder(DINOv2MultiViewEncoder):
    """
    单视角 DINOv2 编码器的便捷类
    """
    def __init__(
        self,
        visual_feature_dim: int = 64,
        mlp_hidden_dims: List[int] = [256, 64],
        model_type: str = 'dinov2_vits14',
        use_cls_token: bool = True,
        dropout: float = 0.0,
        framestack: int = 1
    ):
        super().__init__(
            num_views=1,
            visual_feature_dim=visual_feature_dim,
            mlp_hidden_dims=mlp_hidden_dims,
            model_type=model_type,
            use_cls_token=use_cls_token,
            dropout=dropout,
            framestack=framestack
        )


# 测试代码
if __name__ == "__main__":
    print("Testing DINOv2MultiViewEncoder...")
    
    # 测试多视角编码器
    encoder = DINOv2MultiViewEncoder(
        num_views=3,
        visual_feature_dim=256,
        mlp_hidden_dims=[512, 256],
        model_type='dinov2_vits14'
    )
    
    # 测试输入
    batch_size = 4
    images = torch.randn(batch_size, 3, 3, 224, 224)  # 3 视角
    state = torch.randn(batch_size, 10)
    
    # 前向传播
    output = encoder(images, state)
    print(f"Output shape: {output.shape}")
    print(f"Expected shape: ({batch_size}, {encoder.output_dim})")
    
    assert output.shape == (batch_size, encoder.output_dim)
    print("Test passed!")
    
    # 测试单视角编码器
    print("\nTesting DINOv2Encoder (single view)...")
    single_encoder = DINOv2Encoder(
        visual_feature_dim=256,
        model_type='dinov2_vits14'
    )
    
    single_images = torch.randn(batch_size, 3, 224, 224)
    single_output = single_encoder(single_images, state)
    print(f"Single view output shape: {single_output.shape}")
    print(f"Expected shape: ({batch_size}, {single_encoder.output_dim})")
    
    assert single_output.shape == (batch_size, single_encoder.output_dim)
    print("Single view test passed!")
