"""
Result collection and analysis for continuous action environments.

This module collects and analyzes results from agents trained with the EvA-RL framework
on continuous action environments. It provides functionality for loading experiment
results, parsing experiment names, computing statistics, and generating plots.
"""

import os
import argparse
import pickle
from typing import Any, Dict, List, Tuple, NamedTuple
from collections import defaultdict, namedtuple
from functools import partial

import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import sem
from tqdm import tqdm

# Set environment variables
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Configure matplotlib
plt.rcParams.update({
    "text.usetex": False,
    "font.family": "serif",
    "font.serif": ["Palatino"],
    "axes.labelsize": 18,
    "font.size": 18,
    "legend.fontsize": 18,
    "xtick.labelsize": 18,
    "ytick.labelsize": 18,
})
sns.set_style("darkgrid")

# Constants
THRESH = 1_000.0  # Threshold for MAE outliers
ENV_LOWEST_RETURNS = {
    "halfcheetah": -560,
    "reacher": -15.8,
    "ant": -900
}
SUPPORTED_ENVIRONMENTS = ["reacher", "ant", "halfcheetah"]
ESTIMATORS = ["fqe", "pdis", "dr", "predictor"]


class Transition(NamedTuple):
    """Transition tuple for storing environment transitions."""
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: Any


class ExperimentConfig:
    """Configuration class for experiment parameters."""
    
    def __init__(self, save_dir: str = "./complete_continuous_longrun/",
                 experiment_name: str = "",
                 num_eval_samples: int = 250,
                 num_trajs_per_state: int = 10,
                 experiment_name_file: str = "./complete_continuous_long_run_scripts/experiment_names.txt",
                 pred_coef: float = 0.0,
                 pred_lr: float = 0.0,
                 use_pretrained_transformer: int = 0):
        self.save_dir = save_dir
        self.experiment_name = experiment_name
        self.num_eval_samples = num_eval_samples
        self.num_trajs_per_state = num_trajs_per_state
        self.experiment_name_file = experiment_name_file
        self.pred_coef = pred_coef
        self.pred_lr = pred_lr
        self.use_pretrained_transformer = use_pretrained_transformer


def get_parser() -> argparse.ArgumentParser:
    """Create and return argument parser for command line arguments."""
    parser = argparse.ArgumentParser(description="Evaluation")
    parser.add_argument("--save_dir", default="./complete_continuous_longrun/", 
                       type=str, help="Directory to save results")
    parser.add_argument("--experiment_name", 
                       default="Pong-misc_64envs_100steps_10000000.0ts_2seed_10dtr_200gtr_3dt_Pong-misc_32bs_0.001lr_100ep_4h_8l_128hd", 
                       type=str)
    parser.add_argument("--num_eval_samples", default=250, type=int)
    parser.add_argument("--num_trajs_per_state", default=10, type=int, 
                       help="Number of trajectories to sample per state during MC evaluation")
    parser.add_argument("--experiment_name_file", type=str, 
                       default="./complete_continuous_long_run_scripts/experiment_names.txt",
                       help="File containing the experiment names to evaluate")
    parser.add_argument("--PREDICTABILITY_COEF", default=0.0, type=float, 
                       help="Predictability coefficient")
    parser.add_argument("--PRED_LR", default=0.0, type=float, 
                       help="Learning rate for the predictability transformer")
    parser.add_argument("--use_pretrained_transformer", default=0, type=int, 
                       help="Use pretrained transformer")
    return parser


