from pathlib import Path
import os
import numpy as np
import torch
import h5py
from typing import Dict, Optional, Union, List, Any

from torch.utils.data import Dataset
from src.dataset.normalizer import LinearNormalizer
from src.common.control import ControlMode
import src.common.geometry as C

def load_h5_data(data):
    """Recursively load data from h5py file."""
    out = dict()
    for k in data.keys():
        if isinstance(data[k], h5py.Dataset):
            out[k] = data[k][:]
        else:
            out[k] = load_h5_data(data[k])
    return out


class ManiSkillStateDataset(Dataset):
    def __init__(
        self,
        dataset_paths: Union[List[Path], Path],
        pred_horizon: int,
        obs_horizon: int,
        action_horizon: int,
        data_subset: Optional[int] = None,
        predict_past_actions: bool = False,
        control_mode: ControlMode = ControlMode.delta,
        pad_after: bool = True,
        normalizer: Optional[LinearNormalizer] = None,
        verify_control_mode: bool = True,
    ):
        super().__init__()
        self.pred_horizon = pred_horizon
        self.action_horizon = action_horizon
        self.obs_horizon = obs_horizon
        self.predict_past_actions = predict_past_actions
        self.control_mode = control_mode
        self.dataset_control_mode = None
        self.verify_control_mode = verify_control_mode
        
        
        if isinstance(dataset_paths, (str, Path)):
            dataset_paths = [dataset_paths]
        
        # Load data from h5 files
        self.observations = []
        self.actions = []
        self.episode_ends = []
        
        episode_count = 0
        frame_count = 0
        
        for path in dataset_paths:
            print(f"Loading data from {path}")
            data = h5py.File(path, "r")
            json_path = str(path).replace(".h5", ".json")
            import json
            try:
                with open(json_path, "r") as f:
                    json_data = json.load(f)
                episodes = json_data["episodes"]
                
                # Extract control mode from metadata for verification
                if "env_info" in json_data and "env_kwargs" in json_data["env_info"]:
                    if "control_mode" in json_data["env_info"]["env_kwargs"]:
                        self.dataset_control_mode = json_data["env_info"]["env_kwargs"]["control_mode"]
                        print(f"Found control mode in dataset: {self.dataset_control_mode}")
                    elif len(episodes) > 0 and "control_mode" in episodes[0]:
                        self.dataset_control_mode = episodes[0]["control_mode"]
                        print(f"Found control mode in episode: {self.dataset_control_mode}")
                
            except:
                # Count trajectories
                episodes = [{"episode_id": i} for i in range(len([k for k in data.keys() if k.startswith("traj_")]))]
                self.dataset_control_mode = "pd_ee_delta_pos"
            
            # Limit number of episodes if data_subset is specified
            if data_subset is not None and episode_count + len(episodes) > data_subset:
                original_length = len(episodes)
                episodes = episodes[:data_subset - episode_count]
                print(f"Using data_subset={data_subset}, selecting {len(episodes)} out of {original_length} episodes from {path}")
            elif data_subset is not None:
                print(f"Using data_subset={data_subset}, currently loaded {episode_count} episodes, adding all {len(episodes)} from {path}")
            
            for episode in episodes:
                trajectory = data[f"traj_{episode['episode_id']}"]
                trajectory_data = load_h5_data(trajectory)
                
                # Get observations and actions
                obs = trajectory_data["obs"][:-1]
                
                # Try to extract actions with proper fallbacks
                if "actions" in trajectory_data:
                    actions = trajectory_data["actions"]
                elif "action" in trajectory_data:
                    actions = trajectory_data["action"]
                else:
                    action_keys = [k for k in trajectory_data.keys() if k.startswith("action")]
                    if action_keys:
                        actions = trajectory_data[action_keys[0]]
                    else:
                        raise ValueError(f"No action data found in trajectory {episode['episode_id']}")
                
                # if control_mode != ControlMode.delta:
                #     if control_mode == ControlMode.position:
                #         position_keys = [
                #             "actions_pd_joint_pos",  # Common format
                #             "action_pd_joint_pos",   # Alternative format
                #             "pd_joint_pos"          # Direct format
                #         ]
                        
                #         for key in position_keys:
                #             if key in trajectory_data:
                #                 actions = trajectory_data[key]
                #                 break
                #         else:
                #             print(f"Warning: Could not find position control actions in trajectory {episode['episode_id']}, using default actions")
                
                self.observations.append(obs)
                self.actions.append(actions)
                
                frame_count += len(obs)
                self.episode_ends.append(frame_count)
                
                episode_count += 1
                if data_subset is not None and episode_count >= data_subset:
                    break
            
            data.close()
        
        # Convert to numpy arrays
        self.observations = np.vstack(self.observations)
        self.actions = np.vstack(self.actions)
        self.episode_ends = np.array(self.episode_ends)
        self.obs_dim = self.observations.shape[1]
        self.action_dim = self.actions.shape[1]
        
        # For ManiSkill, split the observation into robot state and parts poses
        # TODO HARDCODED - robot_state_dim should be the first 8 dimensions (Panda robot has 8-DoF in ManiSkill) 
        # parts_poses_dim should be the remaining dimensions

        self.robot_state_dim = min(8, self.obs_dim)  # First 8 dims as robot state
        self.parts_poses_dim = self.obs_dim - self.robot_state_dim  # Remaining dims as parts poses
        
        observations_tensor = torch.from_numpy(self.observations)
        
        self.train_data = {
            "obs": observations_tensor.clone(),
            "action": torch.from_numpy(self.actions),
        }
        
        # Split the observation into robot state and parts poses
        if self.robot_state_dim > 0:
            self.train_data["robot_state"] = observations_tensor[:, :self.robot_state_dim]
        else:
            self.train_data["robot_state"] = torch.zeros((len(observations_tensor), 0))
            
        if self.parts_poses_dim > 0:
            self.train_data["parts_poses"] = observations_tensor[:, self.robot_state_dim:]
        else:
            self.train_data["parts_poses"] = torch.zeros((len(observations_tensor), 0))
        
        # Initialize or set normalizer
        self.normalizer = LinearNormalizer() if normalizer is None else normalizer
        if normalizer is None:
            self.normalizer.fit(self.train_data)
        
        # Normalize data to [-1, 1]
        for key in self.normalizer.keys():
            self.train_data[key] = self.normalizer(self.train_data[key], key, forward=True)
        
        # Create indices for sampling
        self.sequence_length = (
            pred_horizon if predict_past_actions else obs_horizon + pred_horizon - 1
        )
        self.indices = self._create_sample_indices(
            pad_before=obs_horizon - 1,
            pad_after=action_horizon - 1 if pad_after else 0,
        )
        
        self.n_samples = len(self.indices)
        
        # Set action indices limits
        self.first_action_idx = 0 if predict_past_actions else self.obs_horizon - 1
        self.final_action_idx = self.first_action_idx + self.pred_horizon
        
        # Metadata for logging
        self.metadata = {str(path): {"n_episodes_used": episode_count, "n_frames_used": frame_count}}
        
        # Print summary 
        print(f"ManiSkill State Dataset Summary:")
        print(f"  - Total episodes loaded: {episode_count}")
        print(f"  - Total frames loaded: {frame_count}")
        print(f"  - Observation dimension: {self.obs_dim} (robot_state: {self.robot_state_dim}, parts_poses: {self.parts_poses_dim})")
        print(f"  - Action dimension: {self.action_dim}")
        print(f"  - Samples after processing: {self.n_samples}")
        if data_subset is not None:
            print(f"  - Loaded {episode_count}/{data_subset} requested episodes ({(episode_count/data_subset)*100:.1f}%)")
        print(f"  - Control mode: {self.dataset_control_mode}")
        print("=" * 50)
    
    def _create_sample_indices(self, pad_before=0, pad_after=0):
        """Create indices for sampling sequences from the dataset."""
        indices = []
        for i in range(len(self.episode_ends)):
            start_idx = 0
            if i > 0:
                start_idx = self.episode_ends[i - 1]
            end_idx = self.episode_ends[i]
            episode_length = end_idx - start_idx
            
            min_start = -pad_before
            max_start = episode_length - self.sequence_length + pad_after
            
            for idx in range(min_start, max_start + 1):
                buffer_start_idx = max(idx, 0) + start_idx
                buffer_end_idx = min(idx + self.sequence_length, episode_length) + start_idx
                start_offset = buffer_start_idx - (idx + start_idx)
                end_offset = (idx + self.sequence_length + start_idx) - buffer_end_idx
                sample_start_idx = 0 + start_offset
                sample_end_idx = self.sequence_length - end_offset
                indices.append(
                    [buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx, i]
                )
        
        return np.array(indices)
    
    def set_normalizer(self, normalizer: LinearNormalizer):
        """Set the normalizer for the dataset."""
        self.normalizer.load_state_dict(normalizer.state_dict())
    
    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        # Get indices for this datapoint
        (
            buffer_start_idx,
            buffer_end_idx,
            sample_start_idx,
            sample_end_idx,
            demo_idx,
        ) = self.indices[idx]
        
        # Sample sequence
        nsample = self._sample_sequence(
            sequence_length=self.sequence_length,
            buffer_start_idx=buffer_start_idx,
            buffer_end_idx=buffer_end_idx,
            sample_start_idx=sample_start_idx,
            sample_end_idx=sample_end_idx,
        )
        
        # Discard unused actions
        nsample["action"] = nsample["action"][
            self.first_action_idx : self.final_action_idx, :
        ]
        
        # Discard unused observations
        nsample["obs"] = nsample["obs"][: self.obs_horizon, :]
        nsample["robot_state"] = nsample["robot_state"][: self.obs_horizon, :]
        nsample["parts_poses"] = nsample["parts_poses"][: self.obs_horizon, :]
        nsample["success"] = torch.IntTensor([1])
        
        return nsample
    
    def _sample_sequence(
        self,
        sequence_length: int,
        buffer_start_idx: int,
        buffer_end_idx: int,
        sample_start_idx: int,
        sample_end_idx: int,
    ) -> Dict[str, torch.Tensor]:
        """Sample a sequence from the dataset."""
        result = {}
        for key, input_arr in self.train_data.items():
            sample = input_arr[buffer_start_idx:buffer_end_idx]
            data = sample
            
            if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
                data = torch.zeros(
                    size=(sequence_length,) + input_arr.shape[1:], 
                    dtype=input_arr.dtype
                )
                if sample_start_idx > 0:
                    data[:sample_start_idx] = sample[0]
                if sample_end_idx < sequence_length:
                    data[sample_end_idx:] = sample[-1]
                data[sample_start_idx:sample_end_idx] = sample
            
            result[key] = data
        
        return result
    
    def train(self):
        pass
    
    def eval(self):
        pass


