#!/usr/bin/env python
"""
Hydra configuration manager for Flow Matching experiments.
This module provides compatibility between Hydra configs and existing experiment code.
"""

import os
from typing import Dict, Any, Optional
from omegaconf import DictConfig, OmegaConf
import torch


class HydraConfigManager:
    """Manages Hydra configuration and provides compatibility with existing code."""
    
    def __init__(self, cfg: DictConfig):
        """Initialize with Hydra configuration."""
        self.cfg = cfg
        self._processed_config = self._process_config()
    
    def _process_config(self) -> Dict[str, Any]:
        """Process Hydra configuration into a flat dictionary compatible with existing code."""
        config_dict = {}
        
        # Extract environment configuration
        if hasattr(self.cfg, 'env'):
            env = self.cfg.env
            print(f"DEBUG: Found env config: {env}")
            print(f"DEBUG: env.dataset.path = {env.dataset.path}")
            config_dict.update({
                'experiment_name': env.name,
                'dataset_path': env.dataset.path,
                'dataset_type': env.dataset.type,
                'obs_horizon': env.dataset.obs_horizon,
                'pred_horizon': env.dataset.pred_horizon,
                'action_horizon': env.dataset.action_horizon,
                'action_dim': env.dataset.action_dim,
                'vision_feature_dim': env.dataset.vision_feature_dim,
                'obs_keys': env.obs_keys,
                'max_steps': env.environment.max_steps,
                'success_threshold': env.environment.success_threshold,
                'render_mode': env.environment.render_mode,
                'render_hw': env.environment.render_hw,
            })
        
        # Extract execution configuration
        if hasattr(self.cfg, 'execution'):
            exec_cfg = self.cfg.execution
            config_dict.update({
                'mode': exec_cfg.mode,
                'type': exec_cfg.type,
            })
        
        # Extract training configuration from main config
        if hasattr(self.cfg, 'training'):
            training = self.cfg.training
            config_dict.update({
                'epochs': training.epochs,
                'batch_size': training.batch_size,
                'learning_rate': training.learning_rate,
                'weight_decay': training.weight_decay,
                'ema_power': training.ema_power,
                'save_interval': getattr(training, 'save_interval', 25),
            })
        
        # Extract dataloader configuration from main config
        if hasattr(self.cfg, 'dataloader'):
            dataloader = self.cfg.dataloader
            config_dict.update({
                'num_workers': dataloader.num_workers,
                'pin_memory': dataloader.pin_memory,
                'persistent_workers': dataloader.persistent_workers,
            })
        
        # Extract checkpoint configuration from main config (only if not already set by execution)
        if hasattr(self.cfg, 'checkpoint') and 'checkpoint_dir' not in config_dict:
            checkpoint = self.cfg.checkpoint
            config_dict.update({
                'checkpoint_dir': checkpoint.save_dir,
                'load_checkpoint': checkpoint.load_checkpoint,
            })
        
        # Extract validation configuration from main config
        if hasattr(self.cfg, 'validation'):
            validation = self.cfg.validation
            config_dict.update({
                'validation_enabled': validation.enabled,
                'validation_interval': validation.interval,
                'val_split': validation.val_split,
            })
        
        # Extract MLE configuration from main config
        if hasattr(self.cfg, 'mle'):
            mle = self.cfg.mle
            config_dict.update({
                'mle_learning_rate': mle.learning_rate,
                'mle_solver_type': mle.solver.type,
                'mle_time_steps': mle.solver.time_steps,
            })
        

        
        # Extract testing configuration
        if hasattr(self.cfg, 'testing'):
            testing = self.cfg.testing
            config_dict.update({
                'test_start_seed': testing.start_seed,
                'test_episodes': testing.episodes,
                'test_runs_per_episode': testing.runs_per_episode,
            })
        
        # Extract results configuration
        if hasattr(self.cfg, 'results'):
            results = self.cfg.results
            config_dict.update({
                'results_save_dir': results.save_dir,
                'save_trajectories': results.save_trajectories,
                'save_videos': results.save_videos,
            })
        
        # Extract evaluation configuration
        if hasattr(self.cfg, 'evaluation'):
            evaluation = self.cfg.evaluation
            config_dict.update({
                'evaluation_metrics': evaluation.metrics,
                'evaluation_render': evaluation.render,
            })
        
        # Extract model configuration
        if hasattr(self.cfg, 'model'):
            model = self.cfg.model
            config_dict.update({
                'model_type': model.type,
                'model_architecture': model.architecture,
            })
            
            # Extract UNet specific parameters
            if hasattr(model, 'unet'):
                unet = model.unet
                config_dict.update({
                    'unet_hidden_size': unet.hidden_size,
                    'unet_num_blocks': unet.num_blocks,
                    'unet_num_layers_per_block': unet.num_layers_per_block,
                    'unet_use_attention': unet.use_attention,
                })
            
            # Extract Transformer specific parameters
            if hasattr(model, 'transformer'):
                transformer = model.transformer
                config_dict.update({
                    'transformer_hidden_dim': transformer.hidden_dim,
                    'transformer_num_layers': transformer.num_layers,
                    'transformer_num_heads': transformer.num_heads,
                    'transformer_dropout': transformer.dropout,
                })
            
            # Extract conditioning parameters
            if hasattr(model, 'conditioning'):
                conditioning = model.conditioning
                config_dict.update({
                    'use_vision_encoder': conditioning.use_vision_encoder,
                    'vision_encoder_type': conditioning.vision_encoder_type,
                    'replace_bn_with_gn': conditioning.replace_bn_with_gn,
                })
        
        # Extract global configuration
        if hasattr(self.cfg, 'device'):
            config_dict['device'] = self.cfg.device
        else:
            config_dict['device'] = "cuda" if torch.cuda.is_available() else "cpu"
        
        if hasattr(self.cfg, 'seed'):
            config_dict['seed'] = self.cfg.seed
        
        if hasattr(self.cfg, 'ot_model'):
            config_dict['ot_model'] = self.cfg.ot_model
        
        if hasattr(self.cfg, 'logging'):
            logging = self.cfg.logging
            config_dict.update({
                'log_level': logging.level,
                'log_format': logging.format,
            })
        
        # Set default values for missing parameters
        self._set_defaults(config_dict)
        
        return config_dict
    
    def _set_defaults(self, config_dict: Dict[str, Any]):
        """Set default values for missing parameters."""
        defaults = {
            'experiment_name': 'default_experiment',
            'dataset_path': '',
            'dataset_type': 'base',
            'obs_horizon': 1,
            'pred_horizon': 16,
            'action_horizon': 8,
            'action_dim': 2,
            'vision_feature_dim': 512,
            'obs_keys': ['obs'],
            'max_steps': 300,
            'success_threshold': 1.0,
            'render_mode': 'rgb_array',
            'render_hw': (512, 512),
            'mode': 'train',
            'type': 'flow_matching',
            'epochs': 1000,
            'batch_size': 64,
            'learning_rate': 1e-4,
            'weight_decay': 1e-6,
            'ema_power': 0.75,
            'ot_model': 'otcfm',  # Flow matching model type
            'sigma': 0.1,  # Flow matching noise scale parameter
            'lr_scheduler_type': 'cosine',
            'warmup_steps': 500,
            'num_workers': 4,
            'pin_memory': True,
            'persistent_workers': True,
            'checkpoint_dir': None,
            'save_interval': 20,
            'load_checkpoint': None,
            'validation_enabled': True,
            'validation_interval': 10,
            'val_split': 0.1,
            'mle_learning_rate': None,
            'mle_weight_decay': 1e-7,
            'mle_epochs': 100,
            'mle_solver_type': 'euler',
            'mle_time_steps': 16,
            'mle_discretization_points': 32,
            'mle_lr_scheduler_type': 'constant',
            'mle_warmup_steps': 0,
            'mle_ema_power': 0.999,
            'mle_checkpoint_prefix': 'mle',
            'test_start_seed': 1000,
            'test_episodes': 1,
            'test_runs_per_episode': 10,
            'results_save_dir': './results/',
            'save_trajectories': True,
            'save_videos': False,
            'evaluation_metrics': ['success_rate', 'reward', 'trajectory_length'],
            'evaluation_render': True,
            'model_type': 'unet',
            'model_architecture': 'ConditionalUnet1D',
            'unet_hidden_size': 256,
            'unet_num_blocks': 4,
            'unet_num_layers_per_block': 2,
            'unet_use_attention': True,
            'transformer_hidden_dim': 512,
            'transformer_num_layers': 6,
            'transformer_num_heads': 8,
            'transformer_dropout': 0.1,
            'use_vision_encoder': False,
            'vision_encoder_type': 'resnet18',
            'replace_bn_with_gn': True,
            'seed': 42,
            'log_level': 'INFO',
            'log_format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        }
        
        for key, default_value in defaults.items():
            if key not in config_dict:
                config_dict[key] = default_value
        
        # Auto-set MLE learning rate if not provided
        if config_dict['mle_learning_rate'] is None:
            config_dict['mle_learning_rate'] = config_dict['learning_rate'] / 20
        
        # Setup checkpoint directory
        if config_dict['checkpoint_dir'] is None:
            config_dict['checkpoint_dir'] = f"./checkpoint/{config_dict['experiment_name']}/"
    
    def get_config_dict(self) -> Dict[str, Any]:
        """Get the processed configuration dictionary."""
        return self._processed_config.copy()
    
    def get_attr(self, key: str, default: Any = None) -> Any:
        """Get configuration attribute with fallback to default."""
        return self._processed_config.get(key, default)
    
    def __getitem__(self, key: str) -> Any:
        """Allow dictionary-style access to configuration."""
        return self._processed_config[key]
    
    def __getattr__(self, key: str) -> Any:
        """Allow attribute-style access to configuration."""
        if key in self._processed_config:
            return self._processed_config[key]
        raise AttributeError(f"Configuration has no attribute '{key}'")
    
    def get_dataset_kwargs(self) -> Dict[str, Any]:
        """Get dataset-specific keyword arguments."""
        # Check if this is a PushT experiment
        if self._processed_config.get('experiment_name') == 'pusht':
            return {
                "dataset_path": self._processed_config['dataset_path'],
                "pred_horizon": self._processed_config['pred_horizon'],
                "obs_horizon": self._processed_config['obs_horizon'],
                "action_horizon": self._processed_config['action_horizon'],
            }
        elif self._processed_config.get('experiment_name') == 'mimic':
            # For Mimic/Robomimic experiments
            return {
                "hdf5_path": self._processed_config['dataset_path'],
                "obs_keys": self._processed_config['obs_keys'],
                "seq_length": self._processed_config['pred_horizon'],
                "frame_stack": self._processed_config['obs_horizon'],
                "pad_frame_stack": True,
                "pad_seq_length": True,
                "get_pad_mask": False,
                "goal_mode": None,
                "hdf5_cache_mode": "all",
                "hdf5_use_swmr": True,
                "hdf5_normalize_obs": True,
            }
        else:
            # Default for other experiments (like Kitchen)
            return {
                "dataset_dir": self._processed_config['dataset_path'],
                "horizon": self._processed_config['pred_horizon'],
                "seed": self._processed_config.get('seed', 42),
                "val_ratio": self._processed_config.get('val_split', 0.1),
            }
    
    def get_env_kwargs(self) -> Dict[str, Any]:
        """Get environment-specific keyword arguments."""
        if self._processed_config.get('experiment_name') == 'mimic':
            return {
                "render": False,
                "render_offscreen": True,  # 启用离屏渲染以支持图像保存
                "use_image_obs": False,
            }
        else:
            return {}
    
    def get_wrapper_kwargs(self) -> Dict[str, Any]:
        """Get environment wrapper-specific keyword arguments."""
        if self._processed_config.get('experiment_name') == 'mimic':
            return {
                "obs_keys": self._processed_config.get('obs_keys', ['object', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos']),
                "render_hw": self._processed_config.get('render_hw', [500, 500]),
                "render_camera_name": self._processed_config.get('render_camera_name', 'frontview'),
            }
        else:
            return {}
    
    def validate(self):
        """Validate configuration parameters."""
        config = self._processed_config
        
        assert config['obs_horizon'] > 0, "obs_horizon must be positive"
        assert config['pred_horizon'] > 0, "pred_horizon must be positive"
        assert config['action_horizon'] > 0, "action_horizon must be positive"
        assert config['action_dim'] > 0, "action_dim must be positive"
        assert config['epochs'] > 0, "epochs must be positive"
        assert config['batch_size'] > 0, "batch_size must be positive"
        assert 0 <= config['learning_rate'] <= 1, "learning_rate must be in [0, 1]"
        assert config['model_type'] in ["unet", "transformer"], "model_type must be 'unet' or 'transformer'"
        assert config['lr_scheduler_type'] in ["cosine", "linear", "constant"], "Invalid lr_scheduler_type"
        assert config['mle_solver_type'] in ["euler", "odeint", "odeint_adjoint"], "Invalid mle_solver_type"
    
    def print_config(self):
        """Print the current configuration."""
        print("=" * 50)
        print("Current Configuration:")
        print("=" * 50)
        for key, value in self._processed_config.items():
            print(f"{key}: {value}")
        print("=" * 50)