def parse_experiment_name(experiment_name: str) -> Dict[str, Any]:
    """
    Parse experiment name to extract parameters.
    
    Example: 'reacher_2048envs_10steps_50000000.0ts_4seed_25dtr_1000gtr_3dt_reacher_256bs_0.001lr_100ep_4h_4l_16hd_0.0005pc_0usepretrained_0predlr'
    """
    parts = experiment_name.split("_")
    
    return {
        "env_name": parts[0],
        "num_envs": int(parts[1][:-4]),
        "num_steps": int(parts[2][:-6]),
        "num_timesteps": float(parts[3][:-2]),
        "seed": int(parts[4][:-4]),
        "num_dtr": int(parts[5][:-3]),
        "num_gtr": int(parts[6][:-3]),
        "num_dt": int(parts[7][:-2]),
        "env_name_2": parts[8],
        "batch_size": int(parts[9][:-2]),
        "learning_rate": float(parts[10][:-2]),
        "num_epochs": int(parts[11][:-2]),
        "num_heads": int(parts[12][:-1]),
        "num_layers": int(parts[13][:-1]),
        "hidden_dim": int(parts[14][:-2]),
        "pred_coef": float(parts[15][:-2]),
        "use_pretrained_transformer": int(parts[16][0]),
        "pred_lr": float(parts[17][:-len("predlr")])
    }


def load_eval_files(save_dir: str) -> List[str]:
    """Load evaluation files from the save directory."""
    eval_value_files = []
    for root, dirs, files in os.walk(save_dir):
        for file in files:
            if file.endswith("predlr_eval_values.pkl"):
                eval_value_files.append(file)
    
    print(f"Found {len(eval_value_files)} Eval Value files")
    return sorted(eval_value_files)


def load_experiment_results(save_dir: str, experiment_names: List[str]) -> Dict[str, Any]:
    """Load experiment results from pickle files."""
    eval_results = {}
    total_num = len(experiment_names)
    unsuccessful = 0
    
    for experiment_name in tqdm(experiment_names, desc="Loading experiments"):
        # Filter for specific predictability coefficients
        if not ("0.0005pc" in experiment_name or "0.0pc" in experiment_name):
            continue
            
        try:
            with open(f'{save_dir}/{experiment_name}_eval_values.pkl', "rb") as f:
                eval_values = pickle.load(f)
            eval_results[experiment_name] = eval_values
        except Exception as e:
            print(f"Error loading {experiment_name} eval values: {e}")
            unsuccessful += 1
            continue
    
    print(f"Successful experiments: {len(eval_results)}/{total_num}")
    return eval_results


def compute_group_mae(groups: Dict[str, List[str]], group_name: str, 
                     eval_results: Dict[str, Any]) -> Dict[str, Dict[str, float]]:
    """
    Compute Mean Absolute Error (MAE) for a group of experiments.
    
    Returns {estimator: {'agg', 'std', 'num_experiments'}} for one group,
    after removing MAE outliers ≥ THRESH.
    """
    experiments = groups[group_name]
    per_estimator = defaultdict(list)

    for exp in experiments:
        res = eval_results[exp]
        mc = res["mc"]  # ground-truth vector
        for key, val in res.items():
            if key == "mc":
                continue
            mae = float(np.mean(np.abs(val - mc)))
            if mae < THRESH:  # keep only reasonable MAEs
                per_estimator[key.lower()].append(mae)

    # Aggregate (skip estimators that have no remaining samples)
    stat_of_group = {
        est: {
            "agg": np.mean(vals),
            "std": sem(vals),
            "num_experiments": len(vals)
        }
        for est, vals in per_estimator.items() if vals
    }
    return stat_of_group


def compute_group_return(groups: Dict[str, List[str]], group_name: str, 
                        eval_results: Dict[str, Any]) -> Dict[str, float]:
    """Compute return statistics for a group of experiments."""
    experiments = groups[group_name]
    returns = []
    
    for experiment in experiments:
        returns.append(np.mean(eval_results[experiment]["mc"]))
    
    return {
        "agg": np.mean(returns),
        "std": sem(returns),
        "num_experiments": len(experiments)
    }


