from functools import partial
import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax
import chex
import orbax.checkpoint as ocp

from omegaconf import OmegaConf as oc
from envs.envbuilder import EnvironmentBuilder
from datasets import TransitionData
from policies.navix_smart import doorkey_policy, PolicyState, get_dfs_path, DFSState, MINIGRID_ACTION_MAP, COMPLETE_ACTION_MAP, get_grid
from policies.dfs import DFSState

# check if "eval" is registered in OmegaConf
if not oc.has_resolver('eval'):
    oc.register_new_resolver('eval', eval)

@chex.dataclass(frozen=True)
class EnvConfig:
    obs: chex.Array
    n_actions: int

@chex.dataclass(frozen=True)
class MinigridTransitionData(TransitionData):
    graph: jax.Array

FULL_ACTION_MAP = {
    'noop': 0,
    'rotate_cw': 1,
    'rotate_ccw': 2,
    'forward': 3,
    'right': 4,
    'backward': 5,
    'left': 6,
    'pickup': 7,
    'open': 8,
    'done': 9
}

PICKUP_PROB = 0.8
OPEN_PROB = 0.8

def minigrid_data_generator(env_config, horizon, rng):
    print('Generating minigrid with random policy...')
    # policy selection if exists
    def policy_fn(rng, state):
        # Important: Defines the policy for action selection in the environment
        navix_state = state.env_state.timestep.state
        # check if player is adjacent to key or door
        player = navix_state.entities['player'][0]
        key = navix_state.entities.get('key', [None])[0]
        door = navix_state.entities.get('door', [None])[0]
        
        PICKUP_POLICY = (jnp.ones((env_config.n_actions,)) * (1-PICKUP_PROB) / (env_config.n_actions-1)).at[FULL_ACTION_MAP['pickup']].set(PICKUP_PROB)
        OPEN_POLICY = (jnp.ones((env_config.n_actions,)) * (1-OPEN_PROB) / (env_config.n_actions-1)).at[FULL_ACTION_MAP['open']].set(OPEN_PROB)
        UNIFORM_POLICY = jnp.ones((env_config.n_actions,)) / env_config.n_actions
        if key is not None and door is not None:
            is_key_adjacent = (jnp.abs(player.position - key.position).sum() <= 1)
            is_door_adjacent = (jnp.abs(player.position - door.position).sum() <= 1)
            probs = jax.lax.cond(
                is_key_adjacent,
                lambda: PICKUP_POLICY,
                lambda: jax.lax.cond(
                    is_door_adjacent,
                    lambda: OPEN_POLICY,
                    lambda: UNIFORM_POLICY
                )
            )
        else:
            probs = UNIFORM_POLICY
        action = jax.random.categorical(rng, jnp.log(probs))
        return action
    
    def rollout_env(n_samples, rng):
        # Important: Rollout function to generate trajectories in parallel
        rng_reset, rng_step, rng_noise = jax.random.split(rng, 3)
        init_obs, env_state = env_config.env_reset(rng_reset)

        def _rollout_step(state, rng):
            last_obs, env_state, last_done = state
            rng_action, rng_step, rng_noise = jax.random.split(rng, 3)
            action = policy_fn(rng_action, env_state)
            obs, env_state, reward, done, info = env_config.env_step(rng_step, env_state, action)
            transition = MinigridTransitionData(
                obs=last_obs,
                action=action,
                reward=reward,
                done=done.astype(jnp.float32),
                is_first=last_done,
                state=info['state'],
                graph=info['local_graph']
            )

            return (obs, env_state, done.astype(jnp.float32)), transition

        last_state, data = jax.lax.scan(
            _rollout_step,
            (init_obs, env_state, 0),
            jax.random.split(rng_step, n_samples)
        )
        # Concatenate all observations along the first axis
        last_obs = last_state[0]# + jax.random.normal(rng_noise, last_state[0].shape) * 1e-2
        last_state = env_config.base_env.get_state(last_state[1].env_state.timestep.state)
        data = data.replace(
            obs=jnp.concatenate([data.obs, last_obs[None]], axis=0),
            state=jnp.concatenate([data.state, last_state[None]], axis=0)
        )
        return data
    
    dataset = rollout_env(horizon, rng)

    # env_config = EnvConfig(
    #     obs=env_config.obs,
    #     n_actions=env_config.n_actions
    # )
    return dataset, env_config

