#!/usr/bin/env python
#
# Kitchen experiment implementation for Flow Matching with Hydra

import os
import torch
import torch.nn as nn
import numpy as np
from typing import Dict, Any, Tuple
from torch.utils.data import DataLoader
import sys
# Add path for imports
sys.path.append('../models')
sys.path.append('../kitchen')

from unet import ConditionalUnet1D
from external.models.TransformerForDiffusion import TransformerForDiffusion
import kitchen_lowdim_dataset
from diffusion_policy.env.kitchen.v0 import KitchenAllV0

# Import parent class
sys.path.append('../utils')
from base_experiment import BaseExperiment
from hydra_config import HydraConfigManager

class KitchenExperimentHydra(BaseExperiment):
    """Kitchen manipulation experiment using Flow Matching with Hydra configuration."""
    
    def __init__(self, config):
        # Convert Hydra config to compatible format
        self.hydra_config = config
        self.config = HydraConfigManager(config)
            
        super().__init__(self.config)
        
    def setup_dataset(self) -> DataLoader:
        """Setup Kitchen dataset."""
        print("Setting up Kitchen dataset...")
        
        dataset = kitchen_lowdim_dataset.KitchenLowdimDataset(
            **self.config.get_dataset_kwargs()
        )
        
        print(f"Dataset size: {len(dataset)}")
        
        dataloader = DataLoader(
            dataset,
            batch_size=self.config.batch_size,
            num_workers=self.config.num_workers,
            shuffle=True,
            pin_memory=self.config.pin_memory,
            persistent_workers=self.config.persistent_workers
        )
        
        return dataloader
    
    def setup_model(self) -> nn.Module:
        """Setup Kitchen model (low-dim observations, no vision encoder)."""
        print(f"Setting up Kitchen model (type: {self.config.model_type})...")
        
        # Kitchen uses low-dim observations, no vision encoder needed
        if self.config.model_type == "transformer":
            noise_pred_net = TransformerForDiffusion(
                input_dim=self.config.action_dim,
                output_dim=self.config.action_dim,
                horizon=self.config.pred_horizon,
                cond_dim=self.config.vision_feature_dim
            )
        else:  # unet
            noise_pred_net = ConditionalUnet1D(
                input_dim=self.config.action_dim,
                global_cond_dim=self.config.vision_feature_dim
            )
        
        # For low-dim observations, we only need the noise prediction network
        nets = nn.ModuleDict({
            'noise_pred_net': noise_pred_net
        }).to(self.device)
        
        return nets
    
    def process_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]:
        """Process Kitchen batch data (low-dim observations)."""
        # Extract observations and actions
        x_obs = batch['obs'][:, :self.config.obs_horizon].to(self.device)
        x_traj = batch['action'].to(self.device)
        
        # For low-dim observations, just flatten the obs_horizon dimension
        # x_obs shape: [batch_size, obs_horizon, feature_dim]
        obs_cond = x_obs.flatten(start_dim=1)
        
        return obs_cond, x_traj
    
    def setup_environment(self):
        """Setup Kitchen evaluation environment."""
        """This part is for testing"""
        print("Setting up Kitchen environment...")
        
        env = KitchenAllV0(**self.config.get_env_kwargs())
        return env