def group_experiments_by_env_and_params(experiment_names: List[str]) -> Dict[str, List[str]]:
    """Group experiments by environment name and key parameters."""
    groups = {}
    
    for experiment_name in experiment_names:
        parsed_experiment_name = parse_experiment_name(experiment_name)
        env_name = parsed_experiment_name["env_name"]
        num_dt = parsed_experiment_name["num_dt"]
        pred_coef = parsed_experiment_name["pred_coef"]
        pred_lr = parsed_experiment_name["pred_lr"]
        use_pretrained_transformer = parsed_experiment_name["use_pretrained_transformer"]
        
        # Create a key for the group
        group_key = f"{env_name}_{num_dt}_{pred_coef}_{pred_lr}_{use_pretrained_transformer}"
        
        # Add the experiment to the group
        if group_key not in groups:
            groups[group_key] = []
        groups[group_key].append(experiment_name)
    
    return groups


def create_pred_coef_groups(eval_results: Dict[str, Any], num_dt: int = 5) -> Tuple[Dict, Dict, Dict]:
    """Create groups for predictability coefficient analysis."""
    pred_coef_groups = {}
    pred_coef_group_results = {}
    pred_coef_group_return = {}
    
    for use_pretrained_transformer in [0, 1]:
        for pred_lr in [0, 1e-4]:
            for experiment_name in eval_results.keys():
                parsed_experiment_name = parse_experiment_name(experiment_name)
                env_name = parsed_experiment_name["env_name"]
                num_dt_exp = parsed_experiment_name["num_dt"]
                pred_coef = parsed_experiment_name["pred_coef"]
                pred_lr_exp = parsed_experiment_name["pred_lr"]
                use_pretrained_transformer_exp = parsed_experiment_name["use_pretrained_transformer"]
                
                # Filter conditions
                if ((pred_lr == 0 and pred_lr_exp != 0) or 
                    (pred_lr != 0 and pred_lr_exp == 0)):
                    continue
                if (num_dt_exp != num_dt or 
                    use_pretrained_transformer_exp != use_pretrained_transformer):
                    continue
                
                # Create a key for the group
                group_key = f"{env_name}_{pred_coef}_{pred_lr_exp}_{use_pretrained_transformer_exp}"
                
                # Add the experiment to the group
                if group_key not in pred_coef_groups:
                    pred_coef_groups[group_key] = []
                pred_coef_groups[group_key].append(experiment_name)

            # Compute statistics for each group
            for pc_group in pred_coef_groups.keys():
                pc_group_stat = compute_group_mae(pred_coef_groups, pc_group, eval_results)
                pred_coef_group_results[pc_group] = pc_group_stat
                pc_group_return = compute_group_return(pred_coef_groups, pc_group, eval_results)
                pred_coef_group_return[pc_group] = pc_group_return
    
    return pred_coef_groups, pred_coef_group_results, pred_coef_group_return