def mixed_doorkey_data_generator(env_config, horizon, rng, action_set="full", config=None):
    """Data generator that produces a mix of different behaviors in doorkey environments.
    
    This version creates a mix of random exploration and goal-directed behaviors like getting
    the key, opening the door, and going to the goal.
    
    Args:
        env_config: Environment configuration
        horizon: Number of steps to generate
        rng: JAX random key
        config: Optional configuration with behavior parameters:
               - random_prob: Probability of pure random behavior
               - smart_prob: Probability of smart policy behavior
               - key_drop_prob: Probability of dropping the key
               - goal_directed_probs: Dict with probabilities for specific behaviors
               - action_probs: Dict with action biasing probabilities
                
    Returns:
        Tuple of (dataset, env_config), env_states
    """
    from src.policies.navix_smart import doorkey_policy, PolicyState
    from src.policies.dfs import DFSState
    
    # Set default configuration if not provided
    if config is None:
        config = {
            'random_prob': 0.1,       # 10% purely random
            'smart_prob': 0.4,        # 40% smart policy
            'key_drop_prob': 0.05,    # 5% chance to drop key
            'goal_directed_probs': {
                'key_only': 0.25,     # 25% get key then random
                'door_only': 0.25,    # 25% open door then random
                'random_switch_prob': 0.05  # 5% chance to switch to random
            },
            'action_probs': {
                'pickup': 0.4,        # Increased probability of pickup action
                'open': 0.4           # Increased probability of open action
            },
            'explore_prob': 0.1     # 10% exploration mode
        }
    
    def random_policy_fn(key, state):
        """Simple random policy."""
        probs = jnp.ones((env_config.n_actions,))
        action = jax.random.categorical(key, jnp.log(probs/probs.sum()))
        return action, None
    
    def biased_random_policy_fn(key, state, env_state, policy_state):
        """Random policy with bias toward certain actions."""
        # Default uniform probabilities
        probs = jnp.ones((env_config.n_actions,))
        
        # Get action set ID from policy state
        action_set_id = policy_state.action_set_id if hasattr(policy_state, 'action_set_id') else 0
        
        # Bias toward pickup when near key
        if 'pickup' in config['action_probs']:
            player = env_state.env_state.timestep.state.entities['player'][0]
            key_entity = env_state.env_state.timestep.state.entities.get('key', [None])[0]
            
            if key_entity is not None:
                # Check if player is adjacent to key
                player_pos = player.position
                key_pos = key_entity.position
                
                # If manhattan distance is 1, increase pickup probability
                manhattan_dist = jnp.abs(player_pos[0] - key_pos[0]) + jnp.abs(player_pos[1] - key_pos[1])
                is_adjacent = manhattan_dist <= 1
                
                # Get the correct pickup action index based on action set
                pickup_idx = jax.lax.cond(
                    action_set_id == 0,
                    lambda _: 7,  # PICKUP in COMPLETE action set
                    lambda _: MINIGRID_ACTION_MAP[7],  # PICKUP in MINIGRID action set
                    operand=None
                )
                
                # Modify pickup action probability
                pickup_prob = config['action_probs']['pickup']
                probs = jax.lax.cond(
                    is_adjacent,
                    lambda p: p.at[pickup_idx].set(pickup_prob * 10),  # Boost probability
                    lambda p: p,
                    probs
                )
        
        # Bias toward open when near door
        if 'open' in config['action_probs']:
            player = env_state.env_state.timestep.state.entities['player'][0]
            door_entity = env_state.env_state.timestep.state.entities.get('door', [None])[0]
            
            if door_entity is not None:
                # Check if player is adjacent to door
                player_pos = player.position
                door_pos = door_entity.position
                
                # If manhattan distance is 1, increase open probability
                manhattan_dist = jnp.abs(player_pos[0] - door_pos[0]) + jnp.abs(player_pos[1] - door_pos[1])
                is_adjacent = manhattan_dist <= 1
                
                # Get the correct open action index based on action set
                open_idx = jax.lax.cond(
                    action_set_id == 0,
                    lambda _: 8,  # OPEN in COMPLETE action set
                    lambda _: MINIGRID_ACTION_MAP[8],  # TOGGLE (open) in MINIGRID action set
                    operand=None
                )
                
                # Modify open action probability
                open_prob = config['action_probs']['open']
                probs = jax.lax.cond(
                    is_adjacent,
                    lambda p: p.at[open_idx].set(open_prob * 10),  # Boost probability
                    lambda p: p,
                    probs
                )
                
        # Normalize probabilities
        probs = probs / jnp.sum(probs)
        
        # Sample action
        action = jax.random.categorical(key, jnp.log(probs))
        return action, None
    
    def rollout_env(n_samples, rng):
        """Execute a rollout using a mix of random and smart actions."""
        # Initialize environment and random keys
        rng_reset, rng_step, rng_policy, rng_noise = jax.random.split(rng, 4)
        init_obs, env_state = env_config.env_reset(rng_reset)
        
        # Initialize policy state for smart policy with dynamic grid size
        max_depth = 100  # reasonable max depth for pathfinding
        # env_state.timestep.state is the Navix state; get grid shape directly
        grid_shape = get_grid(env_state.env_state.timestep.state).shape
        
        # Create a properly initialized DFSState with all required fields
        dummy_path = DFSState(
            stack=jnp.zeros((max_depth, 2), dtype=jnp.int32),
            stack_index=-1,  # -1 indicates empty stack
            visited=jnp.zeros(grid_shape, dtype=jnp.bool_),
            path=jnp.zeros((max_depth, 2), dtype=jnp.int32),
            path_length=0,
            iterations=0
        )
        
        policy_state = PolicyState(path=dummy_path, current_idx=0, action_set_id=0 if action_set == "full" else 1)
        
        # Initialize behavior mode
        # 0: random, 1: smart, 2: goal-directed (key only), 3: goal-directed (door only), 4: exploration mode
        behavior_probs = jnp.array([
            config['random_prob'],
            config['smart_prob'],
            config['goal_directed_probs']['key_only'],
            config['goal_directed_probs']['door_only'],
            config.get('explore_prob', 0.0)
        ])
        behavior_probs = behavior_probs / jnp.sum(behavior_probs)  # Normalize
        behavior_mode = jax.random.categorical(rng_policy, jnp.log(behavior_probs))
        
        # Episode step counter
        episode_counter = 0
        # Episode length before switching behavior (between 10-20 steps)
        episode_length = jax.random.randint(rng_policy, (), 10, 21)
        
        def _rollout_step(state, rng):
            """One step of the rollout with behavior logic."""
            last_obs, env_state, last_done, policy_state, behavior_mode, episode_counter, episode_length = state
            
            # Split random key for different operations
            rng_action, rng_step, rng_policy, rng_noise, rng_switch = jax.random.split(rng, 5)
            
            # Check if we need to switch behavior
            switch_behavior = episode_counter >= episode_length
            
            # Also switch with random probability
            random_switch = jax.random.uniform(rng_switch) < config['goal_directed_probs']['random_switch_prob']
            switch_behavior = jnp.logical_or(switch_behavior, random_switch)
            
            # Reset counter and maybe change behavior mode
            episode_counter = jax.lax.cond(
                switch_behavior, 
                lambda _: 0,
                lambda _: episode_counter + 1,
                operand=None
            )
            
            episode_length = jax.lax.cond(
                switch_behavior,
                lambda _: jax.random.randint(rng_policy, (), 10, 21),
                lambda _: episode_length,
                operand=None
            )
            
            behavior_mode = jax.lax.cond(
                switch_behavior,
                lambda _: jax.random.categorical(rng_policy, jnp.log(behavior_probs)),
                lambda _: behavior_mode,
                operand=None
            )
            
            # Select action based on behavior mode
            def get_action():
                # Random behavior
                random_action, _ = random_policy_fn(rng_action, last_obs)
                # Biased random behavior
                biased_action, _ = biased_random_policy_fn(rng_action, last_obs, env_state, policy_state)
                # Smart behavior
                smart_action, smart_policy_state = doorkey_policy(
                    rng_action, 
                    env_state.env_state.timestep.state, 
                    policy_state
                )
                
                # Choose based on behavior mode
                # 0: random, 1: smart, 2: key only, 3: door only, 4: exploration
                def handle_random(_):
                    return random_action, policy_state
                    
                def handle_smart(_):
                    return smart_action, smart_policy_state
                    
                def handle_key_only(_):
                    player = env_state.env_state.timestep.state.entities['player'][0]
                    has_key = player.pocket != -1  # Check if player has key
                    return jax.lax.cond(
                        has_key,
                        lambda _: (random_action, policy_state),  # Use random after getting key
                        lambda _: (smart_action, smart_policy_state),  # Use smart to get key
                        operand=None
                    )
                    
                def handle_door_only(_):
                    door = env_state.env_state.timestep.state.entities['door'][0]
                    door_open = door.open  # Check if door is open
                    return jax.lax.cond(
                        door_open,
                        lambda _: (random_action, policy_state),  # Use random after opening door
                        lambda _: (smart_action, smart_policy_state),  # Use smart to open door
                        operand=None
                    )
                
                # Exploration behavior: pure random exploration
                def handle_explore(_):
                    return random_action, policy_state
                
                result = jax.lax.switch(
                    behavior_mode,
                    [handle_random, handle_smart, handle_key_only, handle_door_only, handle_explore],
                    None
                )
                return result
            
            action, new_policy_state = get_action()
            
            # Step the environment
            obs, env_state, reward, done, info = env_config.env_step(rng_step, env_state, action)
            
            # Extract state information from info dictionary
            navix_state = info.get('state')
            
            # Create transition data
            transition = MinigridTransitionData(
                obs=last_obs,
                action=action,
                reward=reward,
                done=done.astype(jnp.float32),
                is_first=last_done,
                state=navix_state,  # Store the factored state from info
                graph=info['local_graph']
            )
            
            # Ensure done is float32 to match type in the input
            done_float = done.astype(jnp.float32)
            
            next_state = (obs, env_state, done_float, new_policy_state, behavior_mode, episode_counter, episode_length)
            return next_state, (transition, env_state)
        
        # Initial state - ensure done is float32
        init_done = jnp.array(False, dtype=jnp.float32)
        
        # Initial state tuple
        init_state = (
            init_obs, 
            env_state, 
            init_done,
            policy_state,
            behavior_mode,
            episode_counter,
            episode_length
        )
        
        # Run rollout using scan
        last_state, (data, env_states) = jax.lax.scan(
            _rollout_step,
            init_state,
            jax.random.split(rng_step, n_samples)
        )
        
        # Add final observation 
        last_obs = last_state[0] #+ jax.random.normal(rng_noise, last_state[0].shape) * 1e-2
        
        # For the final state, extract the state info
        final_env_state = last_state[1]
        final_state_info = env_config.base_env.get_state(final_env_state.env_state.timestep.state)
        
        data = data.replace(
            obs=jnp.concatenate([data.obs, last_obs[None]], axis=0),
            state=jnp.concatenate([data.state, final_state_info[None]], axis=0)
        )
        
        return data, jax.tree.map(lambda x1, x0: jnp.concatenate([x0[None], x1], axis=0), env_states, env_state)
    
    # Generate dataset
    dataset, env_states = rollout_env(horizon, rng)
    
    # Create simplified environment config for return
    # simple_env_config = EnvConfig(
    #     obs=env_config.obs,
    #     n_actions=env_config.n_actions
    # )
    
    return (dataset, env_config), env_states

