"""
Result collection and analysis for discrete action environments.

This module collects and analyzes results from agents trained with the EvA-RL framework
on discrete action environments. It provides functionality for loading experiment
results, parsing experiment names, computing statistics, and generating plots.
"""

import os
import argparse
import pickle
from typing import Any, Dict, List, Tuple, NamedTuple
from collections import defaultdict, namedtuple
from functools import partial
from itertools import cycle

import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import sem
from tqdm import tqdm

# Set environment variables
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Configure matplotlib
plt.rcParams.update({
    "text.usetex": False,
    "font.family": "serif",
    "font.serif": ["Palatino"],
    "axes.labelsize": 18,
    "font.size": 18,
    "legend.fontsize": 18,
    "xtick.labelsize": 18,
    "ytick.labelsize": 18,
})
sns.set_style("darkgrid")

# Constants
SUPPORTED_ENVIRONMENTS = ["Freeway-MinAtar", "Asterix-MinAtar", "SpaceInvaders-MinAtar"]
ESTIMATORS = ["fqe", "tis", "pdis", "dr", "predictor"]
OPE_ESTIMATORS = ["fqe", "tis", "pdis", "dr"]


class Transition(NamedTuple):
    """Transition tuple for storing environment transitions."""
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: Any


class ExperimentConfig:
    """Configuration class for experiment parameters."""
    
    def __init__(self, save_dir: str = "./complete_discrete_long_run/",
                 experiment_name: str = "",
                 num_eval_samples: int = 250,
                 num_trajs_per_state: int = 10,
                 experiment_name_file: str = "./complete_discrete_long_run_scripts/experiment_names.txt",
                 pred_coef: float = 0.0,
                 pred_lr: float = 0.0,
                 use_pretrained_transformer: int = 0):
        self.save_dir = save_dir
        self.experiment_name = experiment_name
        self.num_eval_samples = num_eval_samples
        self.num_trajs_per_state = num_trajs_per_state
        self.experiment_name_file = experiment_name_file
        self.pred_coef = pred_coef
        self.pred_lr = pred_lr
        self.use_pretrained_transformer = use_pretrained_transformer


def get_parser() -> argparse.ArgumentParser:
    """Create and return argument parser for command line arguments."""
    parser = argparse.ArgumentParser(description="Evaluation")
    parser.add_argument("--save_dir", default="./complete_discrete_long_run/", 
                       type=str, help="Directory to save results")
    parser.add_argument("--experiment_name", 
                       default="Pong-misc_64envs_100steps_10000000.0ts_2seed_10dtr_200gtr_3dt_Pong-misc_32bs_0.001lr_100ep_4h_8l_128hd", 
                       type=str)
    parser.add_argument("--num_eval_samples", default=250, type=int)
    parser.add_argument("--num_trajs_per_state", default=10, type=int, 
                       help="Number of trajectories to sample per state during MC evaluation")
    parser.add_argument("--experiment_name_file", type=str, 
                       default="./complete_discrete_long_run_scripts/experiment_names.txt",
                       help="File containing the experiment names to evaluate")
    parser.add_argument("--PREDICTABILITY_COEF", default=0.0, type=float, 
                       help="Predictability coefficient")
    parser.add_argument("--PRED_LR", default=0.0, type=float, 
                       help="Learning rate for the predictability transformer")
    parser.add_argument("--use_pretrained_transformer", default=0, type=int, 
                       help="Use pretrained transformer")
    return parser


