from dataclasses import dataclass, field
from typing import Tuple
import numpy as np
import torch 
import tree
from huggingface_hub import snapshot_download
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.feature_extraction_utils import BatchFeature
from vggt.models.vggt import VGGT
from ultralytics import YOLO
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
from PIL import Image
import torchvision.transforms.functional as TF
from collections import defaultdict
import hdbscan
from tqdm import tqdm 
from torch import nn
from torch.nn.utils.rnn import pad_sequence
import math
import torchvision.ops as ops
import time
from .action_head.flow_matching_action_head import (
    FlowmatchingActionHead,
    FlowmatchingActionHeadConfig,
)
from .backbone import EagleBackbone
BACKBONE_FEATURE_KEY = "backbone_features"
ACTION_KEY = "action_pred"
LOSS_KEY = "loss"
ERROR_MSG = "Error: unexpected input/output"
N_COLOR_CHANNELS = 3
@dataclass
class GR00T_N1_5_Config(PretrainedConfig):
    model_type = "gr00t_n1_5"
    backbone_cfg: dict = field(init=False, metadata={"help": "Backbone configuration."})
    action_head_cfg: dict = field(init=False, metadata={"help": "Action head configuration."})
    action_horizon: int = field(init=False, metadata={"help": "Action horizon."})
    action_dim: int = field(init=False, metadata={"help": "Action dimension."})
    compute_dtype: str = field(default="float32", metadata={"help": "Compute dtype."})
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for key, value in kwargs.items():
            setattr(self, key, value)
