import tempfile
from pathlib import Path
from typing import Any, Callable
import stable_baselines3 as sb3
from stable_baselines3.common.callbacks import EvalCallback, CallbackList
from stable_baselines3.common.monitor import Monitor
import torch.nn as nn # for type hints, ignore if unused
import yaml

from umfavi.true_reward_callback import TrueRewardCallback
from umfavi.utils.policies import PPOExpert


def linear_schedule(initial_value: float) -> Callable[[float], float]:
    """
    Create a linear learning rate schedule.
    
    Args:
        initial_value: Initial value (at progress_remaining=1.0)
        
    Returns:
        A function that takes the remaining progress (1.0 -> 0.0) and returns the scheduled value.
    """
    def schedule(progress_remaining: float) -> float:
        return progress_remaining * initial_value
    return schedule


def parse_schedule_string(value: Any) -> Any:
    """
    Parse schedule strings like 'lin_0.001' into callable schedules.
    
    Supports:
        - 'lin_<value>': Linear schedule from <value> to 0
        
    If the value is not a schedule string, returns it unchanged.
    """
    if not isinstance(value, str):
        return value
    
    if value.startswith("lin_"):
        try:
            initial_value = float(value[4:])
            return linear_schedule(initial_value)
        except ValueError:
            raise ValueError(f"Invalid linear schedule string: {value}")
    
    # Not a schedule string, return as-is
    return value


# Parameters that can be schedule strings (lin_, const_, etc.)
SCHEDULE_PARAMS = {"learning_rate", "clip_range", "clip_range_vf"}


def get_hyperparams(algo: str, env_name: str) -> dict[str, Any]:
    """
    Load hyperparameters for a given algorithm and environment from YAML files.
    
    Args:
        algo: Algorithm name (e.g., 'ppo', 'dqn', 'sac')
        env_name: Environment name (e.g., 'LunarLander-v3', 'CartPole-v1')
    
    Returns:
        Dictionary of hyperparameters suitable for the SB3 model constructor.
    """
    # Find the hyperparams directory relative to this file
    hyperparams_dir = Path(__file__).parent.parent.parent / "hyperparams"
    hyperparams_file = hyperparams_dir / f"{algo.lower()}.yml"
    
    if not hyperparams_file.exists():
        raise FileNotFoundError(f"Hyperparameters file not found: {hyperparams_file}")
    
    with open(hyperparams_file, "r") as f:
        all_hyperparams = yaml.safe_load(f)
    
    if env_name not in all_hyperparams:
        raise KeyError(f"Environment '{env_name}' not found in {hyperparams_file}")
    
    # Evaluate policy_kwargs string if present
    if "policy_kwargs" in all_hyperparams[env_name] and isinstance(all_hyperparams[env_name]["policy_kwargs"], str):
        all_hyperparams[env_name]["policy_kwargs"] = eval(all_hyperparams[env_name]["policy_kwargs"])
    
    hyperparams = all_hyperparams[env_name].copy()
    
    # Parse schedule strings (e.g., 'lin_0.001') into callable schedules
    for param in SCHEDULE_PARAMS:
        if param in hyperparams:
            hyperparams[param] = parse_schedule_string(hyperparams[param])
    
    return hyperparams


def train_dqn(wrapped_env, reference_env_name, eval_freq: int = 10000, n_eval_episodes: int = 5):
    """
    Train a DQN model on the wrapped environment using the appropriate hyperparams for the reference environment.
    Returns the best model based on evaluation performance.
    
    Args:
        wrapped_env: The environment to train on (with learned reward).
        reference_env_name: Name of the original environment (for loading hyperparams).
        eval_freq: How often to evaluate the model (in timesteps).
        n_eval_episodes: Number of episodes to run for each evaluation.
    """
    hyperparams = get_hyperparams("dqn", reference_env_name)
    n_timesteps = hyperparams.pop("n_timesteps")
    
    with tempfile.TemporaryDirectory() as tmpdir:
        true_reward_cb = TrueRewardCallback(window_size=100)
        # Create a monitored evaluation environment
        eval_env = Monitor(wrapped_env)
        eval_cb = EvalCallback(
            eval_env,
            best_model_save_path=tmpdir,
            eval_freq=eval_freq,
            n_eval_episodes=n_eval_episodes,
            deterministic=True,
            verbose=0
        )
        callback = CallbackList([true_reward_cb, eval_cb])
        
        dqn_model = sb3.DQN(env=wrapped_env, **hyperparams, verbose=1)
        dqn_model.learn(total_timesteps=n_timesteps, callback=callback, progress_bar=True)
        
        # Load and return the best model
        best_model_path = Path(tmpdir) / "best_model.zip"
        if best_model_path.exists():
            return sb3.DQN.load(best_model_path, env=wrapped_env)
        
        # Fallback to final model if no best model was saved
        return dqn_model