def parse_experiment_name(experiment_name: str) -> Dict[str, Any]:
    """
    Parse experiment name to extract parameters.
    
    Example: 'Asterix-MinAtar_64envs_100steps_10000000.0ts_3seed_10dtr_200gtr_1dt_Asterix-MinAtar_256bs_0.001lr_100ep_4h_4l_16hd_0.05pc_1usepretrained_0.0001predlr'
    """
    parts = experiment_name.split("_")
    
    return {
        "env_name": parts[0],
        "num_envs": int(parts[1][:-4]),
        "num_steps": int(parts[2][:-6]),
        "num_timesteps": float(parts[3][:-2]),
        "seed": int(parts[4][:-4]),
        "num_dtr": int(parts[5][:-3]),
        "num_gtr": int(parts[6][:-3]),
        "num_dt": int(parts[7][:-2]),
        "env_name_2": parts[8],
        "batch_size": int(parts[9][:-2]),
        "learning_rate": float(parts[10][:-2]),
        "num_epochs": int(parts[11][:-2]),
        "num_heads": int(parts[12][:-1]),
        "num_layers": int(parts[13][:-1]),
        "hidden_dim": int(parts[14][:-2]),
        "pred_coef": float(parts[15][:-2]),
        "use_pretrained_transformer": int(parts[16][0]),
        "pred_lr": float(parts[17][:-len("predlr")])
    }


def load_a2c_files(save_dir: str) -> List[str]:
    """Load A2C files from the save directory."""
    a2c_files = []
    for root, dirs, files in os.walk(save_dir):
        for file in files:
            if file.endswith("_a2c_evarl.pkl"):
                a2c_files.append(file)
    
    print(f"Found {len(a2c_files)} A2C files")
    return sorted(a2c_files)


def load_experiment_results(save_dir: str, experiment_names: List[str]) -> Dict[str, Any]:
    """Load experiment results from pickle files."""
    eval_results = {}
    total_num = len(experiment_names)
    unsuccessful = 0
    
    for experiment_name in tqdm(experiment_names, desc="Loading experiments"):
        # Filter out previous runs with 1dt
        if "1dt" in experiment_name:
            continue
            
        try:
            with open(f'{save_dir}/{experiment_name}_eval_values.pkl', "rb") as f:
                eval_values = pickle.load(f)
            eval_results[experiment_name] = eval_values
        except Exception as e:
            print(f"Error loading {experiment_name} eval values: {e}")
            unsuccessful += 1
            continue
    
    print(f"Successful experiments: {len(eval_results)}/{total_num}")
    return eval_results


def compute_group_mae(groups: Dict[str, List[str]], group_name: str, 
                     eval_results: Dict[str, Any]) -> Dict[str, Dict[str, float]]:
    """Compute Mean Absolute Error (MAE) for a group of experiments."""
    experiments = groups[group_name]
    per_estimator = defaultdict(list)

    for exp in experiments:
        res = eval_results[exp]
        mc = res["mc"]  # ground-truth array
        for key, val in res.items():
            if key == "mc":
                continue
            per_estimator[key.lower()].append(float(np.mean(np.abs(val - mc))))

    # Aggregate
    stat_of_group = {
        est: {
            "agg": np.mean(vals),
            "std": sem(vals),
            "num_experiments": len(vals)
        }
        for est, vals in per_estimator.items()
    }
    return stat_of_group


def compute_group_return(groups: Dict[str, List[str]], group_name: str, 
                        eval_results: Dict[str, Any]) -> Dict[str, float]:
    """Compute return statistics for a group of experiments."""
    experiments = groups[group_name]
    returns = []
    
    for experiment in experiments:
        returns.append(np.mean(eval_results[experiment]["mc"]))
    
    return {
        "agg": np.mean(returns),
        "std": sem(returns),
        "num_experiments": len(experiments)
    }


def group_experiments_by_env_and_params(experiment_names: List[str]) -> Dict[str, List[str]]:
    """Group experiments by environment name and key parameters."""
    groups = {}
    
    for experiment_name in experiment_names:
        parsed_experiment_name = parse_experiment_name(experiment_name)
        env_name = parsed_experiment_name["env_name"]
        num_dt = parsed_experiment_name["num_dt"]
        pred_coef = parsed_experiment_name["pred_coef"]
        pred_lr = parsed_experiment_name["pred_lr"]
        use_pretrained_transformer = parsed_experiment_name["use_pretrained_transformer"]
        
        # Create a key for the group
        group_key = f"{env_name}_{num_dt}_{pred_coef}_{pred_lr}_{use_pretrained_transformer}"
        
        # Add the experiment to the group
        if group_key not in groups:
            groups[group_key] = []
        groups[group_key].append(experiment_name)
    
    return groups


