import argparse
from functools import partial
import os
from pathlib import Path
import pickle
from typing import Callable, Dict, List, Optional
import hydra
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from omegaconf import DictConfig
import ray
from scipy.stats import friedmanchisquare, wilcoxon
import torch
import gymnasium as gym

from src.qd.wrappers import (
    AntBehavioralWrapper,
    BipedalBehavioralWrapper,
    HalfCheetahBehavioralWrapper,
    HopperBehavioralWrapper,
    SwimmerBehavioralWrapper,
    WalkerBehavioralWrapper,
    ANT_MEASURE_NAMES,
    SWIMMER_MEASURE_NAMES,
)
from src.utils import Trajectory, collect_trajectories

GT_MEASURES = {
    "BipedalWalker-v3": [
        "left_contact_freq",
        "right_contact_freq",
    ],
    "Ant-v5": ANT_MEASURE_NAMES,
    "HalfCheetah-v5": [
        "back_foot_freq",
        "front_foot_freq",
    ],
    "Hopper-v5": ["foot_contact_freq"],
    "Swimmer-v5": SWIMMER_MEASURE_NAMES,
    "Walker2d-v5": ["right_contact_freq", "left_contact_freq"],
}

ENV_WRAPPERS = {
    "BipedalWalker-v3": BipedalBehavioralWrapper,
    "Ant-v5": AntBehavioralWrapper,
    "HalfCheetah-v5": HalfCheetahBehavioralWrapper,
    "Hopper-v5": HopperBehavioralWrapper,
    "Swimmer-v5": SwimmerBehavioralWrapper,
    "Walker2d-v5": WalkerBehavioralWrapper,
}

##############################
##### Evaluation Utility #####
##############################


def agent_from_DictConfig(cfg: DictConfig):
    return hydra.utils.instantiate(cfg)


@ray.remote
class FlexiblePolicyEvaluator:
    def __init__(
        self,
        env_id: str,
        env_kwargs: Dict,
        num_envs: int,
        agent_creator: Callable,
        wrappers: Optional[List[gym.Wrapper]] = None,
        measure_names: Optional[List[str]] = None,
    ):
        """Initialize policy evaluator.

        Args:
            env_id: Gymnasium environment ID
            num_envs: Number of parallel environments to run
            agent_creator: Function that creates an instance of the agent (nn.Module)
            wrappers: None or list of wrappers passed to the vector env constructor
            measure_names: None or list of handcrafted measures that will be extracted
        """
        torch.set_num_threads(1)
        self.agent_creator = agent_creator
        if wrappers and measure_names:
            self.include_measures = True
            self.measure_names = measure_names
            self.wrappers = wrappers
        else:
            self.include_measures = False
            self.measure_names = None
            self.wrappers = []

        self.envs = gym.make_vec(
            env_id,
            num_envs=num_envs,
            vectorization_mode=gym.VectorizeMode.ASYNC,
            wrappers=self.wrappers,
            **env_kwargs,
        )

    def evaluate_policy(
        self,
        policy_params: np.ndarray,
        n_trajectories: int,
    ) -> List[Trajectory]:
        """Evaluate a single policy."""
        policy = self.agent_creator().from_numpy(policy_params)
        if self.include_measures:
            return collect_trajectories(
                self.envs,
                policy,
                n_trajectories,
                self.measure_names,
            )
        else:
            return collect_trajectories(self.envs, policy, n_trajectories)


