import gymnasium as gym
import numpy as np
from numpy.typing import NDArray
from typing import Any
from abc import ABC, abstractmethod
from umfavi.utils.tabular import q_opt
from umfavi.utils.math import softmax
import stable_baselines3 as sb3
from umfavi.envs.env_types import TabularEnv
from umfavi.utils.torch_utils import get_model_device, to_numpy, to_torch
from stable_baselines3.common.utils import obs_as_tensor
from pathlib import Path
import torch
import torch.nn as nn
import matplotlib.pyplot as plt


def create_policy(policy_path: str, beta: float, env: gym.Env, gamma: float = 0.99):
    """
    Create a policy from a saved model path.
    
    Args:
        policy_path: Path to the saved policy model.
        beta: Rationality parameter for Q-value based policies.
        env: The environment to create the policy for.
        gamma: Discount factor (used for tabular environments).
        
    Returns:
        A policy object that can be used for trajectory generation.
    """
    # Expand ~ to home directory if present
    load_path = str(Path(policy_path).expanduser())
    
    # Check for policy type in path by looking for directory names "/ppo/" or "/dqn/"
    policy_path_lower = load_path.lower()
    is_tabular_env = isinstance(env.unwrapped, TabularEnv)
    if is_tabular_env:
        q_model = TabularQValueModel(env.unwrapped, gamma=gamma)
        return QValueExpert(q_model, beta=beta)
    else:
        # Remove .zip extension if present (sb3.load expects path without extension)
        if "/ppo/" in policy_path_lower:
            print(f"Detected PPO policy from path. PPO does not support rationality parameter, using default value of 1.0")
            return PPOExpert(sb3.PPO.load(load_path))
        elif "/dqn/" in policy_path_lower:
            print(f"Detected DQN policy from path")
            dqn_model = DQNQValueModel(sb3.DQN.load(load_path))
            return QValueExpert(dqn_model, beta=beta)
        else:
            raise ValueError(f"Could not determine policy type from path.")


class QValueModel(ABC):
    """
    Abstract base class for Q-value models.
    
    Q-value models provide the expected return Q(s,a) for state-action pairs.
    """
    
    @abstractmethod
    def q_values(self, observation) -> NDArray:
        """
        Get Q-values for all actions given an observation.
        
        Args:
            observation: Environment observation
            
        Returns:
            Array of Q-values for each action
        """
        pass


class TabularQValueModel(QValueModel):
    """
    Q-value model for tabular environments.
    
    Computes optimal Q-values from transition dynamics P and rewards R.
    """
    
    def __init__(self, env: TabularEnv, gamma: float = 0.99):
        """
        Initialize tabular Q-value model.
        
        Args:
            env: Tabular environment with P and R attributes
            gamma: Discount factor for Q-value computation
        """
        R = env.get_reward_matrix()
        P = env.get_transition_matrix()
        self.Q_optimal = q_opt(P, R, gamma)
    
    def q_values(self, observation) -> NDArray:
        """Get Q-values for an observation."""
        # Obs is the state index  
        return self.Q_optimal[observation]


class DQNQValueModel(QValueModel):
    """
    Q-value model using Deep Q-Networks.
    
    Wraps a trained DQN model from stable-baselines3.
    """
    
    def __init__(self, dqn_model: sb3.DQN):
        """
        Initialize DQN Q-value model.
        
        Args:
            dqn_model: Trained stable-baselines3 DQN model
        """
        self.model = dqn_model
        self.action_dim = dqn_model.action_space.n
        self.gamma = float(dqn_model.gamma)
    
    def q_values(self, obs: Any) -> NDArray:
        """Get Q-values for an observation."""
        obs_t = obs_as_tensor(obs, self.model.device)
        # Add batch dimension if missing (q_net expects batch dimension)
        single_obs = obs_t.ndim == 1
        if single_obs:
            obs_t = obs_t.unsqueeze(0)
        q_values = self.model.q_net(obs_t)
        if single_obs:
            q_values = q_values.squeeze(0)
        return to_numpy(q_values)


class NeuralQValueModel(QValueModel):
    """Wrapper to use a trained nn.Module as a QValueModel.

    For discrete actions (actions_discrete=True): q_net takes only state and outputs
    Q-values for all actions, shape (batch, num_actions).

    For continuous actions (actions_discrete=False): q_net takes state and action,
    outputs a single Q-value, shape (batch, 1).
    """

    def __init__(self, q_net: nn.Module, actions_discrete: bool = True):
        self.q_net = q_net
        self.device = get_model_device(q_net)
        self.actions_discrete = actions_discrete

    def q_values(self, obs: Any, action: Any = None) -> NDArray:
        """Get Q-values for an observation (and optionally action).

        Args:
            obs: Environment observation
            action: Action (required for continuous actions, ignored for discrete)

        Returns:
            For discrete actions: Array of Q-values for each action, shape (num_actions,) or (batch, num_actions)
            For continuous actions: Single Q-value, shape () or (batch,)
        """
        with torch.no_grad():
            obs_t = to_torch(obs, self.device)
            single_obs = obs_t.ndim == 1
            if single_obs:
                obs_t = obs_t.unsqueeze(0)

            if self.actions_discrete:
                q_vals = self.q_net(obs_t)
            else:
                if action is None:
                    raise ValueError("action must be provided for continuous action spaces")
                action_t = to_torch(action, self.device)
                if single_obs:
                    action_t = action_t.unsqueeze(0)
                q_vals = self.q_net(obs_t, action_t)

            if single_obs:
                q_vals = q_vals.squeeze(0)
            return to_numpy(q_vals)