def create_pred_coef_groups(eval_results: Dict[str, Any], num_dt: int = 5) -> Tuple[Dict, Dict, Dict]:
    """Create groups for predictability coefficient analysis."""
    pred_coef_groups = {}
    pred_coef_group_results = {}
    pred_coef_group_return = {}
    
    for use_pretrained_transformer in [0, 1]:
        for pred_lr in [0, 1e-4]:
            for experiment_name in eval_results.keys():
                parsed_experiment_name = parse_experiment_name(experiment_name)
                env_name = parsed_experiment_name["env_name"]
                num_dt_exp = parsed_experiment_name["num_dt"]
                pred_coef = parsed_experiment_name["pred_coef"]
                pred_lr_exp = parsed_experiment_name["pred_lr"]
                use_pretrained_transformer_exp = parsed_experiment_name["use_pretrained_transformer"]
                
                # Filter conditions
                if ((pred_lr == 0 and pred_lr_exp != 0) or 
                    (pred_lr != 0 and pred_lr_exp == 0)):
                    continue
                if (num_dt_exp != num_dt or 
                    use_pretrained_transformer_exp != use_pretrained_transformer):
                    continue
                
                # Create a key for the group
                group_key = f"{env_name}_{pred_coef}_{pred_lr_exp}_{use_pretrained_transformer_exp}"
                
                # Add the experiment to the group
                if group_key not in pred_coef_groups:
                    pred_coef_groups[group_key] = []
                pred_coef_groups[group_key].append(experiment_name)

            # Compute statistics for each group
            for pc_group in pred_coef_groups.keys():
                pc_group_stat = compute_group_mae(pred_coef_groups, pc_group, eval_results)
                pred_coef_group_results[pc_group] = pc_group_stat
                pc_group_return = compute_group_return(pred_coef_groups, pc_group, eval_results)
                pred_coef_group_return[pc_group] = pc_group_return
    
    return pred_coef_groups, pred_coef_group_results, pred_coef_group_return


def split_key(k: str) -> Tuple[str, float, float, int]:
    """Helper function to split group key into components."""
    env, beta, lr, pre = k.split("_", 3)
    return env, float(beta), float(lr), int(pre)