class ManiSkillImageDataset(Dataset):
    def __init__(
        self,
        dataset_paths: Union[List[Path], Path],
        pred_horizon: int,
        obs_horizon: int,
        action_horizon: int,
        data_subset: Optional[int] = None,
        predict_past_actions: bool = False,
        control_mode: ControlMode = ControlMode.delta,
        pad_after: bool = True,
        normalizer: Optional[LinearNormalizer] = None,
    ):
        super().__init__()
        self.pred_horizon = pred_horizon
        self.action_horizon = action_horizon
        self.obs_horizon = obs_horizon
        self.predict_past_actions = predict_past_actions
        self.control_mode = control_mode
        
        if isinstance(dataset_paths, (str, Path)):
            dataset_paths = [dataset_paths]
        
        # Load data from h5 files
        self.robot_states = []
        self.rgb_images = []
        self.depth_images = []
        self.actions = []
        self.episode_ends = []
        
        episode_count = 0
        frame_count = 0
        
        for path in dataset_paths:
            print(f"Loading data from {path}")
            data = h5py.File(path, "r")
            json_path = str(path).replace(".h5", ".json")
            import json
            try:
                with open(json_path, "r") as f:
                    json_data = json.load(f)
                episodes = json_data["episodes"]
            except:
                # Count trajectories 
                episodes = [{"episode_id": i} for i in range(len([k for k in data.keys() if k.startswith("traj_")]))]
            
            # Limit number of episodes if data_subset is specified
            if data_subset is not None and episode_count + len(episodes) > data_subset:
                original_length = len(episodes)
                episodes = episodes[:data_subset - episode_count]
                print(f"Using data_subset={data_subset}, selecting {len(episodes)} out of {original_length} episodes from {path}")
            elif data_subset is not None:
                print(f"Using data_subset={data_subset}, currently loaded {episode_count} episodes, adding all {len(episodes)} from {path}")
            
            for episode in episodes:
                trajectory = data[f"traj_{episode['episode_id']}"]
                trajectory_data = load_h5_data(trajectory)
                
                obs = trajectory_data["obs"][:-1]
                actions = trajectory_data["actions"]
                
                # Get RGB images if available
                rgb = None
                if "rgb" in trajectory_data:
                    rgb = trajectory_data["rgb"][:-1]
                elif "rgbs" in trajectory_data:
                    rgb = trajectory_data["rgbs"][:-1]
                
                # Get depth images if available
                depth = None
                if "depth" in trajectory_data:
                    depth = trajectory_data["depth"][:-1]
                elif "depths" in trajectory_data:
                    depth = trajectory_data["depths"][:-1]
                
                # # Convert to requested control mode if necessary
                # if control_mode != ControlMode.delta and "pd_joint_delta_pos" in actions:
                #     if control_mode == ControlMode.position:
                #         # Use absolute positions instead of deltas
                #         actions = trajectory_data["actions_pd_joint_pos"]
                
                # Add to dataset
                self.robot_states.append(obs)
                self.actions.append(actions)
                if rgb is not None:
                    self.rgb_images.append(rgb)
                if depth is not None:
                    self.depth_images.append(depth)
                
                # Track episode end indices
                frame_count += len(obs)
                self.episode_ends.append(frame_count)
                
                episode_count += 1
                if data_subset is not None and episode_count >= data_subset:
                    break
            
            data.close()
        
        # Convert to numpy arrays
        self.robot_states = np.vstack(self.robot_states)
        self.actions = np.vstack(self.actions)
        self.episode_ends = np.array(self.episode_ends)
        
        has_rgb = len(self.rgb_images) > 0
        has_depth = len(self.depth_images) > 0
        
        if has_rgb:
            self.rgb_images = np.vstack(self.rgb_images)
        if has_depth:
            self.depth_images = np.vstack(self.depth_images)
        
        # Compute dimensions
        self.action_dim = self.actions.shape[1]
        
        # For ManiSkill, split the observation into robot state and parts poses
        # TODO HARDCODED - robot_state_dim should be the first 8 dimensions (Panda robot has 8-DoF in ManiSkill) 
        # parts_poses_dim should be the remaining dimensions
        self.robot_state_dim = min(8, self.robot_states.shape[1])  # First 8 dims as robot state
        self.parts_poses_dim = self.robot_states.shape[1] - self.robot_state_dim  # Remaining dims as parts poses
        
        robot_states_tensor = torch.from_numpy(self.robot_states)
        
        self.train_data = {
            "obs": robot_states_tensor.clone(),
            "action": torch.from_numpy(self.actions),
        }
        
        # Split the observation into robot state and parts poses
        if self.robot_state_dim > 0:
            self.train_data["robot_state"] = robot_states_tensor[:, :self.robot_state_dim]
        else:
            self.train_data["robot_state"] = torch.zeros((len(robot_states_tensor), 0))
            
        if self.parts_poses_dim > 0:
            self.train_data["parts_poses"] = robot_states_tensor[:, self.robot_state_dim:]
        else:
            self.train_data["parts_poses"] = torch.zeros((len(robot_states_tensor), 0))
        
        if has_rgb:
            self.train_data["color_image1"] = torch.from_numpy(
                self.rgb_images
            ).permute(0, 3, 1, 2)
        
        if has_depth:
            if len(self.depth_images.shape) == 3:
                depth_reshaped = self.depth_images.reshape(
                    self.depth_images.shape[0], 
                    self.depth_images.shape[1], 
                    self.depth_images.shape[2], 
                    1
                )
                self.train_data["depth_image1"] = torch.from_numpy(
                    depth_reshaped
                ).permute(0, 3, 1, 2)
            else:
                self.train_data["depth_image1"] = torch.from_numpy(
                    self.depth_images
                ).permute(0, 3, 1, 2)
        
        # Initialize or set normalizer
        self.normalizer = LinearNormalizer() if normalizer is None else normalizer
        if normalizer is None:
            norm_data = {
                "robot_state": self.train_data["robot_state"],
                "action": self.train_data["action"],
            }
            self.normalizer.fit(norm_data)
        
        # Normalize data to [-1, 1]
        for key in self.normalizer.keys():
            if key in self.train_data:
                self.train_data[key] = self.normalizer(self.train_data[key], key, forward=True)
        
        # Create indices for sampling
        self.sequence_length = (
            pred_horizon if predict_past_actions else obs_horizon + pred_horizon - 1
        )
        self.indices = self._create_sample_indices(
            pad_before=obs_horizon - 1,
            pad_after=action_horizon - 1 if pad_after else 0,
        )
        
        self.n_samples = len(self.indices)
        
        # Set action indices limits
        self.first_action_idx = 0 if predict_past_actions else self.obs_horizon - 1
        self.final_action_idx = self.first_action_idx + self.pred_horizon
        
        # Add image keys for the training pipeline
        self.image_keys = []
        if has_rgb:
            self.image_keys.append("color_image1")
        if has_depth:
            self.image_keys.append("depth_image1")
            
        # Metadata for logging
        self.metadata = {str(path): {"n_episodes_used": episode_count, "n_frames_used": frame_count}}
        
        # Print summary of loaded data
        print(f"ManiSkill Image Dataset Summary:")
        print(f"  - Total episodes loaded: {episode_count}")
        print(f"  - Total frames loaded: {frame_count}")
        print(f"  - Observation dimension: robot_state: {self.robot_state_dim}, parts_poses: {self.parts_poses_dim}")
        print(f"  - Action dimension: {self.action_dim}")
        print(f"  - Image keys: {self.image_keys}")
        print(f"  - Samples after processing: {self.n_samples}")
        if data_subset is not None:
            print(f"  - Loaded {episode_count}/{data_subset} requested episodes ({(episode_count/data_subset)*100:.1f}%)")
        print("=" * 50)
    
    def _create_sample_indices(self, pad_before=0, pad_after=0):
        """Create indices for sampling sequences from the dataset."""
        indices = []
        for i in range(len(self.episode_ends)):
            start_idx = 0
            if i > 0:
                start_idx = self.episode_ends[i - 1]
            end_idx = self.episode_ends[i]
            episode_length = end_idx - start_idx
            
            min_start = -pad_before
            max_start = episode_length - self.sequence_length + pad_after
            
            for idx in range(min_start, max_start + 1):
                buffer_start_idx = max(idx, 0) + start_idx
                buffer_end_idx = min(idx + self.sequence_length, episode_length) + start_idx
                start_offset = buffer_start_idx - (idx + start_idx)
                end_offset = (idx + self.sequence_length + start_idx) - buffer_end_idx
                sample_start_idx = 0 + start_offset
                sample_end_idx = self.sequence_length - end_offset
                indices.append(
                    [buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx, i]
                )
        
        return np.array(indices)
    
    def set_normalizer(self, normalizer: LinearNormalizer):
        """Set the normalizer for the dataset."""
        self.normalizer.load_state_dict(normalizer.state_dict())
    
    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        # Get indices for this datapoint
        (
            buffer_start_idx,
            buffer_end_idx,
            sample_start_idx,
            sample_end_idx,
            demo_idx,
        ) = self.indices[idx]
        
        # Sample sequence
        nsample = self._sample_sequence(
            sequence_length=self.sequence_length,
            buffer_start_idx=buffer_start_idx,
            buffer_end_idx=buffer_end_idx,
            sample_start_idx=sample_start_idx,
            sample_end_idx=sample_end_idx,
        )
        
        for key in self.image_keys:
            if key in nsample:
                nsample[key] = nsample[key][: self.obs_horizon, :]
        
        nsample["robot_state"] = nsample["robot_state"][: self.obs_horizon, :]
        nsample["parts_poses"] = nsample["parts_poses"][: self.obs_horizon, :]
        nsample["action"] = nsample["action"][
            self.first_action_idx : self.final_action_idx, :
        ]
        nsample["success"] = torch.IntTensor([1])
        
        return nsample
    
    def _sample_sequence(
        self,
        sequence_length: int,
        buffer_start_idx: int,
        buffer_end_idx: int,
        sample_start_idx: int,
        sample_end_idx: int,
    ) -> Dict[str, torch.Tensor]:
        """Sample a sequence from the dataset."""
        result = {}
        for key, input_arr in self.train_data.items():
            sample = input_arr[buffer_start_idx:buffer_end_idx]
            data = sample
            
            if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
                data = torch.zeros(
                    size=(sequence_length,) + input_arr.shape[1:], 
                    dtype=input_arr.dtype
                )
                if sample_start_idx > 0:
                    data[:sample_start_idx] = sample[0]
                if sample_end_idx < sequence_length:
                    data[sample_end_idx:] = sample[-1]
                data[sample_start_idx:sample_end_idx] = sample
            
            result[key] = data
        
        return result
    
    def train(self):
        pass
    
    def eval(self):
        pass