#!/usr/bin/env python
"""Q-learning with options used only for exploration.

This script evaluates Q-learning where options are used ONLY for exploration
(not for credit assignment). In epsilon-greedy exploration, the agent
equally likely chooses between options and primitive actions. Options do
not participate in Q-value updates or exploitation.
"""
from __future__ import annotations
import argparse, random
from pathlib import Path
from typing import Dict, List
import glob, os
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt

# ----------------- default hyper-parameters -----------------
EPISODES = 2000
STEPS_PER_EP = 200
MAX_STEPS = 100000  # Total training steps (same for all methods)
EVAL_STEP_INTERVAL = 1000  # Evaluate every N steps
EVAL_EVERY = 10  # Deprecated: kept for backward compatibility
EVAL_EPISODES = 10
EPSILON = 0.1
ALPHA = 0.1
GAMMA = 0.99

# ----------------- utility functions ------------------------
def option_can_start(q_row: np.ndarray) -> bool:
    """
    Gridworld-style initiation: full state space, but an option is
    considered startable at s only if its local Q-max is strictly > 0.
    """
    return q_row.max() > 0


def option_terminated(q_row: np.ndarray, L: int = 15) -> bool:
    """
    Gridworld-style termination during rollout: an option terminates if
    its local Q-max is non-positive OR with probability 1/L at each step.
    """
    return (q_row.max() <= 0) or (random.random() < 1.0 / L)


def random_argmax(qrow: np.ndarray) -> int:
    """Random tie-breaking for argmax."""
    best = np.flatnonzero(qrow == qrow.max())
    return int(random.choice(best))


# ---------------- Q-learning agent with option exploration ------------------
class QLearningAgentWithOptionExploration:
    """
    Q-learning agent where options are used ONLY for exploration.
    Q-function only contains primitive actions (not options).
    """
    def __init__(
        self,
        env: gym.Env,
        Qopt: np.ndarray | None,
        epsilon: float,
        alpha: float,
        gamma: float,
    ):
        self.env = env
        self.S = env.observation_space.n
        self.Ap = env.action_space.n  # Number of primitive actions

        self.Qopt = Qopt
        self.K = 0 if Qopt is None else Qopt.shape[0]

        # Q-function ONLY for primitive actions (not options)
        self.Q = np.zeros((self.S, self.Ap), np.float32)
        self.eps, self.alpha, self.gamma = epsilon, alpha, gamma

        # Option policies for rollout
        if self.K:
            self.opt_pi = np.zeros((self.K, self.S), dtype=int)
            for o in range(self.K):
                for s in range(self.S):
                    row = Qopt[o, s]
                    best = np.flatnonzero(row == row.max())
                    self.opt_pi[o, s] = int(random.choice(best))
        else:
            self.opt_pi = None

    # ------ ε-greedy: options only used in exploration -------------
    def _select_action(self, s: int, *, greedy: bool = False) -> tuple[int, bool]:
        """
        Select action (primitive or option).
        
        Returns:
            (action, is_option): 
            - If primitive: (action_idx, False)
            - If option: (option_id, True)
        """
        # Exploration: equally likely to choose option or primitive action
        if (not greedy) and random.random() < self.eps:
            if self.K > 0:
                # Get startable options
                startable_opts = [
                    oid for oid in range(self.K) 
                    if option_can_start(self.Qopt[oid, s])
                ]
                
                # Equally likely to choose option or primitive action
                if startable_opts and random.random() < 0.5:
                    # Choose a random startable option
                    oid = int(random.choice(startable_opts))
                    return (oid, True)
            
            # Choose a random primitive action
            a = int(random.randrange(self.Ap))
            return (a, False)

        # Exploitation: only choose from primitive actions
        a = random_argmax(self.Q[s])
        return (a, False)

    # ------ run one episode ---------------------------------
    def run_episode(self, max_len: int, *, train: bool) -> tuple[float, bool, int]:
        """
        Run one episode.
        
        Returns:
            (total_reward, success, num_steps): 
            - total_reward: cumulative reward
            - success: True if task completed successfully (reward == 20 for Taxi-v3)
            - num_steps: number of steps taken in this episode
        """
        s, _ = self.env.reset()
        s = int(s)
        total_r, steps = 0.0, 0
        success = False

        while steps < max_len:
            action_or_opt, is_option = self._select_action(s, greedy=not train)
            term, trunc = False, False

            # ---- primitive action ------------------------------------
            if not is_option:
                a = action_or_opt
                sn, r, term, trunc, _ = self.env.step(a)
                sn = int(sn)
                total_r += r
                
                # Check for success (Taxi-v3: reward == 20 means successful pickup+dropoff)
                if r == 20:
                    success = True

                if train:
                    # Q-learning update for primitive action
                    boot = 0.0 if (term or trunc) else self.Q[sn].max()
                    td = r + self.gamma * boot - self.Q[s, a]
                    self.Q[s, a] += self.alpha * td

                s = sn
                steps += 1
                if term or trunc:
                    break

            # ---- option rollout (with Q-learning updates for primitive actions) ----
            else:
                oid = action_or_opt
                # Execute option rollout
                while steps < max_len and not option_terminated(
                    self.Qopt[oid, s]
                ):
                    ain = int(self.opt_pi[oid, s])
                    sn, r, term, trunc, _ = self.env.step(ain)
                    sn = int(sn)
                    total_r += r
                    
                    # Check for success (Taxi-v3: reward == 20 means successful pickup+dropoff)
                    if r == 20:
                        success = True
                    
                    # Q-learning update for primitive actions within option rollout
                    if train:
                        boot = 0.0 if (term or trunc) else self.Q[sn].max()
                        td = r + self.gamma * boot - self.Q[s, ain]
                        self.Q[s, ain] += self.alpha * td
                    
                    s = sn
                    steps += 1
                    if term or trunc:
                        break

                if term or trunc:
                    break
        return total_r, success, steps