@torch.no_grad()
def get_measures(
    env_id,
    env_kwargs,
    agent_creator,
    all_solutions,
    auto_qd_measure_fn,
    aurora_measure_fn,
    lstm_aurora_measure_fn,
    n_evals,
):
    if not ray.is_initialized():
        ray.init()
    cpus_per_worker = 2
    num_workers = int(max(1, (ray.cluster_resources()["CPU"]) // cpus_per_worker))
    print(f"Using {num_workers} workers each with {cpus_per_worker} CPUs")

    evaluators = [
        FlexiblePolicyEvaluator.options(num_cpus=cpus_per_worker).remote(
            env_id=env_id,
            env_kwargs=env_kwargs,
            num_envs=4,
            agent_creator=agent_creator,
            wrappers=[ENV_WRAPPERS[env_id]],
            measure_names=GT_MEASURES[env_id],
        )
        for _ in range(num_workers)
    ]

    futures = [
        evaluators[i % num_workers].evaluate_policy.remote(individual, n_evals)
        for i, individual in enumerate(all_solutions)
    ]
    trajectory_batches = ray.get(futures)
    gt_measures, auto_qd_measures, aurora_measures, lstm_aurora_measures = (
        [],
        [],
        [],
        [],
    )
    for trajs in trajectory_batches:
        gt_measures.append(np.array([t.measures for t in trajs]))  # (n_evals, meas_dim)
        auto_qd_measures.append(auto_qd_measure_fn(trajs))
        aurora_measures.append(aurora_measure_fn(trajs))
        lstm_aurora_measures.append(lstm_aurora_measure_fn(trajs))

    gt_measures = np.array(gt_measures)
    auto_qd_measures = np.array(auto_qd_measures)
    aurora_measures = np.array(aurora_measures)
    lstm_aurora_measures = np.array(lstm_aurora_measures)

    return {
        "GT": gt_measures,
        "AutoQD": auto_qd_measures,
        "Aurora": aurora_measures,
        "LSTM-Aurora": lstm_aurora_measures,
    }


##############################
##### Stability Analysis #####
##############################


def compute_policy_variances(measures):
    """
    Calculate per-policy variance (trace of covariance) for each measure function.
    measures: dict mapping algo name to np.ndarray of shape (n_policies, k_rollouts, m_dims)
    returns: dict mapping method name to 1D array of per-policy variances
    """
    variances = {}
    for algo, data in measures.items():
        variances[algo] = np.var(data, axis=1, ddof=0).sum(axis=1)  # (n_policies,)
    return variances


def compute_summary_stats(variances):
    """
    Compute mean variance and standard error for each method.
    variances: dict mapping method to 1D array of variances
    returns: dict mapping method to (mean, sem)
    """
    summary = {}
    for method, vals in variances.items():
        mean = vals.mean()
        sem = vals.std(ddof=1) / np.sqrt(len(vals))
        summary[method] = (mean, sem)
    return summary


def plot_variance_distributions(variances, output_path, env_name):
    """
    Plot with a broken y-axis to allow visibility of small and large variance values.
    """
    sns.set_theme(style="whitegrid", palette="pastel", font_scale=1.2)

    data = []
    labels = []
    for method, values in variances.items():
        data.extend(values)
        labels.extend([method] * len(values))

    # Prepare
    fig, (ax1, ax2) = plt.subplots(
        2, 1, sharex=True, figsize=(10, 6), gridspec_kw={"height_ratios": [1, 3]}
    )

    # Plot boxplots on both axes
    sns.boxplot(x=labels, y=data, ax=ax1, width=0.6, linewidth=1.3)
    sns.boxplot(x=labels, y=data, ax=ax2, width=0.6, linewidth=1.3)

    # Scatter (strip plot for detail)
    sns.stripplot(
        x=labels, y=data, color="black", size=3, jitter=True, alpha=0.5, ax=ax1
    )
    sns.stripplot(
        x=labels, y=data, color="black", size=3, jitter=True, alpha=0.5, ax=ax2
    )

    # Set y-limits (you can tune these depending on your data range)
    max_variance = max(data)
    ax1.set_ylim(max_variance * 0.6, max_variance * 1.05)  # zoom in on high values
    ax2.set_ylim(0, np.percentile(data, 75) * 1.5)  # zoom in on low/mid values

    # Hide the spines between the two plots
    ax1.spines["bottom"].set_visible(False)
    ax2.spines["top"].set_visible(False)
    ax1.tick_params(labeltop=False)  # don't double-label
    ax2.xaxis.tick_bottom()

    # Diagonal break marks
    d = 0.015
    kwargs = dict(transform=ax1.transAxes, color="k", clip_on=False)
    ax1.plot((-d, +d), (-d, +d), **kwargs)  # top-left diagonal
    ax1.plot((1 - d, 1 + d), (-d, +d), **kwargs)  # top-right diagonal

    kwargs.update(transform=ax2.transAxes)
    ax2.plot((-d, +d), (1 - d, 1 + d), **kwargs)  # bottom-left diagonal
    ax2.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)  # bottom-right diagonal

    fig.suptitle(
        f"Embedding Variance Across Methods ({env_name})", fontsize=16, weight="bold"
    )
    ax2.set_ylabel("Embedding Variance (Trace)")
    plt.tight_layout()
    plt.subplots_adjust(hspace=0.05)
    plt.savefig(output_path, dpi=300)
    plt.close()


def run_statistical_tests(variances):
    """
    Perform Friedman test and paired Wilcoxon tests.
    returns a dict with 'friedman': (stat, p), 'pairwise': {comp: (stat, p)}
    """
    methods = list(variances.keys())
    data = [variances[m] for m in methods]
    stat_f, p_f = friedmanchisquare(*data)
    pairwise = {}
    for i in range(len(methods)):
        for j in range(i + 1, len(methods)):
            m1, m2 = methods[i], methods[j]
            stat_w, p_w = wilcoxon(variances[m1], variances[m2])
            pairwise[f"{m1} vs {m2}"] = (stat_w, p_w)
    return {"friedman": (stat_f, p_f), "pairwise": pairwise}


def save_results(summary, tests, path):
    """
    Save summary stats and statistical test results to a text file.
    """
    with open(path, "w") as f:
        f.write("Method\tMeanVariance\tSEM\n")
        for method, (mean, sem) in summary.items():
            f.write(f"{method}\t{mean:.6f}\t{sem:.6f}\n")
        stat_f, p_f = tests["friedman"]
        f.write("\nFriedman test:\n")
        f.write(f"Statistic={stat_f:.6f}, p-value={p_f:.6e}\n\n")
        f.write("Pairwise Wilcoxon signed-rank tests:\n")
        for comp, (stat_w, p_w) in tests["pairwise"].items():
            f.write(f"{comp}: Statistic={stat_w:.6f}, p-value={p_w:.6e}\n")


# Run with python -m src.evaluation_suite.measure_stability bipedal_walker
def main():
    parser = argparse.ArgumentParser(
        description="Analyze stability of policy embedding methods."
    )
    parser.add_argument("env_name")
    args = parser.parse_args()
    env_id = {
        "ant": "Ant-v5",
        "bipedal_walker": "BipedalWalker-v3",
        "halfcheetah": "HalfCheetah-v5",
        "hopper": "Hopper-v5",
        "swimmer": "Swimmer-v5",
        "walker2d": "Walker2d-v5",
    }[args.env_name]
    env_kwargs = (
        {"include_cfrc_ext_in_observation": False} if env_id == "Ant-v5" else {}
    )

    outputs_path = Path("./1_outputs")
    all_solutions = []
    for algo_name in ["auto_qd", "regular_qd", "aurora", "lstm_aurora"]:
        logdir = next(outputs_path.glob(f"{algo_name}_{args.env_name}_*"))
        with open(logdir / "checkpoints" / "final.pkl", "rb") as f:
            ckpt = pickle.load(f)
        # Load solutions
        archive = ckpt["archive"]
        solutions: np.ndarray = archive.data("solution")  # N x sol_dim
        solutions = solutions[
            np.random.choice(
                solutions.shape[0], size=min(100, solutions.shape[0]), replace=False
            )
        ]
        all_solutions.append(solutions)
        agent_cfg: DictConfig = ckpt["agent_cfg"]
        agent_creator = partial(agent_from_DictConfig, cfg=agent_cfg)

        # Load the measure map
        if algo_name == "auto_qd":
            # NOTE: Takes in a list of trajectories, returns a numpy array of shape n_trajs, measure_dim
            # NOTE: Only cwpca works with batches
            embedding_map = ckpt["embedding_map"]
            measure_map = ckpt["measure_map"]
            lower_bounds = ckpt["archive"].lower_bounds
            interval_size = ckpt["archive"].interval_size
            auto_qd_measure_fn = (
                lambda trajs: (
                    measure_map(embedding_map.embed_individual_trajectories(trajs))
                    .cpu()
                    .numpy()
                    - lower_bounds
                )
                / interval_size
            )
        elif algo_name == "aurora":
            encoder = ckpt["encoder"]
            lower_bounds = ckpt["archive"].lower_bounds
            interval_size = ckpt["archive"].interval_size
            aurora_measure_fn = (
                lambda trajs: (
                    encoder.encode(
                        torch.from_numpy(
                            np.stack([t.states[-1] for t in trajs])
                        ).float()
                    )
                    .cpu()
                    .numpy()
                    - lower_bounds
                )
                / interval_size
            )
        elif algo_name == "lstm_aurora":
            encoder = ckpt["encoder"]
            lower_bounds = ckpt["archive"].lower_bounds
            interval_size = ckpt["archive"].interval_size
            lstm_aurora_measure_fn = (
                lambda trajs: (
                    encoder.encode(
                        list([torch.from_numpy(t.states[::10]).float() for t in trajs])
                    )
                    .cpu()
                    .numpy()
                    - lower_bounds
                )
                / interval_size
            )

    all_solutions = np.concatenate(all_solutions)

    measures = get_measures(
        env_id,
        env_kwargs,
        agent_creator,
        all_solutions,
        auto_qd_measure_fn,
        aurora_measure_fn,
        lstm_aurora_measure_fn,
        32,
    )

    variances = compute_policy_variances(measures)
    summary = compute_summary_stats(variances)
    os.makedirs("measure_stability_results/", exist_ok=True)
    plot_variance_distributions(
        variances,
        f"measure_stability_results/{args.env_name}_distribution.png",
        args.env_name,
    )
    tests = run_statistical_tests(variances)
    save_results(summary, tests, f"measure_stability_results/{args.env_name}.txt")


if __name__ == "__main__":
    main()