class Expert(ABC):
    """
    Abstract base class for expert policies.
    """

    def __init__(self, model: Any, beta: float = 1.0):
        """
        Initialize expert policy.
        
        Args:
            q_model: Q-value model to use for action selection
            rationality: Rationality parameter (β) for softmax policy
        """
        self.model = model
        self.beta = beta
    
    @abstractmethod
    def predict(self, obs: Any, deterministic: bool = False):
        pass

    @abstractmethod
    def predict_proba(self, obs: Any) -> NDArray:
        pass


class QValueExpert(Expert):
    """
    Expert policy for Q-value models.
    """

    def __init__(self, model: QValueModel, beta: float = 1.0):
        """
        Initialize tabular expert policy.
        
        Args:
            q_model: TabularQValueModel for Q-value computation
            rationality: Rationality parameter (β) for softmax policy
        """
        super().__init__(model, beta)
    
    def predict(self, observation: Any, deterministic: bool = False):
        if deterministic or self.beta == float('inf'):
            q_values = self.model.q_values(observation)
            # Handle both single observation and batch
            # When multiple actions have equal max Q-value, sample uniformly among them
            if q_values.ndim == 1:
                max_q = np.max(q_values)
                max_actions = np.flatnonzero(q_values == max_q)
                action = np.random.choice(max_actions)
            else:
                max_q = np.max(q_values, axis=-1, keepdims=True)
                is_max = (q_values == max_q)
                action = np.array([np.random.choice(np.flatnonzero(row)) for row in is_max])
        else:
            probs = self.predict_proba(observation)
            # Handle both single observation and batch
            if probs.ndim == 1:
                # Single observation case
                action = np.random.choice(len(probs), p=probs)
            else:
                # Batch case: sample an action for each observation
                # Use cumulative probabilities for efficient vectorized sampling
                cumsum_probs = np.cumsum(probs, axis=-1)
                # Generate random values for each observation
                rand_vals = np.random.random(probs.shape[0])
                # Find the first index where cumulative prob >= random value
                action = np.argmax(cumsum_probs >= rand_vals[:, np.newaxis], axis=-1)
        return action

    def predict_proba(self, observation: Any) -> NDArray:
        q_values = self.model.q_values(observation)
        return softmax(self.beta * q_values, dims=-1)


class ContinuousQValueExpert(Expert):
    """
    Expert policy for continuous action spaces using Q-value models.

    Uses Cross-Entropy Method (CEM) to find actions that maximize Q-values.
    """

    def __init__(
        self,
        model: QValueModel,
        action_space: gym.spaces.Box,
        beta: float = 1.0,
        num_samples: int = 64,
        num_iterations: int = 3,
        elite_fraction: float = 0.25,
    ):
        """
        Initialize continuous Q-value expert policy.

        Args:
            model: QValueModel for Q-value computation (must support continuous actions)
            action_space: Continuous action space (gym.spaces.Box)
            beta: Rationality parameter (higher = more deterministic)
            num_samples: Number of action samples per CEM iteration
            num_iterations: Number of CEM iterations
            elite_fraction: Fraction of top samples to use for distribution update
        """
        super().__init__(model, beta)
        self.action_space = action_space
        self.action_low = action_space.low
        self.action_high = action_space.high
        self.action_dim = action_space.shape[0]
        self.num_samples = num_samples
        self.num_iterations = num_iterations
        self.num_elites = max(1, int(num_samples * elite_fraction))

    def predict(self, observation: Any, deterministic: bool = False) -> NDArray:
        """Select action using CEM optimization over Q-values."""
        obs = np.asarray(observation)
        single_obs = obs.ndim == 1

        if single_obs:
            obs = obs[np.newaxis, :]

        batch_size = obs.shape[0]
        actions = np.zeros((batch_size, self.action_dim))

        for b in range(batch_size):
            obs_b = obs[b]

            # Initialize CEM distribution (uniform over action space)
            mean = (self.action_low + self.action_high) / 2
            std = (self.action_high - self.action_low) / 4

            for _ in range(self.num_iterations):
                # Sample actions from current distribution
                samples = np.random.normal(mean, std, size=(self.num_samples, self.action_dim))
                samples = np.clip(samples, self.action_low, self.action_high)

                # Evaluate Q-values for all samples
                obs_batch = np.tile(obs_b, (self.num_samples, 1))
                q_values = self.model.q_values(obs_batch, samples).flatten()

                # Select elite samples
                elite_indices = np.argsort(q_values)[-self.num_elites:]
                elite_samples = samples[elite_indices]

                # Update distribution
                mean = np.mean(elite_samples, axis=0)
                std = np.std(elite_samples, axis=0) + 1e-6

            if deterministic or self.beta == float('inf'):
                actions[b] = mean
            else:
                # Add noise scaled by 1/beta for stochastic action selection
                noise_scale = 1.0 / (self.beta + 1e-6)
                actions[b] = mean + noise_scale * std * np.random.randn(self.action_dim)
                actions[b] = np.clip(actions[b], self.action_low, self.action_high)

        if single_obs:
            return actions[0]
        return actions

    def predict_proba(self, observation: Any) -> NDArray:
        """Not applicable for continuous actions - raises NotImplementedError."""
        raise NotImplementedError(
            "predict_proba is not supported for continuous action spaces. "
            "Use predict() instead."
        )


class PPOExpert(Expert):
    """
    Expert policy for PPO models.
    """

    def __init__(self, model: sb3.PPO, beta: float = 1.0):
        """
        Initialize PPO expert policy.
        """
        super().__init__(model, beta)

    def predict(self, observation: Any, deterministic: bool = False):
        action, _ = self.model.predict(observation, deterministic=deterministic)
        return action

    def predict_proba(self, observation: Any) -> NDArray:
        probs = self.model.predict_proba(observation)
        return probs