def minigrid_doorkey_data_collection(config, behavior_config):
    """Collects data by interacting with the minigrid environment using a random policy.
    Processes environments in batches to avoid memory issues.
    
    Args:
        config: Configuration object containing environment and data collection parameters
        
    Returns:
        tuple: (dataset, env_config) containing the collected data and environment configuration
    """
    env_config = config.envs[config.env]
    rng = jax.random.key(config.data_collection.seed)
    rng, rng_env = jax.random.split(rng)
    action_set = env_config.action_space
    env_config = EnvironmentBuilder.build(
        env_config,
        rng=rng_env
    )
    
    # # Generate dataset
    # (dataset, env_config), _ = mixed_doorkey_data_generator(
    #     env_config,
    #     config.data_collection.horizon,
    #     rng,
    #     config=behavior_config
    # )

    return lambda horizon, rng: mixed_doorkey_data_generator(
        env_config,
        horizon,
        rng,
        config=behavior_config,
        action_set=action_set
    )[0]

def minigrid_data_collection(config):
    """Collects data by interacting with the minigrid environment using a random policy.
    Processes environments in batches to avoid memory issues.
    
    Args:
        config: Configuration object containing environment and data collection parameters
        
    Returns:
        tuple: (dataset, env_config) containing the collected data and environment configuration
    """
    env_config = config.envs[config.env]
    rng = jax.random.key(config.data_collection.seed)
    rng, rng_env = jax.random.split(rng)
    env_config = EnvironmentBuilder.build(
        env_config,
        rng=rng_env
    )

    # Use mixed data generator if specified in config
    if hasattr(config.data_collection, 'use_mixed_policy') and config.data_collection.use_mixed_policy:
        behavior_config = config.data_collection.get('behavior_config', None)
        return partial(mixed_doorkey_data_generator, env_config, config=behavior_config)
    else:
        return partial(minigrid_data_generator, env_config)