def plot_mae_ope_vs_evarl(pred_coef_group_results: Dict[str, Any]):
    """Plot MAE comparison between OPE estimators and EvA-RL predictor."""
    # Constants that define the two settings to compare
    OPE_FILTER = dict(pred_coef=0, pred_lr=0, pre=1)   # RL + OPE
    EVA_FILTER = dict(pred_lr=1e-3, pre=0)            # EvA-RL

    palette_ope = plt.cm.Pastel1.colors[:4]           # 4 distinct colours
    color_pred = "tab:orange"                         # predictor bars

    # Collect data
    baseline = {}                       # env -> {est: (mean, std)}
    eva = defaultdict(list)            # env -> list of (β, mean, std)

    for key, stat in pred_coef_group_results.items():
        env, beta, lr, pre = split_key(key)

        # RL + OPE baseline (β = 0)
        if (beta == OPE_FILTER["pred_coef"] and
            lr == OPE_FILTER["pred_lr"] and
            pre == OPE_FILTER["pre"]):
            baseline[env] = {e: (stat[e]["agg"], stat[e]["std"])
                             for e in OPE_ESTIMATORS}

        # EvA-RL predictor (pre=0, lr=1e-3, β>0)
        if (pre == EVA_FILTER["pre"] and
            lr == EVA_FILTER["pred_lr"] and
            beta > 0):
            eva[env].append((beta,
                             stat["predictor"]["agg"],
                             stat["predictor"]["std"]))

    # Sort β lists
    for env in eva:
        eva[env].sort(key=lambda tup: tup[0])

    # Plot for every environment that has both datasets
    for env in sorted(eva):
        if env not in baseline:
            continue                              # need baseline to compare

        betas = [0] + [b for (b, _, _) in eva[env]]
        x = np.arange(len(betas))
        width = 0.15                            # width of individual bars

        fig, ax = plt.subplots(figsize=(6, 5))

        # β = 0 cluster: 4 OPE estimator bars
        for j, est in enumerate(OPE_ESTIMATORS):
            mu, sd = baseline[env][est]
            offset = (-1.5 + j) * width          # centres the 4 bars
            ax.bar(x[0] + offset, mu, width,
                   yerr=sd, capsize=3,
                   color=palette_ope[j], edgecolor="black",
                   label=est.upper())

        # β > 0 clusters: Predictor bar
        for idx, (beta, mu, sd) in enumerate(eva[env], start=1):
            ax.bar(x[idx], mu, width*2, yerr=sd, capsize=3,
                   color=color_pred, edgecolor="black",
                   label="Predictor" if idx == 1 else None)

        # Cosmetics
        ax.set_xticks(x)
        ax.set_xticklabels(["Standard RL", f"EvA-RL\n$\\beta=0.01$", f"EvA-RL\n$\\beta=0.1$"], rotation=0)
        ax.set_ylabel("MAE")
        ax.grid(axis="y", linestyle="--", alpha=0.4)
        ax.legend(frameon=False, ncol=1, loc="upper right")
        fig.tight_layout()
        
        out = f"./results/mae_ope_vs_evarl_{env}.pdf"
        fig.savefig(out)
        plt.close(fig)
        print("saved", out)


def plot_evarl_vs_rl_return(pred_coef_group_return: Dict[str, Any]):
    """Plot EvA-RL vs Standard RL return comparison."""
    # Settings that define each regime
    BASE_FILTER = dict(beta=0.0, lr=0.0, pre=1)     # standard RL
    EVA_FILTER = dict(lr=1e-3, pre=0)              # EvA-RL + predictor

    # Collect baseline returns
    baseline_ret = {}
    bars = defaultdict(dict)

    for key, stat in pred_coef_group_return.items():
        env, beta, lr, pre = split_key(key)

        # Baseline
        if (beta == BASE_FILTER["beta"] and lr == BASE_FILTER["lr"] 
                and pre == BASE_FILTER["pre"]):
            baseline_ret[env] = float(stat["agg"])

    # Second pass for EvA-RL data
    for key, stat in pred_coef_group_return.items():
        env, beta, lr, pre = split_key(key)
        if env not in baseline_ret or baseline_ret[env] == 0:
            continue

        # EvA-RL
        if beta > 0 and lr == EVA_FILTER["lr"] and pre == EVA_FILTER["pre"]:
            μ_norm = float(stat["agg"]) / baseline_ret[env]
            σ_norm = float(stat["std"]) / abs(baseline_ret[env])
            if beta not in bars[env]:  # keep first run per β
                bars[env][beta] = (μ_norm, σ_norm)

    # Plotting
    for env, beta_dict in bars.items():
        betas = sorted(beta_dict)
        if not betas:
            continue

        μ = [beta_dict[b][0] for b in betas]
        σ = [beta_dict[b][1] for b in betas]
        x = np.arange(len(betas))
        w = 0.6

        fig, ax = plt.subplots(figsize=(5, 3.2))
        # Horizontal baseline line at y = 1
        ax.axhline(1.0, color="red", linewidth=1.2, label="Standard RL")

        # EvA-RL bars
        ax.bar(x, μ, width=w, yerr=σ, capsize=3,
               color="tab:blue", edgecolor="black", label="EvA-RL")

        ax.set_xticks(x)
        ax.set_xticklabels([f"β={b:g}" for b in betas], rotation=25)
        ax.set_ylabel("Norm. Return")
        ax.legend(loc="upper center", bbox_to_anchor=(.5, 1.3),
                  frameon=False, ncol=2, fontsize=12)
        ax.grid(axis="y", linestyle="--", alpha=0.4)

        fig.tight_layout()
        out = f"./results/evarl_vs_rl_return_{env}.pdf"
        fig.savefig(out)
        plt.close(fig)
        print("saved", out)


