#!/usr/bin/env python
#
# Mimic experiment implementation for Flow Matching with Hydra

import sys
import os
import time
import collections
from typing import Dict, Any, Tuple, Optional

import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader

# Add paths for imports
sys.path.append('../models')
sys.path.append('../external/robomimic')
sys.path.append('../external/diffusion_policy')

from unet import ConditionalUnet1D
from external.models.TransformerForDiffusion import TransformerForDiffusion
from diffusion_policy.dataset.robomimic_replay_lowdim_dataset import RobomimicReplayLowdimDataset
from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.env.robomimic.robomimic_lowdim_wrapper import RobomimicLowdimWrapper

# Import robomimic utilities
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.env_utils as EnvUtils

# Import parent class and Hydra config manager
sys.path.append('../utils')
from base_experiment import BaseExperiment
from hydra_config import HydraConfigManager

class MimicExperimentHydra(BaseExperiment):
    """Mimic manipulation experiment using Flow Matching with Hydra configuration."""
    
    def __init__(self, config):
        # Convert Hydra config to compatible format
        self.hydra_config = config
        self.config = HydraConfigManager(config)
            
        super().__init__(self.config)
        self.normalizer = None
        self.rotation_transformer = None
        
    def setup_dataset(self) -> DataLoader:
        """Setup Robomimic dataset using RobomimicReplayLowdimDataset."""
        print("Setting up Robomimic dataset...")
        
        # Get dataset kwargs
        kwargs = self.config.get_dataset_kwargs()
        
        # Create dataset using RobomimicReplayLowdimDataset (like in flow_mimic.py)
        dataset = RobomimicReplayLowdimDataset(
            dataset_path=kwargs['hdf5_path'],
            horizon=kwargs['seq_length'],  # pred_horizon
            abs_action=True,
        )
        
        # Get normalizer from dataset and load into LinearNormalizer
        dataset_normalizer = dataset.get_normalizer()
        self.normalizer = LinearNormalizer()
        self.normalizer.load_state_dict(dataset_normalizer.state_dict())
        
        print(f"Dataset size: {len(dataset)}")
        
        dataloader = DataLoader(
            dataset,
            batch_size=self.config.batch_size,
            num_workers=self.config.num_workers,
            shuffle=True,
            pin_memory=self.config.pin_memory,
            persistent_workers=self.config.persistent_workers
        )
        
        return dataloader
    
    def setup_model(self) -> nn.Module:
        """Setup Mimic model (low-dim observations, no vision encoder)."""
        print(f"Setting up Mimic model (type: {self.config.model_type})...")

        # Mimic uses low-dim observations, no vision encoder needed
        if self.config.model_type == "transformer":
            noise_pred_net = TransformerForDiffusion(
                input_dim=self.config.action_dim,
                output_dim=self.config.action_dim,
                horizon=self.config.pred_horizon,
                cond_dim=self.config.vision_feature_dim
            )
        else:  # unet
            noise_pred_net = ConditionalUnet1D(
                input_dim=self.config.action_dim,
                global_cond_dim=self.config.vision_feature_dim
            )

        # For low-dim observations, we only need the noise prediction network
        nets = nn.ModuleDict({
            'noise_pred_net': noise_pred_net
        }).to(self.device)

        return nets
    
    def process_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]:
        """Process Mimic batch data (low-dim observations)."""
        # Normalize the batch using the normalizer (like in flow_mimic.py)
        normalized_batch = self.normalizer.normalize(batch)
        
        # Extract observations and actions
        x_obs = normalized_batch['obs'][:, :self.config.obs_horizon].to(self.device)
        x_traj = normalized_batch['action'].to(self.device)
        
        # For Flow Matching, we need to flatten the obs_horizon dimension
        # x_obs shape: [batch_size, obs_horizon, feature_dim] -> [batch_size, obs_horizon * feature_dim]
        obs_cond = x_obs.flatten(start_dim=1)
        
        return obs_cond, x_traj
    

    def setup_environment(self):
        """Setup Mimic evaluation environment."""
        """This part is for testing"""
        print("Setting up Mimic environment...")
        
        try:
            # Get environment metadata from dataset
            env_meta = FileUtils.get_env_metadata_from_dataset(self.config.dataset_path)
            
            # Create environment using robomimic utilities
            env = EnvUtils.create_env_from_metadata(
                env_meta=env_meta,
                **self.config.get_env_kwargs()
            )
            
            # Create wrapper for low-dimensional observations
            wrapper = RobomimicLowdimWrapper(
                env=env,
                **self.config.get_wrapper_kwargs()
            )
            
            print(f"Environment created successfully: {env_meta['env_name']}")
            return wrapper
            
        except Exception as e:
            print(f"Error creating environment: {e}")
            print("Environment setup failed. Please check OpenGL/EGL dependencies.")
            raise e