def train_ppo(
    make_env_fn: Callable,
    seed: int, 
    eval_freq: int = 10000,
    n_eval_episodes: int = 5,
    true_reward_threshold: float = None,
    verbose: int = 1,
    progress_bar: bool = True,
    n_envs: int = None,
    n_timesteps: int = None,
    reference_env_name: str = None,
):
    """
    Train a PPO model on the wrapped environment using the appropriate hyperparams for the reference environment.
    Returns the best model based on true reward performance.
    
    Args:
        make_env_fn: Factory function to create a single environment.
        seed: Random seed for environment.
        eval_freq: How often to evaluate the model (in timesteps per env).
        n_eval_episodes: Number of episodes to run for each evaluation.
        true_reward_threshold: If provided, stop training early if mean true reward goes below this threshold.
        verbose: Verbosity level for PPO model (0=none, 1=info). Default is 1.
        progress_bar: Whether to show progress bar during training. Default is True.
        n_envs: Number of parallel environments for training. If None, uses hyperparams or default (8).
        n_timesteps: Total timesteps for training. If None, uses hyperparams or default (1_000_000).
        reference_env_name: Environment name to load hyperparams for (e.g., 'LunarLander-v3'). 
                           If None, uses sensible defaults.
    """
    from stable_baselines3.common.vec_env import DummyVecEnv
    
    # Load hyperparameters if reference environment is specified
    ppo_kwargs = {}
    hp_n_envs = 8  # default
    hp_n_timesteps = 1_000_000  # default
    
    if reference_env_name is not None:
        try:
            hyperparams = get_hyperparams("ppo", reference_env_name)
            
            # Extract non-PPO params
            hp_n_envs = hyperparams.pop("n_envs", 8)
            hp_n_timesteps = hyperparams.pop("n_timesteps", 1_000_000)
            
            # Remove params that aren't PPO constructor args
            non_ppo_params = ["policy", "normalize", "env_wrapper", "frame_stack"]
            for param in non_ppo_params:
                hyperparams.pop(param, None)
            
            ppo_kwargs = hyperparams
            print(f"Loaded PPO hyperparams for {reference_env_name}: {list(ppo_kwargs.keys())}")
        except (FileNotFoundError, KeyError) as e:
            print(f"Warning: Could not load hyperparams for {reference_env_name}: {e}")
            print("Using default PPO hyperparameters")
    
    # Use function arguments if provided, otherwise use hyperparams/defaults
    actual_n_envs = n_envs if n_envs is not None else hp_n_envs
    actual_n_envs = 1 # for debugging
    actual_n_timesteps = n_timesteps if n_timesteps is not None else hp_n_timesteps
    actual_n_timesteps = 500_000 # for debugging
    
    # DummyVecEnv expects a list of callables (zero-arg functions) that return envs
    # Use i=i to capture current value in closure (Python closure gotcha)
    train_env = DummyVecEnv([lambda i=i: make_env_fn(seed + i) for i in range(actual_n_envs)])
    
    # Only start saving best model after 10% of training to avoid selecting undertrained models
    min_timesteps_before_save = int(0.1 * actual_n_timesteps)
    
    with tempfile.TemporaryDirectory() as tmpdir:
        # Track and save best model based on true reward
        true_reward_cb = TrueRewardCallback(
            window_size=100, 
            true_reward_threshold=true_reward_threshold,
            best_model_save_path=tmpdir,
            min_timesteps_before_save=min_timesteps_before_save,
        )
        
        # Create a separate evaluation environment for reproducibility
        # (using the same env for train and eval can cause state leakage)
        eval_env = Monitor(make_env_fn(seed))
        
        # Adjust eval_freq for vectorized env (eval_freq is per rollout collection)
        eval_cb = EvalCallback(
            eval_env,
            best_model_save_path=tmpdir,
            eval_freq=max(eval_freq // actual_n_envs, 1000),
            n_eval_episodes=n_eval_episodes,
            deterministic=True,
            verbose=0
        )
        callback = CallbackList([true_reward_cb, eval_cb])
        
        # Create PPO with hyperparams (or defaults)
        ppo_model = sb3.PPO(
            env=train_env, 
            policy="MlpPolicy", 
            verbose=verbose, 
            seed=seed,
            # **ppo_kwargs,
        )
        print(f"Training PPO for {actual_n_timesteps:,} timesteps with {actual_n_envs} parallel envs")
        ppo_model.learn(total_timesteps=actual_n_timesteps, callback=callback, progress_bar=progress_bar)
        
        # Clean up vectorized env
        train_env.close()
        
        # Load and return the best model based on TRUE reward
        best_true_reward_path = Path(tmpdir) / "best_true_reward_model.zip"
        if best_true_reward_path.exists():
            print(f"Loading best model (true reward: {true_reward_cb.best_mean_true_reward:.2f})")
            return sb3.PPO.load(best_true_reward_path, env=make_env_fn(seed))
        
        # Fallback to best model based on learned reward if no true reward model was saved
        best_model_path = Path(tmpdir) / "best_model.zip"
        if best_model_path.exists():
            print("Warning: No true reward model saved, falling back to best learned reward model")
            return sb3.PPO.load(best_model_path, env=make_env_fn(seed))
        
        # Fallback to final model if no best model was saved
        print("Warning: No best model saved, returning final model")
        return ppo_model