def plot_mae_by_beta(pred_coef_groups: Dict[str, List[str]], 
                    pred_coef_group_results: Dict[str, Any],
                    pred_coef_group_return: Dict[str, Any],
                    plot_pred_lr: float = 0.001,
                    plot_pretrain: int = 0):
    """Plot MAE bar plot for each environment showing effect of beta."""
    # Gather per-environment data structures
    envs = sorted({split_key(k)[0] for k in pred_coef_groups})
    palette = plt.cm.Set2.colors                     # ≈ 8 qualitative colours

    # env → β-sorted lists
    env2_beta, env2_mae, env2_std, env2_ret, env2_retstd = (
        defaultdict(list), defaultdict(list), defaultdict(list),
        defaultdict(list), defaultdict(list)
    )

    for key in pred_coef_groups:
        env, beta, lr, pre = split_key(key)
        if lr != plot_pred_lr or pre != plot_pretrain:
            continue

        env2_beta[env].append(beta)
        # MAE
        mae_per_est = pred_coef_group_results[key]      # dict per estimator
        env2_mae[env].append([mae_per_est[e]["agg"] for e in ESTIMATORS[:-1]])  # exclude predictor
        env2_std[env].append([mae_per_est[e]["std"] for e in ESTIMATORS[:-1]])
        # Return
        rstat = pred_coef_group_return[key]
        env2_ret[env].append(rstat["agg"])
        env2_retstd[env].append(rstat["std"])

    # Sort β inside each env
    for env in envs:
        order = np.argsort(env2_beta[env])
        env2_beta[env] = np.array(env2_beta[env])[order]
        env2_mae[env] = np.array(env2_mae[env])[order]
        env2_std[env] = np.array(env2_std[env])[order]
        env2_ret[env] = np.array(env2_ret[env])[order]
        env2_retstd[env] = np.array(env2_retstd[env])[order]

    # MAE bar-plot for each environment
    bar_w = 0.8 / max(len(b) for b in env2_beta.values())    # width per β

    for env in envs:
        betas = env2_beta[env]
        mae = env2_mae[env]      # shape [n_beta, n_est]
        mae_std = env2_std[env]

        x_est = np.arange(len(ESTIMATORS[:-1]))                   # 0‥4
        fig, ax = plt.subplots(figsize=(6, 4))

        for j, (beta, col) in enumerate(zip(betas, cycle(palette))):
            offs = (j - (len(betas)-1)/2) * bar_w
            ax.bar(x_est + offs, mae[j], bar_w,
                   yerr=mae_std[j], capsize=3,
                   color=col, edgecolor="black",
                   label=f"β={beta:.2g}")

        ax.set_xticks(x_est)
        ax.set_xticklabels(ESTIMATORS[:-1])
        ax.set_ylabel("MAE")
        ax.legend(bbox_to_anchor=(.75, 1), loc='upper left', borderaxespad=0.)

        fig.tight_layout()
        fig.savefig(f"./results/mae_{env}_lr{plot_pred_lr}_pre{plot_pretrain}.pdf")
        print("saved at ", f"./results/mae_{env}_lr{plot_pred_lr}_pre{plot_pretrain}.pdf")


