import json
import os
import pickle
from functools import partial
from pathlib import Path
from typing import Callable, Dict, List, Optional

import gymnasium as gym
import hydra
import numpy as np
import ray
import torch
import yaml
from omegaconf import DictConfig
from src.algorithms.diayn import SACPolicy
from src.embeddings.rff import RFF

from src.utils import Trajectory, collect_trajectories


def agent_from_DictConfig(cfg: DictConfig):
    return hydra.utils.instantiate(cfg)


class SAC_Agent_Wrapper:
    def __init__(self, policy, n_skills):
        self.policy = policy
        self.n_skills = n_skills

    def from_numpy(self, skill_id):
        self.skill_id = skill_id
        self.skill = np.zeros(self.n_skills, dtype=np.float32)
        self.skill[skill_id] = 1.0
        return self

    @torch.no_grad()
    def act(self, state: np.ndarray):
        batched = True
        if state.ndim == 1:
            state = np.expand_dims(state, 0)
            batched = False
        skill_vec = np.tile(self.skill, state.shape[0]).reshape(
            (state.shape[0], self.n_skills)
        )

        state_skill = torch.FloatTensor(np.concatenate([state, skill_vec], axis=-1))
        action, _, _ = self.policy.sample_action(state_skill)
        if batched:
            return action.cpu().numpy()
        else:
            return action.cpu().numpy().flatten()


def create_sac_agent(state_dim, action_dim, hidden_dims, n_skills, state_dict):
    policy = SACPolicy(state_dim + n_skills, action_dim, hidden_dims)
    policy.load_state_dict(state_dict)
    return SAC_Agent_Wrapper(policy, n_skills)


@ray.remote
class FlexiblePolicyEvaluator:
    def __init__(
        self,
        env_id: str,
        env_kwargs: Dict,
        num_envs: int,
        agent_creator: Callable,
        wrappers: Optional[List[gym.Wrapper]] = None,
        measure_names: Optional[List[str]] = None,
    ):
        """Initialize policy evaluator.

        Args:
            env_id: Gymnasium environment ID
            num_envs: Number of parallel environments to run
            agent_creator: Function that creates an instance of the agent (nn.Module)
            wrappers: None or list of wrappers passed to the vector env constructor
            measure_names: None or list of handcrafted measures that will be extracted
        """
        torch.set_num_threads(1)
        self.agent_creator = agent_creator
        if wrappers and measure_names:
            self.include_measures = True
            self.measure_names = measure_names
            self.wrappers = wrappers
        else:
            self.include_measures = False
            self.measure_names = None
            self.wrappers = []

        self.envs = gym.make_vec(
            env_id,
            num_envs=num_envs,
            vectorization_mode=gym.VectorizeMode.ASYNC,
            wrappers=self.wrappers,
            **env_kwargs,
        )

    def evaluate_policy(
        self,
        policy_params: np.ndarray,
        n_trajectories: int,
    ) -> List[Trajectory]:
        """Evaluate a single policy."""
        policy = self.agent_creator().from_numpy(policy_params)
        if self.include_measures:
            return collect_trajectories(
                self.envs,
                policy,
                n_trajectories,
                self.measure_names,
            )
        else:
            return collect_trajectories(self.envs, policy, n_trajectories)


def load_or_create_embedding_map(env_id, state_dim, action_dim):
    embedding_file_path = Path(f"./evaluation_data/embedding_map")
    if (embedding_file_path / f"{env_id}.pkl").is_file():
        with open(embedding_file_path / f"{env_id}.pkl", "rb") as f:
            embedding_map = pickle.load(f)
        print("Loaded RFF embedding map from file")
    else:
        os.makedirs(embedding_file_path, exist_ok=True)
        embedding_map = RFF(
            dim=400,
            state_dim=state_dim,
            action_dim=action_dim,
            kernel_width=None,
            normalize=True,
            gamma=0.999,
        )
        embedding_map.training = False
        with open(f"./evaluation_data/{env_id}_mean_std.pkl", "rb") as f:
            stats = pickle.load(f)
            embedding_map.normalizer.mean = torch.tensor(
                stats["mean"], dtype=torch.float32
            )
            embedding_map.normalizer.std = torch.tensor(
                stats["std"], dtype=torch.float32
            )
        with open(embedding_file_path / f"{env_id}.pkl", "wb") as f:
            pickle.dump(embedding_map, f)
        print("Initialized new embedding map")
    return embedding_map


