"""
Load n BipedalWalker QD archive (names and ).
Instantiate the env with different values of friction and mass.
Evaluate the policies on that env and select the best policy.
Log the performances and save them to a file called `adaptation_results.json`
Generate a plot `adapatation_results.png` with the x axis being different values of the
    friction/mass and the y axis is the performance of the best policy of the archive
    with error bars indicating the error in its measurement (each policy is evaluated
    n times) comparing the adaptation performance of policies found by different algorithms.

"""

from functools import partial
import json
import pickle
from typing import Callable, Dict, List, Optional
import hydra
import numpy as np
from omegaconf import DictConfig
import ray
import torch
import gymnasium as gym
from tqdm import tqdm
from src.algorithms.diayn import SACPolicy
from src.qd.wrappers.ablations import FrictionScalingWrapper, MassScalingWrapper
from src.utils import Trajectory, collect_trajectories
import matplotlib.pyplot as plt
import seaborn as sns

N_EVALS = 32

#########################
##### Visualization #####
#########################


def plot_adaptation_results(results_file: str):
    with open(results_file, "r") as f:
        results = json.load(f)

    sns.set_theme(style="whitegrid", font_scale=1.2)
    colors = sns.color_palette("Set2", n_colors=len(results))

    for param in ["friction", "mass"]:
        plt.figure(figsize=(10, 6))
        for i, (algo, data) in enumerate(results.items()):
            scales = []
            means = []
            stds = []
            for scale, returns in data[param]:
                scales.append(scale)
                arr = np.array(returns)
                means.append(arr.mean())
                stds.append(arr.std(ddof=1) / np.sqrt(len(arr)))
            plt.errorbar(
                scales,
                means,
                yerr=stds,
                label=algo,
                color=colors[i],
                marker="o",
                capsize=4,
                linewidth=2,
            )

        plt.title(f"Best Policy Performance vs {param.capitalize()} Scale")
        plt.xlabel(f"{param.capitalize()} Scale")
        plt.ylabel("Return")
        plt.legend(title="Algorithm")
        plt.tight_layout()
        plt.savefig(f"adaptation_{param}.png", dpi=300)
        plt.close()


#################
##### Utils #####
#################


def agent_from_DictConfig(cfg: DictConfig):
    return hydra.utils.instantiate(cfg)


class SAC_Agent_Wrapper:
    def __init__(self, policy, n_skills):
        self.policy = policy
        self.n_skills = n_skills

    def from_numpy(self, skill_id):
        self.skill_id = skill_id
        self.skill = np.zeros(self.n_skills, dtype=np.float32)
        self.skill[skill_id] = 1.0
        return self

    @torch.no_grad()
    def act(self, state: np.ndarray):
        batched = True
        if state.ndim == 1:
            state = np.expand_dims(state, 0)
            batched = False
        skill_vec = np.tile(self.skill, state.shape[0]).reshape(
            (state.shape[0], self.n_skills)
        )

        state_skill = torch.FloatTensor(np.concatenate([state, skill_vec], axis=-1))
        action, _, _ = self.policy.sample_action(state_skill)
        if batched:
            return action.cpu().numpy()
        else:
            return action.cpu().numpy().flatten()


def create_sac_agent(state_dim, action_dim, hidden_dims, n_skills, state_dict):
    policy = SACPolicy(state_dim + n_skills, action_dim, hidden_dims)
    policy.load_state_dict(state_dict)
    return SAC_Agent_Wrapper(policy, n_skills)


@ray.remote
class FlexiblePolicyEvaluator:
    def __init__(
        self,
        env_id: str,
        env_kwargs: Dict,
        num_envs: int,
        agent_creator: Callable,
        wrappers: Optional[List[gym.Wrapper]] = None,
    ):
        """Initialize policy evaluator.

        Args:
            env_id: Gymnasium environment ID
            num_envs: Number of parallel environments to run
            agent_creator: Function that creates an instance of the agent (nn.Module)
            wrappers: None or list of wrappers passed to the vector env constructor
        """
        torch.set_num_threads(1)
        self.agent_creator = agent_creator
        self.wrappers = wrappers or []

        self.envs = gym.make_vec(
            env_id,
            num_envs=num_envs,
            vectorization_mode=gym.VectorizeMode.ASYNC,
            wrappers=self.wrappers,
            **env_kwargs,
        )

    def evaluate_policy(
        self,
        policy_params: np.ndarray,
        n_trajectories: int,
    ) -> List[Trajectory]:
        """Evaluate a single policy."""
        policy = self.agent_creator().from_numpy(policy_params)
        return collect_trajectories(self.envs, policy, n_trajectories)