def plot_frozen_vs_finetune(pred_coef_group_return: Dict[str, Any],
                           pred_coef_group_results: Dict[str, Any]):
    """Plot frozen vs finetuned transformer comparison."""
    # Pass 1: build per-env baseline from all β = 0 runs
    baseline = defaultdict(list)                     # env → list of returns

    for k, stat in pred_coef_group_return.items():
        env, beta, lr, pre = split_key(k)
        if pre == 1 and beta == 0:                   # any pred_lr, but β = 0
            baseline[env].append(float(stat["agg"]))

    # Take the mean as the merged baseline
    base_mu = {env: np.mean(vals) for env, vals in baseline.items()
               if len(vals) > 0 and np.mean(vals) != 0}

    # Pass 2: collect & normalise
    data = defaultdict(lambda: defaultdict(dict))    # env → β → tag → (µ, σ)

    for k, rstat in pred_coef_group_return.items():
        env, beta, lr, pre = split_key(k)
        if pre != 1 or env not in base_mu:           # need valid baseline
            continue

        # Normalise
        mu_norm = float(rstat["agg"]) / base_mu[env]
        sd_norm = float(rstat["std"]) / abs(base_mu[env])

        if beta == 0:                                # merged baseline
            data[env][0]["baseline"] = (1.0, 0.0)    # plotted once, value = 1
            continue

        tag = "frozen" if lr == 0 else "finetune"
        if tag not in data[env][beta]:               # keep first occurrence
            data[env][beta][tag] = (mu_norm, sd_norm)

    # Plotting
    col = {"baseline": "lightgrey", "frozen": "tab:blue", "finetune": "tab:orange"}

    for env, beta_dict in data.items():
        betas = sorted(beta_dict)
        x = np.arange(len(betas))
        bar_w = 0.28
        shown = set()

        fig, ax = plt.subplots(figsize=(5, 4))

        for i, beta in enumerate(betas):
            entry = beta_dict[beta]

            # Baseline bar (β = 0)
            if beta == 0:
                mu, sd = entry["baseline"]
                ax.bar(x[i], mu, width=bar_w*2, yerr=sd, capsize=3,
                       color=col["baseline"], edgecolor="black",
                       label="baseline")
                continue

            # Frozen / finetune bars
            for j, tag in enumerate(("frozen", "finetune")):
                if tag not in entry:
                    continue
                mu, sd = entry[tag]
                offs = (-0.5 + j) * bar_w
                lab = {"frozen": "frozen", "finetune": "finetune"}[tag]
                ax.bar(x[i] + offs, mu, width=bar_w, yerr=sd, capsize=3,
                       color=col[tag], edgecolor="black",
                       label=lab if tag not in shown else None)
                shown.add(tag)

        # Figure cosmetics
        ax.set_xticks(x)
        ax.set_xticklabels([f"β={b:g}" for b in betas], rotation=30)
        ax.set_ylabel("Norm. Return")
        ax.grid(axis="y", linestyle="--", alpha=0.4)
        ax.legend(loc="upper center", ncol=3, bbox_to_anchor=(0.5, 1.2),
                  frameon=False, fontsize=14)

        fig.tight_layout()
        out = f"./results/ret_norm_frozen_vs_tuned_{env}.pdf"
        fig.savefig(out)
        plt.close(fig)
        print("saved", out)


