#!/usr/bin/env python
#
# PushT experiment implementation for Flow Matching with Hydra support

import sys
import os
import torch
import torch.nn as nn
import numpy as np
import collections
from typing import Dict, Any, Tuple
from torch.utils.data import DataLoader

# Add path for imports
sys.path.append('../models')

from unet import ConditionalUnet1D
from TransformerForDiffusion import TransformerForDiffusion
from resnet import get_resnet, replace_bn_with_gn
import pusht

# Import parent class and Hydra config manager
sys.path.append('../utils')
from base_experiment import BaseExperiment
from hydra_config import HydraConfigManager

class PushTExperimentHydra(BaseExperiment):
    """PushT manipulation experiment using Flow Matching with Hydra configuration."""
    
    def __init__(self, config):
        # Convert Hydra config to compatible format
        self.hydra_config = config
        self.config = HydraConfigManager(config)
            
        super().__init__(self.config)
        self.stats = None
        
    def setup_dataset(self) -> DataLoader:
        """Setup PushT dataset."""
        print("Setting up PushT dataset...")
        
        dataset = pusht.PushTImageDataset(
            **self.config.get_dataset_kwargs()
        )
        
        # Save data statistics for normalization
        self.stats = dataset.stats
        
        print(f"Dataset size: {len(dataset)}")
        
        dataloader = DataLoader(
            dataset,
            batch_size=self.config.batch_size,
            num_workers=self.config.num_workers,
            shuffle=True,
            pin_memory=self.config.pin_memory,
            persistent_workers=self.config.persistent_workers
        )
        
        return dataloader
    
    def setup_model(self) -> nn.Module:
        """Setup PushT model (with vision encoder)."""
        print(f"Setting up PushT model (type: {self.config.model_type})...")
        
        # Setup vision encoder
        vision_encoder = get_resnet(self.config.vision_encoder_type)
        vision_encoder = replace_bn_with_gn(vision_encoder)
        
        # Setup noise prediction network
        if self.config.model_type == "transformer":
            noise_pred_net = TransformerForDiffusion(
                input_dim=self.config.action_dim,
                output_dim=self.config.action_dim,
                horizon=self.config.pred_horizon,
                cond_dim=self.config.vision_feature_dim
            )
        else:  # unet
            noise_pred_net = ConditionalUnet1D(
                input_dim=self.config.action_dim,
                global_cond_dim=self.config.vision_feature_dim
            )
        
        # Combine into ModuleDict
        nets = nn.ModuleDict({
            'vision_encoder': vision_encoder,
            'noise_pred_net': noise_pred_net
        }).to(self.device)
        
        return nets
    
    def process_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]:
        """Process PushT batch data."""
        # Extract observations and actions
        x_img = batch['image'][:, :self.config.obs_horizon].to(self.device)
        x_pos = batch['agent_pos'][:, :self.config.obs_horizon].to(self.device)
        x_traj = batch['action'].to(self.device)
        
        # Process images through vision encoder
        image_features = self.model['vision_encoder'](x_img.flatten(end_dim=1))
        image_features = image_features.reshape(*x_img.shape[:2], -1)
        
        # Combine image features and position
        obs_features = torch.cat([image_features, x_pos], dim=-1)
        
        if self.config.model_type == "transformer":
            # For transformer, keep sequence dimension
            obs_cond = obs_features
        else:
            # For UNet, flatten
            obs_cond = obs_features.flatten(start_dim=1)
        
        return obs_cond, x_traj

    def setup_environment(self):
        """Setup PushT evaluation environment."""
        """This part is for testing"""
        print("Setting up PushT environment...")
        
        env = pusht.PushTImageEnv()
        return env