def plot_evarl_vs_rl_comparison(pred_coef_group_return: Dict[str, Any], 
                               output_path: str = "./results/evarl_vs_rl_return_all_envs.pdf"):
    """Plot EvA-RL vs Standard RL comparison across environments."""
    # Settings that define each regime
    BASE_FILTER = dict(beta=0.0, lr=0.0, pre=1)     # standard RL
    EVA_FILTER = dict(lr=1e-3, pre=0)              # EvA-RL + predictor

    def unpack(k):
        env, beta, lr, pre = k.split("_", 3)
        return env, float(beta), float(lr), int(pre)

    # Collect baseline returns
    baseline_ret = {}
    bars = defaultdict(dict)

    for key, stat in pred_coef_group_return.items():
        env, beta, lr, pre = unpack(key)

        # Baseline
        if (beta == BASE_FILTER["beta"] and lr == BASE_FILTER["lr"] 
                and pre == BASE_FILTER["pre"]):
            baseline_ret[env] = float(stat["agg"])
            print(f"env: {env}, return: {baseline_ret[env]}, std: {stat['std']}")

    # Second pass for EvA-RL data
    for key, stat in pred_coef_group_return.items():
        env, beta, lr, pre = unpack(key)
        if env not in baseline_ret or baseline_ret[env] == 0:
            continue

        # EvA-RL
        if beta > 0 and lr == EVA_FILTER["lr"] and pre == EVA_FILTER["pre"]:
            print(f"EvARL env: {env}, return: {stat['agg']}, std: {stat['std']}")
            μ_norm = ((float(stat["agg"]) - ENV_LOWEST_RETURNS[env]) / 
                     (baseline_ret[env] - ENV_LOWEST_RETURNS[env]))
            σ_norm = float(stat["std"]) / (baseline_ret[env] - ENV_LOWEST_RETURNS[env])
            if beta not in bars[env]:  # keep first run per β
                bars[env][beta] = (μ_norm, σ_norm)

    # Plotting
    envs = [env for env in SUPPORTED_ENVIRONMENTS if env in bars]
    if not envs:
        print("No environments to plot")
        return

    # Set up the figure
    fig, ax = plt.subplots(figsize=(8, 5))
    
    # Horizontal baseline line at y = 1
    ax.axhline(1.0, color="red", linewidth=1.2, label="Standard RL")
    
    # Set up positions for grouped bars
    num_envs = len(envs)
    bar_width = 0.2
    x_positions = np.arange(num_envs)
    
    # Track the best beta for each environment
    best_betas = {}
    
    # Plot bars for each environment
    for i, env in enumerate(envs):
        beta_dict = bars[env]
        betas = sorted(beta_dict)
        if not betas:
            continue
            
        # Find the beta with the highest return
        best_beta = max(betas, key=lambda b: beta_dict[b][0])
        best_betas[env] = best_beta
        
        # Get the normalized return and std error for the best beta
        μ, σ = beta_dict[best_beta]
        print(f"env: {env}, best_beta: {best_beta}, μ: {μ}, σ: {σ}")
        
        # Plot the bar for this environment
        ax.bar(x_positions[i], μ, width=bar_width, yerr=σ, capsize=3,
               color="tab:blue", edgecolor="black", label="EvA-RL" if i == 0 else "")
    
    # Set up the x-axis
    ax.set_xticks(x_positions)
    ax.set_xticklabels([f"{env}" for env in envs], rotation=0)
    ax.set_ylabel("Normalized Return")
    
    # Add legend
    ax.legend(loc="upper center", bbox_to_anchor=(0.5, 1.15),
              frameon=False, ncol=2, fontsize=12)
    
    ax.grid(axis="y", linestyle="--", alpha=0.4)
    
    fig.tight_layout()
    fig.savefig(output_path)
    print(f"Saved plot to {output_path}")


def main():
    """Main function to run the result collection and analysis."""
    # Parse arguments
    parser = get_parser()
    args = parser.parse_args([])
    
    # Create configuration
    config = ExperimentConfig(
        save_dir=args.save_dir,
        experiment_name=args.experiment_name,
        num_eval_samples=args.num_eval_samples,
        num_trajs_per_state=args.num_trajs_per_state,
        experiment_name_file=args.experiment_name_file,
        pred_coef=args.PREDICTABILITY_COEF,
        pred_lr=args.PRED_LR,
        use_pretrained_transformer=args.use_pretrained_transformer
    )
    
    # Load evaluation files
    eval_value_files = load_eval_files(config.save_dir)
    experiment_names = [eval_file[:-len("_eval_values.pkl")] for eval_file in eval_value_files]
    print("Experiment names:", experiment_names)
    
    # Load experiment results
    eval_results = load_experiment_results(config.save_dir, experiment_names)
    experiment_names = list(eval_results.keys())
    
    # Count experiments by environment
    env_wise_count = {env: 0 for env in SUPPORTED_ENVIRONMENTS}
    for experiment_name in experiment_names:
        for env in SUPPORTED_ENVIRONMENTS:
            if env in experiment_name:
                env_wise_count[env] += 1
                break
    print("Environment wise count:", env_wise_count)
    
    # Group experiments
    groups_by_env_and_params = group_experiments_by_env_and_params(eval_results.keys())
    
    # Create predictability coefficient groups
    pred_coef_groups, pred_coef_group_results, pred_coef_group_return = create_pred_coef_groups(eval_results)
    
    # Generate comparison plot
    plot_evarl_vs_rl_comparison(pred_coef_group_return)


if __name__ == "__main__":
    main()