def plot_predictor_mae_frozen_vs_finetune(pred_coef_group_results: Dict[str, Any],
                                         plot_pred_lr: float = 1e-4):
    """Plot predictor MAE comparison between frozen and finetuned transformers."""
    PRETRAIN = 1               # compare only when pretrained = 1

    # env → β → {frozen / fine : (µ , σ)}
    data = defaultdict(lambda: defaultdict(dict))

    for k, stat in pred_coef_group_results.items():
        env, beta, lr, pre = split_key(k)
        if pre != PRETRAIN:            # we plot only pretrained-on runs
            continue
        if lr == 0:                    # frozen
            tag = "frozen"
        elif lr == plot_pred_lr:       # chosen finetune lr
            tag = "fine"
        else:
            continue                   # ignore other lrs

        mu = stat["predictor"]["agg"]
        sd = stat["predictor"]["std"]
        data[env][beta][tag] = (mu, sd)

    # Plotting
    col = {"frozen": "tab:blue", "fine": "tab:orange"}

    for env, beta_dic in data.items():
        # Keep only β values that have BOTH bars
        betas = sorted(b for b, d in beta_dic.items() if {"frozen", "fine"} <= d.keys())
        if not betas:
            continue

        x = np.arange(len(betas))
        width = 0.35

        frozen_vals = [beta_dic[b]["frozen"][0] for b in betas]
        frozen_err = [beta_dic[b]["frozen"][1] for b in betas]

        fig, ax = plt.subplots(figsize=(5, 3.5))
        ax.bar(x - width/2, frozen_vals, width, yerr=frozen_err, capsize=3,
               color=col["frozen"], edgecolor="black", label="frozen (pred_lr=0)")

        ax.set_xticks(x)
        ax.set_xticklabels([f"β={b:g}" for b in betas], rotation=30)
        ax.set_ylabel("Predictor MAE")
        ax.grid(axis="y", linestyle="--", alpha=0.4)

        fig.tight_layout()
        out = f"./results/predmae_frozen_vs_finetune_{env}.pdf"
        fig.savefig(out)
        plt.close(fig)
        print("saved", out)


def plot_driver_test_size_ablation(eval_results: Dict[str, Any]):
    """Plot driver test size ablation study."""
    # Driver's test size effect
    pred_coef_ = [1e-2]
    pred_lr_ = 1e-4
    use_pretrained_transformer_ = 0

    dt_groups = {}
    for experiment_name in eval_results.keys():
        parsed_experiment_name = parse_experiment_name(experiment_name)
        env_name = parsed_experiment_name["env_name"]
        num_dt = parsed_experiment_name["num_dt"]
        pred_coef = parsed_experiment_name["pred_coef"]
        pred_lr = parsed_experiment_name["pred_lr"]
        use_pretrained_transformer = parsed_experiment_name["use_pretrained_transformer"]
        
        if pred_lr_ == 0 and pred_lr != 0:
            continue
        elif pred_lr_ != 0 and pred_lr == 0:
            continue
        
        if pred_coef not in pred_coef_ or use_pretrained_transformer != use_pretrained_transformer_:
            continue
        
        # Create a key for the group
        group_key = f"{env_name}_{num_dt}"
        
        # Add the experiment to the group
        if group_key not in dt_groups:
            dt_groups[group_key] = []
        dt_groups[group_key].append(experiment_name)
            
    # Compute statistics
    dt_group_mae = {}
    dt_group_return = {}
    for dt_group in dt_groups.keys():
        dt_group_stat = compute_group_mae(dt_groups, dt_group, eval_results)
        dt_group_mae[dt_group] = dt_group_stat
        
        dt_group_return_ = compute_group_return(dt_groups, dt_group, eval_results)
        dt_group_return[dt_group] = dt_group_return_
        
    # Plot MAE results
    for env_name in SUPPORTED_ENVIRONMENTS:
        # Get the dt groups for the environment
        dt_groups_for_env = {key: value for key, value in dt_groups.items() if key.startswith(env_name)}
        
        plt.figure()
        dt_group_env_names = list(dt_groups_for_env.keys())
        dt_group_env_names.sort(key=lambda x: int(x.split("_")[1]))
        xs = []
        ys = []
        for dt_group_env in dt_group_env_names:
            dt_group_stat = dt_group_mae[dt_group_env]
            x = int(dt_group_env.split("_")[1])
            xs.append(x)
            ys.append(dt_group_stat["agg"])
            plt.scatter(x, dt_group_stat["agg"], label=dt_group_env)
            plt.errorbar(x, dt_group_stat["agg"], yerr=dt_group_stat["std"], fmt='o', color='orange')
        plt.plot(xs, ys, linestyle='-', color='orange')
        plt.xticks(rotation=45, ticks=xs)
        plt.xlabel("Number of Driver States")
        plt.ylabel("Mean Absolute Error")
        plt.tight_layout()
        out = f"./results/dt_group_mae_eval_results_{env_name}_{use_pretrained_transformer_}pretrained_{pred_lr_}predlr.pdf"
        plt.savefig(out)
        print(f"Saved at {out}")
        plt.close()
        
    # Plot return results
    for env_name in SUPPORTED_ENVIRONMENTS:
        # Get the dt groups for the environment
        dt_groups_for_env = {key: value for key, value in dt_groups.items() if key.startswith(env_name)}
        
        plt.figure()
        dt_group_env_names = list(dt_groups_for_env.keys())
        dt_group_env_names.sort(key=lambda x: int(x.split("_")[1]))
        xs = []
        ys = []
        for dt_group_env in dt_group_env_names:
            dt_group_stat = dt_group_return[dt_group_env]
            x = int(dt_group_env.split("_")[1])
            xs.append(x)
            ys.append(dt_group_stat["agg"])
            plt.scatter(x, dt_group_stat["agg"], label=dt_group_env)
            plt.errorbar(x, dt_group_stat["agg"], yerr=dt_group_stat["std"], fmt='o', color='orange')
        plt.plot(xs, ys, linestyle='-', color='orange')
        plt.xticks(rotation=45, ticks=xs)
        plt.xlabel("Number of Driver States")
        plt.ylabel("Mean Return")
        plt.tight_layout()
        out = f"./results/dt_group_mc_eval_results_{env_name}_{use_pretrained_transformer_}_pretrained_{pred_lr_}predlr.pdf"
        plt.savefig(out)
        print(f"Saved at {out}")
        plt.close()


