import jax
import jax.numpy as jnp
import numpy as np
from utils.old_datasets import StoredDataset
from datasets import TransitionData
import chex

def load_from_stored_dataset(dataset_path, split='train', batch_size=1000):
    """
    Load data from a StoredDataset and convert it to TransitionData format.
    
    Args:
        dataset_path (str): Path to the dataset directory
        split (str): Dataset split to load ('train', 'val', 'test')
        batch_size (int): Number of samples to load at once for efficiency
        
    Returns:
        tuple: (TransitionData, data_cfg, env_config) - the loaded data, data configuration, and environment configuration
    """
    # Load the dataset
    dataset = StoredDataset.load(f'{dataset_path}/{split}')
    
    # Extract configuration information
    data_cfg = dataset.description['data_cfg']
    env_config = dataset.description['env']
    
    # Get total number of samples
    n_samples = len(dataset)
    
    # Initialize lists to store the data
    obs_list = []
    next_obs_list = []
    action_list = []
    reward_list = []
    state_list = []
    next_state_list = []
    
    # Load data in batches for better performance
    for start_idx in range(0, n_samples, batch_size):
        end_idx = min(start_idx + batch_size, n_samples)
        batch = dataset[start_idx:end_idx]
        
        # Split the batch into components
        for sample in batch:
            obs, action, next_obs, reward, state, next_state = sample
            obs_list.append(obs)
            next_obs_list.append(next_obs)
            action_list.append(action)
            reward_list.append(reward)
            state_list.append(state)
            next_state_list.append(next_state)
    # Convert lists to arrays
    obs_array = jnp.array(obs_list)
    next_obs_array = jnp.array(next_obs_list)
    obs_array = jnp.stack([obs_array, next_obs_array], axis=1)
    action_array = jnp.array(action_list)[:, None]
    reward_array = jnp.concatenate([jnp.zeros_like(jnp.array(reward_list))[:, None], jnp.array(reward_list)[:, None]], axis=1)
    state_array = jnp.array(state_list)
    next_state_array = jnp.array(next_state_list)
    state_array = jnp.stack([state_array, next_state_array], axis=1)

    
    # For consistency with other datasets, we'll create done and is_first flags
    # Most datasets don't explicitly store done flags, so we'll set them all to False
    done_array = jnp.zeros_like(reward_array, dtype=bool)

    # Set all is_first flags to False except the first one
    is_first_array = jnp.zeros_like(done_array, dtype=bool)
    is_first_array = is_first_array.at[0].set(True)
    if len(data_cfg['obs_dim']) == 3:
        obs_array = jnp.transpose(obs_array, (0, 1, 3, 4, 2)) # (B, T, C, H, W) -> (B, T, H, W, C)
    # Create TransitionData instance
    transition_data = TransitionData(
        obs=obs_array,
        action=action_array,
        reward=reward_array,
        done=done_array,
        is_first=is_first_array,
        state=state_array
    )
    @chex.dataclass
    class EnvConfig:
        n_actions : int
        discrete: bool
        obs : chex.Array

    env_config = EnvConfig(
        n_actions=data_cfg['n_actions'],
        discrete=data_cfg['discrete'],
        obs=obs_array[0, 0]
    )
    return transition_data, data_cfg, env_config

# Example usage:
if __name__ == "__main__":
    # Example of how to use the function and access the data
    
    # 1. Load the dataset
    dataset_path = "experiments/datasets/taxi_pixel_variable_goal/train"  # Replace with actual path
    transition_data, data_cfg, env_config = load_from_stored_dataset(
        dataset_path=dataset_path,
        split='train',
        batch_size=5000
    )
    
    # 2. Print some information about the dataset
    print(f"Dataset loaded successfully!")
    print(f"Number of transitions: {len(transition_data.obs)}")
    print(f"Observation shape: {transition_data.obs.shape}")
    print(f"Action shape: {transition_data.action.shape}")
    print(f"State dimension: {transition_data.state.shape[-1]}")
    
    # 3. Print configuration information
    print("\nData Configuration:")
    for key, value in data_cfg.items():
        print(f"  {key}: {value}")
    
    print("\nEnvironment Configuration:")
    if hasattr(env_config, 'n_actions'):
        print(f"  Number of actions: {env_config.n_actions}")
    
    # 4. Access a few samples
    print("\nFirst few observations:")
    for i in range(min(3, len(transition_data.obs))):
        print(f"Observation {i}:")
        print(transition_data.obs[i])
        print(f"Action {i}: {transition_data.action[i]}")
        print(f"Reward {i}: {transition_data.reward[i]}")
        print(f"State {i}: {transition_data.state[i]}")
        print("-" * 30) 