class SOMA(PreTrainedModel):
    supports_gradient_checkpointing = True
    config_class = GR00T_N1_5_Config
    """
    we expect the backbone output to have a key 'backbone_features' with shape (batch_size, n, hidden_size)
    here n is variable and can be e.g. time, 1 or user specified
    we expect the action head output to have a key 'action_pred' with shape (batch_size, time, action_dim) during inference time
    we expect these to have type BatchFeature, and they can of course have many other user specified keys too
    """
    def __init__(
        self,
        config: GR00T_N1_5_Config,
        local_model_path: str,
    ):
        assert isinstance(config.backbone_cfg, dict)
        assert isinstance(config.action_head_cfg, dict)
        super().__init__(config)
        self.local_model_path = local_model_path
        self.backbone = EagleBackbone(**config.backbone_cfg)
        action_head_cfg = FlowmatchingActionHeadConfig(**config.action_head_cfg)
        self.action_head = FlowmatchingActionHead(action_head_cfg)
        self.action_horizon = config.action_horizon
        self.action_dim = config.action_dim
        self.compute_dtype = config.compute_dtype
        self.n_layer = None
        self.perception_head = None
        self.geometry_head = None
        self.det_head = None
        self.memory_mlp_hidden_dim = 384
        self.bbox_embedding_mlp = torch.nn.Sequential(
            nn.LayerNorm(8 * 3),
            nn.Linear(8 * 3, self.memory_mlp_hidden_dim*2),
            nn.GELU(),
            nn.Linear(self.memory_mlp_hidden_dim*2, self.memory_mlp_hidden_dim*2),
            nn.GELU(),
            nn.Linear(self.memory_mlp_hidden_dim*2, self.memory_mlp_hidden_dim),
            nn.GELU(),
            nn.LayerNorm(self.memory_mlp_hidden_dim)
        )
        self.spatial_pos_mlp = torch.nn.Sequential(
            nn.LayerNorm(self.memory_mlp_hidden_dim),
            nn.Linear(self.memory_mlp_hidden_dim, self.memory_mlp_hidden_dim*2),
            nn.GELU(),
            nn.Linear(self.memory_mlp_hidden_dim*2, self.memory_mlp_hidden_dim*2),
            nn.GELU(),
            nn.Linear(self.memory_mlp_hidden_dim*2, self.memory_mlp_hidden_dim),
            nn.GELU(),
            nn.LayerNorm(self.memory_mlp_hidden_dim)
        )
        self.memory_mlp = torch.nn.Sequential(
            nn.LayerNorm(self.memory_mlp_hidden_dim),
            nn.Linear(self.memory_mlp_hidden_dim, self.memory_mlp_hidden_dim*2),
            nn.GELU(),
            nn.Linear(self.memory_mlp_hidden_dim*2, self.memory_mlp_hidden_dim),
            nn.GELU(),
            nn.LayerNorm(self.memory_mlp_hidden_dim)
        )
        self.sim_mlp = torch.nn.Sequential(
            nn.LayerNorm(self.memory_mlp_hidden_dim*3),
            nn.Linear(self.memory_mlp_hidden_dim*3, self.memory_mlp_hidden_dim*2),
            nn.GELU(),
            nn.Linear(self.memory_mlp_hidden_dim*2, self.memory_mlp_hidden_dim),
            nn.GELU(),
            nn.LayerNorm(self.memory_mlp_hidden_dim)
        )
        self.fusion_mlp = torch.nn.Sequential(
            nn.LayerNorm(self.memory_mlp_hidden_dim*2),
            nn.Linear(self.memory_mlp_hidden_dim*2, self.memory_mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(self.memory_mlp_hidden_dim, self.memory_mlp_hidden_dim),
            nn.GELU(),
            nn.LayerNorm(self.memory_mlp_hidden_dim)
        )   
        from .action_head.cross_attention_dit import SelfAttentionTransformerMaskQuery
        attn_cfg = {'attention_head_dim': 64, 
                    'dropout': 0.2, 
                    'final_dropout': True, 
                    'num_attention_heads': 32, 
                    'num_layers': 3, 
                    'positional_embeddings': None}
        self.boosting_transformer = SelfAttentionTransformerMaskQuery(**attn_cfg)
        self.memory_alignment_mlp = torch.nn.Sequential(
            nn.LayerNorm(self.memory_mlp_hidden_dim),
            nn.Linear(self.memory_mlp_hidden_dim, self.memory_mlp_hidden_dim*2),
            nn.GELU(),
            nn.Linear(self.memory_mlp_hidden_dim*2, action_head_cfg.backbone_embedding_dim),
            nn.GELU(),
            nn.LayerNorm(action_head_cfg.backbone_embedding_dim)
        )
    def init_weights(self):
        self.bbox_embedding_mlp.apply(self._init_layer_weight)
        self.spatial_pos_mlp.apply(self._init_layer_weight)
        self.memory_mlp.apply(self._init_layer_weight)
        self.sim_mlp.apply(self._init_layer_weight)
        self.fusion_mlp.apply(self._init_layer_weight)
        self.memory_alignment_mlp.apply(self._init_layer_weight)
        self.boosting_transformer.apply(self._init_layer_weight)
    def _init_layer_weight(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
            if m.bias is not None:
                fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(m.weight)
                bound = 1 / math.sqrt(fan_in)
                torch.nn.init.uniform_(m.bias, -bound, bound)
        elif isinstance(m, torch.nn.Embedding):
            torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
        elif isinstance(m, torch.nn.LayerNorm):
            torch.nn.init.ones_(m.weight)
            torch.nn.init.zeros_(m.bias)
    def embed_bbox_3d(self, bbox_3d_corners):
        """
        Args:
            bbox_3d_corners: torch.Tensor, shape (N, 8, 3)
        Returns:
            torch.Tensor, shape (N, bbox_embedding_dim)
        """
        if isinstance(bbox_3d_corners, np.ndarray):
            bbox_3d_corners = torch.from_numpy(bbox_3d_corners).float().to(self.device)
        elif not isinstance(bbox_3d_corners, torch.Tensor):
            raise TypeError("bbox_3d_corners must be torch.Tensor or np.ndarray")
        N = bbox_3d_corners.shape[0]
        x = bbox_3d_corners.reshape(N, 8 * 3)
        return self.bbox_embedding_mlp(x)
    def process_spatial_memory_v3(self, input):
        batchsize = len(input['overview'])
        final_obs_memory = []
        final_obs_labels = [ ]
        for i in range(batchsize):
            overview_obs_memory_i = input['overview'][i]['overview_memory']
            overview_obs_labels_i = input['overview'][i]['overview_labels']
            overview_obs_bbox_3d_labels_i = input['overview'][i]['bbox_3d_labels']
            overview_obs_bbox_3d_corners_i = np.array(input['overview'][i]['bbox_3d_corners'])
            if isinstance(overview_obs_memory_i, np.ndarray):
                overview_obs_memory_i = torch.from_numpy(overview_obs_memory_i).float().to(self.device)
            elif not isinstance(overview_obs_memory_i, torch.Tensor):
                raise TypeError("overview_memory must be torch.Tensor or np.ndarray")
            overview_obs_bbox_pos_embedding = self.embed_bbox_3d(overview_obs_bbox_3d_corners_i)
            grouped_spatial_embeddings = []
            for label in overview_obs_labels_i:
                mask = (np.array(overview_obs_bbox_3d_labels_i) == label)
                if np.any(mask):
                    mean_embedding = overview_obs_bbox_pos_embedding[mask].mean(dim=0)
                else:
                    mean_embedding = torch.empty_like(overview_obs_memory_i[0]).normal_()
                grouped_spatial_embeddings.append(mean_embedding)
            grouped_spatial_embeddings = torch.stack(grouped_spatial_embeddings, dim=0)
            spatial_pos_cls_i = self.spatial_pos_mlp(grouped_spatial_embeddings)
            process_overview_memory_i = self.memory_mlp(overview_obs_memory_i) + spatial_pos_cls_i
            current_obs_memory_i = input['current_obs'][i]['overview_memory']
            current_obs_labels_i = input['current_obs'][i]['overview_labels']
            current_obs_bbox_3d_labels_i = input['current_obs'][i]['bbox_3d_labels']
            current_obs_bbox_3d_corners_i = np.array(input['current_obs'][i]['bbox_3d_corners'])
            if isinstance(current_obs_memory_i, np.ndarray):
                current_obs_memory_i = torch.from_numpy(current_obs_memory_i).float().to(self.device)
            elif not isinstance(current_obs_memory_i, torch.Tensor):
                raise TypeError("overview_memory must be torch.Tensor or np.ndarray")
            current_obs_bbox_pos_embedding = self.embed_bbox_3d(current_obs_bbox_3d_corners_i)
            grouped_spatial_embeddings = []
            for label in current_obs_labels_i:
                mask = (np.array(current_obs_bbox_3d_labels_i) == label)
                if mask.any():
                    mean_embedding = current_obs_bbox_pos_embedding[mask].mean(dim=0)
                else:
                    mean_embedding = torch.empty_like(process_overview_memory_i[0]).normal_()
                grouped_spatial_embeddings.append(mean_embedding)
            grouped_spatial_embeddings = torch.stack(grouped_spatial_embeddings, dim=0)
            spatial_pos_cls_i = self.spatial_pos_mlp(grouped_spatial_embeddings)
            process_current_memory_i = self.memory_mlp(current_obs_memory_i) + spatial_pos_cls_i
            per_final_obs_memory, per_final_obs_label = [], []
            for j, label in enumerate(current_obs_labels_i):
                if label in overview_obs_labels_i:
                    idx = list(overview_obs_labels_i).index(label)
                    mem_vec = process_overview_memory_i[idx]
                    new_vec = process_current_memory_i[j]
                    sim = torch.sigmoid(self.sim_mlp(torch.cat([mem_vec, new_vec, mem_vec - new_vec], dim=-1)))
                    fusion_gate = torch.sigmoid(self.fusion_mlp(torch.cat([mem_vec, new_vec], dim=-1)))
                    alpha = fusion_gate * sim
                    updated_vec = alpha * new_vec + (1 - alpha) * mem_vec
                    per_final_obs_memory.append(updated_vec)
                    per_final_obs_label.append(label)
                else:
                    per_final_obs_memory.append(process_current_memory_i[j])
                    per_final_obs_label.append(label)
            for j, label in enumerate(overview_obs_labels_i):
                if label not in current_obs_labels_i:
                    per_final_obs_memory.append(process_overview_memory_i[j])
                    per_final_obs_label.append(label)
            per_final_obs_memory = torch.stack(per_final_obs_memory)
            final_obs_memory.append(per_final_obs_memory)
        return final_obs_memory, final_obs_labels
    def filtering_keys(self, input): 
        if 'overview' in input:
            del input['overview']
        if 'current_obs' in input:
            del input['current_obs']
        return input
    def validate_inputs(self, inputs):
        detected_error = False
        error_msg = ERROR_MSG
        if "action" in inputs:
            action = inputs["action"]
            type_ok = isinstance(action, torch.Tensor)
            shape_ok = (
                len(action.shape) == 3
                and action.shape[1] == self.action_horizon
                and action.shape[2] == self.action_dim
            )
            if not type_ok:
                error_msg += f"\n{action.dtype=}"
                detected_error = True
            if not shape_ok:
                error_msg += f"\n{action.shape=}"
                detected_error = True
        if "video" in inputs:
            video = inputs["video"]
            type_ok = isinstance(video, np.ndarray)
            dtype_ok = video.dtype == np.uint8
            shape_ok = len(video.shape) == 6 and video.shape[3] == N_COLOR_CHANNELS
            if not type_ok:
                error_msg += f"\n{type(video)=}"
                detected_error = True
            if not dtype_ok:
                error_msg += f"\n{video.dtype=}"
                detected_error = True
            if not shape_ok:
                error_msg += f"\n{video.shape=}"
                detected_error = True
        if detected_error:
            raise ValueError(error_msg)
    def validate_data(self, action_head_outputs, backbone_outputs, is_training):
        fail_backbone = (
            not isinstance(backbone_outputs, BatchFeature)
            or BACKBONE_FEATURE_KEY not in backbone_outputs
        )
        if fail_backbone:
            error_msg = ERROR_MSG
            error_msg += f"\n{isinstance(backbone_outputs, BatchFeature)=}"
            error_msg += f"\n{BACKBONE_FEATURE_KEY in backbone_outputs=}"
            error_msg += f"\n{backbone_outputs[BACKBONE_FEATURE_KEY].shape=}"
            raise ValueError(error_msg)
        fail_action_head = (not isinstance(action_head_outputs, BatchFeature)) or not (
            (
                LOSS_KEY in action_head_outputs and is_training
            )
            or (
                ACTION_KEY in action_head_outputs
                and action_head_outputs[ACTION_KEY].shape[1] == self.action_horizon
                and action_head_outputs[ACTION_KEY].shape[2] == self.action_dim
            )
        )
        if fail_action_head:
            error_msg = ERROR_MSG
            error_msg += f"\n{isinstance(action_head_outputs, BatchFeature)=}"
            error_msg += f"\n{LOSS_KEY in action_head_outputs=}"
            error_msg += f"\n{action_head_outputs[ACTION_KEY].shape=}"
            error_msg += f"\n{self.action_horizon=}"
            error_msg += f"\n{self.action_dim=}"
            raise ValueError(error_msg)
    def get_opengl_conversion_matrix(self) -> np.ndarray:
        """
        Constructs and returns the OpenGL conversion matrix.
        Returns:
            numpy.ndarray: A 4x4 OpenGL conversion matrix.
        """
        matrix = np.identity(4)
        matrix[1, 1] = -1
        matrix[2, 2] = -1
        return matrix
    def boost_backbone_outputs_spatial(self, backbone_outputs, memories):
        backbone_features = backbone_outputs['backbone_features']  # (B, T, D)
        total_memory = pad_sequence(memories, batch_first=True)  
        lengths = torch.tensor([m.size(0) for m in memories], device=total_memory.device)
        max_len = total_memory.size(1)
        memory_mask = torch.arange(max_len, device=total_memory.device)[None, :] < lengths[:, None]
        total_memory_expand = self.memory_alignment_mlp(total_memory)
        boosted_backbone_features = self.boosting_transformer(
            backbone_features,
            attention_mask=None,
            encoder_hidden_states=total_memory_expand,
            encoder_attention_mask=memory_mask)
        backbone_outputs['backbone_features'] = boosted_backbone_features
        return backbone_outputs
    def forward(
        self,
        inputs: dict,
    ) -> BatchFeature:
        final_obs_memory, final_obs_labels = self.process_spatial_memory_v3(inputs)
        inputs = self.filtering_keys(inputs)
        backbone_inputs, action_inputs = self.prepare_input(inputs)
        backbone_outputs = self.backbone(backbone_inputs)
        backbone_outputs = self.boost_backbone_outputs_spatial(backbone_outputs, final_obs_memory)
        action_head_outputs = self.action_head(backbone_outputs, action_inputs)
        self.validate_data(action_head_outputs, backbone_outputs, is_training=True)
        return action_head_outputs
    def get_action(
        self,
        inputs: dict,
    ) -> BatchFeature:
        backbone_inputs, action_inputs = self.prepare_input(inputs)
        backbone_outputs = self.backbone(backbone_inputs)
        action_head_outputs = self.action_head.get_action(backbone_outputs, action_inputs)
        self.validate_data(action_head_outputs, backbone_outputs, is_training=False)
        return action_head_outputs
    def prepare_input(self, inputs) -> Tuple[BatchFeature, BatchFeature]:
        self.validate_inputs(inputs)
        backbone_inputs = self.backbone.prepare_input(inputs)
        action_inputs = self.action_head.prepare_input(inputs)
        def to_device_with_maybe_dtype(x):
            if isinstance(x, np.ndarray):
                x = torch.from_numpy(x)
            if torch.is_floating_point(x):
                return x.to(self.device, dtype=self.action_head.dtype)
            else:
                return x.to(self.device)
        backbone_inputs = tree.map_structure(to_device_with_maybe_dtype, backbone_inputs)
        action_inputs = tree.map_structure(to_device_with_maybe_dtype, action_inputs)
        return backbone_inputs, action_inputs
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
        tune_visual = kwargs.pop("tune_visual", True)
        tune_llm = kwargs.pop("tune_llm", False)
        tune_projector = kwargs.pop("tune_projector", True)
        tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
        print(f"Loading pretrained dual brain from {pretrained_model_name_or_path}")
        print(f"Tune backbone vision tower: {tune_visual}")
        print(f"Tune backbone LLM: {tune_llm}")
        print(f"Tune action head projector: {tune_projector}")
        print(f"Tune action head DiT: {tune_diffusion_model}")
        try:
            local_model_path = snapshot_download(pretrained_model_name_or_path, repo_type="model")
        except (HFValidationError, RepositoryNotFoundError):
            print(
                f"Model not found or avail in the huggingface hub. Loading from local path: {pretrained_model_name_or_path}"
            )
            local_model_path = pretrained_model_name_or_path
        pretrained_model = super().from_pretrained(
            local_model_path, local_model_path=local_model_path, **kwargs
        )
        pretrained_model.backbone.set_trainable_parameters(
            tune_visual=tune_visual, tune_llm=tune_llm
        )
        pretrained_model.action_head.set_trainable_parameters(
            tune_projector=tune_projector, tune_diffusion_model=tune_diffusion_model
        )
        pretrained_model.init_weights()
        return pretrained_model
AutoConfig.register("gr00t_n1_5", GR00T_N1_5_Config)
AutoModel.register(GR00T_N1_5_Config, SOMA)