def main():
    """Main function to run the result collection and analysis."""
    # Parse arguments
    parser = get_parser()
    args = parser.parse_args([])
    
    # Create configuration
    config = ExperimentConfig(
        save_dir=args.save_dir,
        experiment_name=args.experiment_name,
        num_eval_samples=args.num_eval_samples,
        num_trajs_per_state=args.num_trajs_per_state,
        experiment_name_file=args.experiment_name_file,
        pred_coef=args.PREDICTABILITY_COEF,
        pred_lr=args.PRED_LR,
        use_pretrained_transformer=args.use_pretrained_transformer
    )
    
    # Load A2C files
    a2c_files = load_a2c_files(config.save_dir)
    experiment_names = [a2c_file[:-len("_a2c_evarl.pkl")] for a2c_file in a2c_files]
    print("Experiment names:", experiment_names)
    
    # Load experiment results
    eval_results = load_experiment_results(config.save_dir, experiment_names)
    experiment_names = list(eval_results.keys())
    
    # Group experiments
    groups_by_env_and_params = group_experiments_by_env_and_params(eval_results.keys())
    
    # Create predictability coefficient groups
    pred_coef_groups, pred_coef_group_results, pred_coef_group_return = create_pred_coef_groups(eval_results)
    
    # Generate plots
    
    # Figure 2: Mean Absolute Error (MAE) of the predictor compared against the OPE estimators
    # Comparison of EvA-RL and RL + OPE
    plot_mae_ope_vs_evarl(pred_coef_group_results)
    
    # Returns of the agent trained using standard RL vs EvA-RL
    plot_evarl_vs_rl_return(pred_coef_group_return)
    
    # Figure 3: MAE bar plot for each environment showing effect of beta
    plot_mae_by_beta(pred_coef_groups, pred_coef_group_results, pred_coef_group_return)
    
    # Figure 4: EvA-RL with pretrained transformer - frozen vs finetuned comparison
    plot_frozen_vs_finetune(pred_coef_group_return, pred_coef_group_results)
    plot_predictor_mae_frozen_vs_finetune(pred_coef_group_results)
    
    # Appendix: Driver's test size ablation
    plot_driver_test_size_ablation(eval_results)


if __name__ == "__main__":
    main()