def load_population(algo_name, path):
    if algo_name in ["auto_qd", "regular_qd", "aurora", "lstm_aurora"]:
        with open(path + "/checkpoints" + "/final.pkl", "rb") as f:
            ckpt = pickle.load(f)
        archive = ckpt["archive"]
        solutions: np.ndarray = archive.data("solution")  # N x sol_dim
        solutions = solutions[
            np.random.choice(
                solutions.shape[0], size=min(1000, solutions.shape[0]), replace=False
            )
        ]
        agent_cfg: DictConfig = ckpt["agent_cfg"]
        agent_creator = partial(agent_from_DictConfig, cfg=agent_cfg)
        return solutions, agent_creator
    elif algo_name in ["diayn", "smerl"]:
        ckpt = torch.load(path + "/checkpoints/final.pt")
        agent_creator = partial(
            create_sac_agent,
            state_dim=ckpt["state_dim"],
            action_dim=ckpt["action_dim"],
            hidden_dims=ckpt["hidden_dims"],
            n_skills=ckpt["n_skills"],
            state_dict=ckpt["policy"],
        )
        solutions = np.arange(ckpt["n_skills"])
        return solutions, agent_creator
    else:
        raise ValueError(f"algo_name {algo_name} not recognized")


def get_embeddings(
    population: List,
    agent_creator: Callable,
    states: np.ndarray,
    env_id: str,
    env_kwargs: Dict,
    n_evals: int,
    state_dim: int,
    action_dim: int,
):
    print(f"Evaluating population of size: {len(population)}")

    if not ray.is_initialized():
        ray.init()
    cpus_per_worker = 2
    num_workers = int(max(1, (ray.cluster_resources()["CPU"]) // cpus_per_worker))
    print(f"Using {num_workers} workers each with {cpus_per_worker} CPUs")

    evaluators = [
        FlexiblePolicyEvaluator.options(num_cpus=cpus_per_worker).remote(
            env_id=env_id,
            env_kwargs=env_kwargs,
            num_envs=4,
            agent_creator=agent_creator,
        )
        for _ in range(num_workers)
    ]

    rff_embeddings = []
    embedding_map = load_or_create_embedding_map(env_id, state_dim, action_dim)

    futures = [
        evaluators[j % num_workers].evaluate_policy.remote(individual, n_evals)
        for j, individual in enumerate(population)
    ]
    batch_trajectories = ray.get(futures)

    for trajs in batch_trajectories:
        rff_embeddings.append(embedding_map.embed_trajectories(trajs))

    # action-embeddings
    agent = agent_creator()
    act_embeddings = []
    for individual in population:
        agent.from_numpy(individual)
        actions = agent.act(states)
        act_embeddings.append(actions.flatten())
    act_embeddings = np.array(act_embeddings)
    return rff_embeddings, act_embeddings


def main():
    embeddings = {
        "BipedalWalker-v3": {"rff": [], "act": []},
        "Ant-v5": {"rff": [], "act": []},
        "HalfCheetah-v5": {"rff": [], "act": []},
        "Hopper-v5": {"rff": [], "act": []},
        "Swimmer-v5": {"rff": [], "act": []},
        "Walker2d-v5": {"rff": [], "act": []},
    }
    outputs_path = Path("./outputs")
    for logdir_path in outputs_path.iterdir():
        if logdir_path.is_dir():
            name = logdir_path.name
            if name[0].isalpha() and name.split("_")[0] != "dvd":
                print(f"Evaluating {name}")
                with open(f"outputs/{name}/config.yaml", "r") as file:
                    config = yaml.safe_load(file)

                algo_name = config["algorithm"]["algo_name"]
                env_id = config["env"]["env_id"]
                env_kwargs = config["env"].get("env_kwargs", {})
                state_dim, action_dim = (
                    config["env"]["state_dim"],
                    config["env"]["action_dim"],
                )
                with open(
                    f"./evaluation_data/{config['env']['env_id']}.pkl", "rb"
                ) as f:
                    states = pickle.load(f)  # np array of shape (n, state_dim)
                population, agent_creator = load_population(
                    algo_name, f"outputs/{name}"
                )

                rff_embs, act_embs = get_embeddings(
                    population,
                    agent_creator,
                    states,
                    env_id,
                    env_kwargs,
                    32,  # n_evals
                    state_dim,
                    action_dim,
                )
                embeddings[env_id]["rff"].extend(rff_embs)
                embeddings[env_id]["act"].extend(act_embs)

    gammas = {env_id: {"rff": {}, "act": {}} for env_id in embeddings.keys()}
    for env_id, all_embs in embeddings.items():
        for embedding_type, embs in all_embs.items():
            embs = np.array(embs)
            # Compute pairwise squared distances
            diffs = embs[:, None] - embs[None, :]
            sq_dists = np.sum(diffs**2, axis=-1)

            # Compute gamma using median squared distance
            gamma_med = np.log(2) / np.median(sq_dists)
            gamma_mean = np.log(2) / np.mean(sq_dists)
            gammas[env_id][embedding_type] = {"median": gamma_med, "mean": gamma_mean}

            del sq_dists
            del embs

    with open("evaluation_data/gammas.json", "w") as f:
        json.dump(gammas, f, indent=4)

    from pprint import pprint

    pprint(gammas)


if __name__ == "__main__":
    main()