def main():
    # Setup ray multiprocessing
    if not ray.is_initialized():
        ray.init()
    cpus_per_worker = 2  # Adjust based on machine
    num_workers = int(
        max(1, 0.8 * ((ray.cluster_resources()["CPU"]) // cpus_per_worker))
    )
    print(f"Using {num_workers} workers each with {cpus_per_worker} CPUs")

    all_results = {}
    for dir, algo_name in [
        ("auto_qd_bipedal_walker_0411_1658", "auto_qd"),
        ("aurora_bipedal_walker_0411_1233", "aurora"),
        ("lstm_aurora_bipedal_walker_0412_2233", "lstm_aurora"),
        ("regular_qd_bipedal_walker_0413_0554", "regular_qd"),
        ("smerl_bipedal_walker_0415_1044", "smerl"),
    ]:
        all_results[algo_name] = {"friction": [], "mass": []}
        if algo_name == "smerl":
            ckpt = torch.load(f"1_outputs/{dir}/checkpoints/final.pt")
            agent_creator = partial(
                create_sac_agent,
                state_dim=ckpt["state_dim"],
                action_dim=ckpt["action_dim"],
                hidden_dims=ckpt["hidden_dims"],
                n_skills=ckpt["n_skills"],
                state_dict=ckpt["policy"],
            )
            solutions = np.arange(ckpt["n_skills"])
        else:
            with open(f"1_outputs/{dir}/checkpoints/final.pkl", "rb") as f:
                ckpt = pickle.load(f)

            archive = ckpt["archive"]
            solutions: np.ndarray = archive.data("solution")  # N x sol_dim
            agent_cfg: DictConfig = ckpt["agent_cfg"]
            agent_creator = partial(agent_from_DictConfig, cfg=agent_cfg)

        # Robustness to changes in the friction
        print(f"Evaluating {algo_name} for friction robustness")
        for friction_scale in tqdm(
            [0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 4.0, 6.0]
        ):
            best_return = -np.inf
            evaluators = [
                FlexiblePolicyEvaluator.options(num_cpus=cpus_per_worker).remote(
                    env_id="BipedalWalker-v3",
                    env_kwargs={},
                    num_envs=4,
                    agent_creator=agent_creator,
                    wrappers=[
                        partial(FrictionScalingWrapper, friction_scale=friction_scale)
                    ],
                )
                for _ in range(num_workers)
            ]
            for i in range(0, len(solutions), 10 * num_workers):
                batch = solutions[i : i + 10 * num_workers]

                futures = [
                    evaluators[j % num_workers].evaluate_policy.remote(
                        individual, N_EVALS
                    )
                    for j, individual in enumerate(batch)
                ]
                batch_trajectories = ray.get(futures)

                for trajs in batch_trajectories:
                    returns = np.array([t.rewards.sum() for t in trajs])
                    m = returns.mean()
                    if m > best_return:
                        best_return = m
                        all_returns = returns.copy()
            all_results[algo_name]["friction"].append((friction_scale, all_returns))

        # Robustness to changes in the mass
        print(f"Evaluating {algo_name} for mass robustness")
        for mass_scale in tqdm(np.linspace(0.2, 1.8, 17)):
            best_return = -np.inf
            evaluators = [
                FlexiblePolicyEvaluator.options(num_cpus=cpus_per_worker).remote(
                    env_id="BipedalWalker-v3",
                    env_kwargs={},
                    num_envs=4,
                    agent_creator=agent_creator,
                    wrappers=[partial(MassScalingWrapper, mass_scale=mass_scale)],
                )
                for _ in range(num_workers)
            ]
            for i in range(0, len(solutions), 10 * num_workers):
                batch = solutions[i : i + 10 * num_workers]

                futures = [
                    evaluators[j % num_workers].evaluate_policy.remote(
                        individual, N_EVALS
                    )
                    for j, individual in enumerate(batch)
                ]
                batch_trajectories = ray.get(futures)

                for trajs in batch_trajectories:
                    returns = np.array([t.rewards.sum() for t in trajs])
                    m = returns.mean()
                    if m > best_return:
                        best_return = m
                        all_returns = returns.copy()
            all_results[algo_name]["mass"].append((mass_scale, all_returns))

    with open("adaptation_results.json", "w") as f:
        json.dump(
            {
                algo: {
                    k: [(scale, r.tolist()) for (scale, r) in v]
                    for k, v in data.items()
                }
                for algo, data in all_results.items()
            },
            f,
            indent=4,
        )
    plot_adaptation_results("adaptation_results.json")

    ray.shutdown()


if __name__ == "__main__":
    main()