# ---------------- option file loading -----------------------
def load_option_groups(env_id: str, out_dir: Path, outer: int):
    groups = {"random": [], "eigen": [], "vps": []}
    for f in out_dir.glob(f"{env_id}_*_*Opt_*.npy"):
        name = f.name.lower()
        if "random" in name:
            groups["random"].append(np.load(f))
        elif "eigen" in name:
            groups["eigen"].append(np.load(f))
        elif "vps" in name:
            groups["vps"].append(np.load(f))

    for k, lst in groups.items():
        if len(lst) < outer:
            raise RuntimeError(f"[{env_id}] {k} groups={len(lst)} < outer={outer}")
        groups[k] = lst[:outer]
    return groups


# ---------------- run experiment ----------------------------
def run_experiment(
    env_id: str,
    groups,
    outer,
    inner,
    max_steps: int,
    max_len,
    eps,
    alpha,
    gamma,
    eval_trials: int,
    eval_step_interval: int,
    print_freq: int = 100,
):
    """
    Run experiment where all methods train for the same number of steps.
    
    Args:
        max_steps: Total training steps (same for all methods)
        eval_step_interval: Evaluate every N steps
    """
    methods = ["primitive", "random", "eigen", "vps"]
    # Use lists to store data flexibly, convert to arrays at the end
    curves_list = {m: [] for m in methods}
    step_counts_list = {m: [] for m in methods}

    run = 0
    for g in range(outer):
        Q_r, Q_e, Q_v = groups["random"][g], groups["eigen"][g], groups["vps"][g]

        for inn in range(inner):
            seed = g * 10000 + inn
            random.seed(seed)
            np.random.seed(seed)

            for meth, Qopt in [
                ("primitive", None),
                ("random", Q_r),
                ("eigen", Q_e),
                ("vps", Q_v),
            ]:
                env = gym.make(env_id)
                ag = QLearningAgentWithOptionExploration(
                    env, Qopt, eps, alpha, gamma
                )
                
                # Track success count and total steps
                success_count = 0
                total_steps = 0
                last_eval_step = 0
                last_print_step = 0
                
                # Store evaluation results for this run
                run_rewards = []
                run_steps = []
                
                # Train until reaching max_steps
                while total_steps < max_steps:
                    # Run one episode
                    _, success, ep_steps = ag.run_episode(max_len, train=True)
                    if success:
                        success_count += 1
                    total_steps += ep_steps
                    
                    # Periodic printing of success count (based on steps)
                    if total_steps - last_print_step >= print_freq * max_len:
                        if meth == "primitive":
                            print(f"Primitive, step {total_steps}, {success_count} success")
                        else:
                            print(f"{meth} option, step {total_steps}, {success_count} success")
                        last_print_step = total_steps
                    
                    # Evaluate at fixed step intervals
                    if total_steps - last_eval_step >= eval_step_interval:
                        ret = sum(
                            ag.run_episode(max_len, train=False)[0]
                            for _ in range(eval_trials)
                        ) / eval_trials
                        run_rewards.append(ret)
                        run_steps.append(total_steps)
                        last_eval_step = total_steps
                    
                    # Check if we've exceeded max_steps
                    if total_steps >= max_steps:
                        break
                
                # Final evaluation at max_steps
                if len(run_steps) == 0 or run_steps[-1] < max_steps:
                    ret = sum(
                        ag.run_episode(max_len, train=False)[0]
                        for _ in range(eval_trials)
                    ) / eval_trials
                    run_rewards.append(ret)
                    run_steps.append(total_steps)
                
                # Store this run's data
                curves_list[meth].append(run_rewards)
                step_counts_list[meth].append(run_steps)

                env.close()
            run += 1
            print(f"[{env_id}] group={g} seed={inn} finished")
    
    # Find minimum number of evaluation points across all runs and methods
    all_min_lengths = []
    for m in methods:
        if curves_list[m]:  # Only consider methods with data
            all_min_lengths.append(min(len(run_data) for run_data in curves_list[m]))
    
    if not all_min_lengths:
        # No data collected, return empty arrays
        return {m: np.zeros((0, 0), np.float32) for m in methods}, \
               {m: np.zeros((0, 0), np.int64) for m in methods}
    
    min_eval_points = min(all_min_lengths)
    
    # Convert to arrays and trim to same length
    curves = {}
    step_counts = {}
    for m in methods:
        num_runs = len(curves_list[m])
        if num_runs > 0:
            # Initialize arrays
            curves[m] = np.zeros((num_runs, min_eval_points), np.float32)
            step_counts[m] = np.zeros((num_runs, min_eval_points), np.int64)
            
            # Fill arrays
            for run_idx, (run_rewards, run_steps) in enumerate(zip(curves_list[m], step_counts_list[m])):
                actual_len = min(len(run_rewards), min_eval_points)
                curves[m][run_idx, :actual_len] = run_rewards[:actual_len]
                step_counts[m][run_idx, :actual_len] = run_steps[:actual_len]
        else:
            curves[m] = np.zeros((0, min_eval_points), np.float32)
            step_counts[m] = np.zeros((0, min_eval_points), np.int64)
    
    return curves, step_counts


