import os

from functools import partial
from datasets.stored_dataset_loader import load_from_stored_dataset
from datasets.taxi_suff import generate_taxi_suff_dataset
from datasets.minigrid_navix import minigrid_data_collection
from datasets.minigrid_navix import minigrid_doorkey_data_collection

from datasets.dataset_generators import (
    generate_multi_object_dataset,
    generate_multi_object_dataset_with_selection,
    generate_taxi_dataset,
    generate_l_shaped_dataset
)
from datasets.gymnax_taxi import generate_taxi_gymnax_dataset

import jax
import jax.numpy as jnp
import numpy as np
from tqdm import tqdm

def get_generator_fn(config):
    if config.env == 'grid_2d':
        return partial(generate_l_shaped_dataset, thickness=config.thickness)
    elif config.env == 'taxi':
        return partial(
            generate_taxi_dataset,
            grid_size=config.datasets[config.env].grid_size,
            n_passengers=config.datasets[config.env].n_passengers,
            img_size=config.datasets[config.env].img_size,
        )
    elif config.env == 'taxi_suff':
        return partial(
            generate_taxi_suff_dataset,
            grid_size=config.datasets[config.env].grid_size,
            n_passengers=config.datasets[config.env].n_passengers,
            img_size=config.datasets[config.env].img_size,
        )
    elif config.env == 'taxi_gymnax':
        return partial(
            generate_taxi_gymnax_dataset,
            size=config.datasets[config.env].grid_size,
            n_passengers=config.datasets[config.env].n_passengers,
            img_size=config.datasets[config.env].img_size,
            allow_dropoff_anywhere=config.datasets[config.env].get('allow_dropoff_anywhere', True),
            exploring_starts=config.datasets[config.env].get('exploring_starts', True),
        )
    elif config.env == 'navix':
        return minigrid_data_collection(config)
    elif config.env == 'navix_doorkey':
        behavior_config = {
            'random_prob': 0.3,       # 10% purely random (reduced)
            'smart_prob': 0.4,        # 40% smart policy (increased)
            'key_drop_prob': 0.05,    # 5% chance to drop key (reduced)
            'goal_directed_probs': {
                'key_only': 0.25,     # 25% get key then random
                'door_only': 0.25,    # 25% open door then random
                'random_switch_prob': 0.05  # 5% chance to switch to random each step (reduced)
            },
            'action_probs': {
                'pickup': 0.4,        # Increase probability of pickup action
                'open': 0.4           # Increase probability of open action
            },
            'explore_prob': 0.15
        }
        return minigrid_doorkey_data_collection(config, behavior_config)
    elif config.env == 'multi_object':
        return partial(
            generate_multi_object_dataset,
            n_objects=config.datasets[config.env].n_objects,
            img_size=config.datasets[config.env].img_size,
        )
    elif config.env == 'multi_object_selection':
        return partial(
            generate_multi_object_dataset_with_selection,
            n_objects=config.datasets[config.env].n_objects,
            img_size=config.datasets[config.env].img_size,
        )
    else:
        raise ValueError(f"Unknown environment type: {config.env}")
    
def generate_dataset(config, rng_data, batch_size=1024, n_samples=None):
    """Generate or load dataset based on configuration.
    
    Args:
        config: Configuration object containing dataset parameters
        rng_data: JAX random key for data generation
        batch_size: Batch size for dataset generation (default: 1024)
        
    Returns:
        tuple: (dataset, env_config, horizon) containing the generated dataset,
               environment configuration, and horizon length
    """
    if config.get('stored_dataset_path', None) is not None and os.path.exists(config.stored_dataset_path):
        dataset, data_cfg, env_config = load_from_stored_dataset(
            dataset_path=config.stored_dataset_path,
            split='train',
            batch_size=5000
        )
        horizon = dataset.obs.shape[1]-1
    else:
        horizon = config.horizon
        thickness = config.thickness
        n_samples = config.data_collection.n_samples if n_samples is None else n_samples
        n_eps = n_samples // horizon
        
        # Generate dataset in batches to avoid memory issues
        n_batches = (n_eps + batch_size - 1) // batch_size  # Ceiling division
        
        datasets = []
        generator_fn = get_generator_fn(config)
        _, env_config = generator_fn(horizon, rng_data)

        print(f'Generating {n_batches} batches of {batch_size} episodes of {config.env}')
        generator_fn_jitted = jax.jit(lambda *args: generator_fn(*args)[0], static_argnums=(0,))
        # generator_fn_jitted = lambda *args: generator_fn(*args)[0]
        for i in tqdm(range(n_batches)):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, n_eps)
            batch_size_i = end_idx - start_idx

            rng_data, rng_data_i = jax.random.split(rng_data)
            batch_dataset_gpu = jax.vmap(generator_fn_jitted, in_axes=(None, 0))(
                horizon,
                jax.random.split(rng_data_i, batch_size_i),
            )
    
            # Transfer batch to CPU memory
            batch_dataset = jax.tree.map(lambda x: np.array(x), batch_dataset_gpu)
            del batch_dataset_gpu
            datasets.append(batch_dataset)
        
        # Concatenate all batches
        dataset = jax.tree.map(lambda *xs: np.concatenate(xs), *datasets)
        # env_config = jax.tree.map(lambda x: x[0], env_configs[0])  # All env_configs should be identical
        
    return dataset, env_config, horizon