def visualize_minigrid_state(state, save_path=None):
    """Visualize a minigrid environment state.
    
    Args:
        state: The environment state to visualize
        save_path: Optional path to save the visualization
        
    Returns:
        The matplotlib figure
    """
    fig, ax = plt.subplots(figsize=(4, 4))
    
    # Get the grid representation
    grid = state.grid
    
    # Create a custom colormap for different objects
    cmap = plt.cm.tab10
    
    # Plot the grid
    ax.imshow(grid, cmap=cmap)
    ax.axis('off')
    
    # Add title
    ax.set_title("Minigrid State")
    
    plt.tight_layout()
    
    # Save if path provided
    if save_path:
        plt.savefig(save_path)
        print(f"Visualization saved to {save_path}")
    
    return fig

def visualize_minigrid_trajectory(trajectory, save_path=None):
    """Visualize a sequence of minigrid states.
    
    Args:
        trajectory: List of environment states
        save_path: Optional path to save the visualization
        
    Returns:
        The matplotlib figure
    """
    n_states = len(trajectory)
    fig, axes = plt.subplots(1, n_states, figsize=(4 * n_states, 4))
    if n_states == 1:
        axes = [axes]
    
    for i, state in enumerate(trajectory):
        grid = state.grid
        axes[i].imshow(grid, cmap='tab10')
        axes[i].set_title(f"Step {i+1}")
        axes[i].axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
        print(f"Visualization saved to {save_path}")
    
    return fig