# ---------------- plotting helpers --------------------------
def _moving_average(arr: np.ndarray, k: int) -> np.ndarray:
    """Simple 1-D moving average; k=1 returns the input unchanged."""
    if k <= 1:
        return arr
    kernel = np.ones(k) / k
    return np.convolve(arr, kernel, mode="valid")  # length T-k+1


def plot_curves(
    curves: Dict[str, np.ndarray],
    step_counts: Dict[str, np.ndarray],
    env_id: str,
    win: int = 1,
    save_path: Path | None = None,
):
    """Display (and optionally save) reward curves. `win` is MA window.
    
    Args:
        curves: Dictionary mapping method names to reward arrays (runs, timepoints)
        step_counts: Dictionary mapping method names to step count arrays (runs, timepoints)
        env_id: Environment identifier
        win: Moving average window size
        save_path: Optional path to save the plot
    """
    colors = {
        "primitive": "tab:blue",
        "random": "tab:purple",
        "eigen": "tab:orange",
        "vps": "tab:green",
    }
    labels = {
        "primitive": "Q-Learning (Primitive Only)",
        "random": "Q-Learning + Random Option Exploration",
        "eigen": "Q-Learning + Eigen Option Exploration",
        "vps": "Q-Learning + VPS Option Exploration",
    }

    plt.figure(figsize=(8, 4.5))

    for m, mat in curves.items():
        # Calculate mean and std across runs
        mean_rewards = mat.mean(0)
        std_rewards = mat.std(0)
        
        # Get corresponding step counts (mean across runs)
        mean_steps = step_counts[m].mean(0).astype(float)
        
        # Apply moving average if needed
        if win > 1:
            mean_rewards = _moving_average(mean_rewards, win)
            std_rewards = _moving_average(std_rewards, win)
            # For step counts, we need to align with the smoothed rewards
            # Since moving average reduces length, we need to adjust step_counts accordingly
            mean_steps = mean_steps[win - 1:]
        
        plt.plot(mean_steps, mean_rewards, color=colors[m], label=labels[m])
        plt.fill_between(mean_steps, mean_rewards - std_rewards, mean_rewards + std_rewards, 
                        color=colors[m], alpha=0.25)

    plt.xlabel("Time Steps", fontsize=16)
    plt.ylabel("Return", fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    title = f"{env_id} (Options for Exploration Only, window={win})" if win > 1 else f"{env_id} (Options for Exploration Only)"
    plt.title(title, fontsize=16)
    plt.legend(fontsize=14)
    plt.grid(alpha=0.3)
    plt.tight_layout()

    if save_path is not None:
        plt.savefig(save_path)
        print(f"[✓] plot saved → {save_path.resolve()}")

    plt.show()


# ----------------------------- CLI --------------------------
def main():
    pa = argparse.ArgumentParser(
        description="Q-learning with options used only for exploration"
    )
    pa.add_argument("--env", default="Taxi-v3")
    pa.add_argument("--out_dir", default="option_results")
    pa.add_argument("--outer", type=int, default=5)
    pa.add_argument("--inner", type=int, default=5)
    pa.add_argument("--episodes", type=int, default=None,
                    help="Number of episodes (deprecated, use --max_steps instead)")
    pa.add_argument("--max_steps", type=int, default=MAX_STEPS,
                    help="Total training steps (same for all methods)")
    pa.add_argument("--max_len", type=int, default=STEPS_PER_EP)
    pa.add_argument("--epsilon", type=float, default=EPSILON)
    pa.add_argument("--alpha", type=float, default=ALPHA)
    pa.add_argument("--gamma", type=float, default=GAMMA)
    pa.add_argument("--plot_path", default="qlearning_option_exploration_curve.png")
    pa.add_argument(
        "--smooth",
        type=int,
        default=1,  # 1 = no smoothing
        help="moving-average window length for the reward curve",
    )
    pa.add_argument(
        "--eval_trials",
        type=int,
        default=10,
        help="number of evaluation episodes to average",
    )
    pa.add_argument(
        "--eval_step_interval",
        type=int,
        default=EVAL_STEP_INTERVAL,
        help="evaluate every N steps",
    )
    pa.add_argument(
        "--print_freq",
        type=int,
        default=100,
        help="frequency (in episodes) for printing success count during training",
    )

    args = pa.parse_args()

    # Resolve option directory relative to this script:
    # `discrete_gym/option_results/<Env>/...`.
    script_dir = os.path.dirname(os.path.abspath(__file__))
    root_dir = os.path.join(script_dir, args.out_dir)
    save_dir = os.path.join(root_dir, args.env)
    os.makedirs(save_dir, exist_ok=True)
    plot_path = os.path.join(save_dir, args.plot_path)

    groups = load_option_groups(args.env, Path(save_dir), args.outer)
    curves, step_counts = run_experiment(
        args.env,
        groups,
        args.outer,
        args.inner,
        args.max_steps,
        args.max_len,
        args.epsilon,
        args.alpha,
        args.gamma,
        args.eval_trials,
        args.eval_step_interval,
        args.print_freq,
    )
    plot_curves(curves, step_counts, env_id=args.env, win=args.smooth, save_path=Path(plot_path))


if __name__ == "__main__":
    main()

