import argparse
import json
import os
import pickle
from functools import partial
from pathlib import Path
from typing import Callable, Dict, List, Optional

import gymnasium as gym
import hydra
import numpy as np
import ray
import torch
import yaml
from omegaconf import DictConfig
from ribs.archives import GridArchive
from vendi_score import vendi

import os

import gc
from src.algorithms.diayn import SACPolicy
from src.embeddings.rff import RFF
from src.qd.wrappers import (
    AntBehavioralWrapper,
    BipedalBehavioralWrapper,
    HalfCheetahBehavioralWrapper,
    HopperBehavioralWrapper,
    SwimmerBehavioralWrapper,
    WalkerBehavioralWrapper,
    ANT_MEASURE_NAMES,
    SWIMMER_MEASURE_NAMES,
)

# from src.qd.wrappers.ablations import DistractorWrapper
from src.utils import Trajectory, collect_trajectories

GT_MEASURES = {
    "BipedalWalker-v3": [
        "left_contact_freq",
        "right_contact_freq",
    ],
    "Ant-v5": ANT_MEASURE_NAMES,
    "HalfCheetah-v5": [
        "back_foot_freq",
        "front_foot_freq",
    ],
    "Hopper-v5": ["foot_contact_freq"],
    "Swimmer-v5": SWIMMER_MEASURE_NAMES,
    "Walker2d-v5": ["right_contact_freq", "left_contact_freq"],
}

ENV_WRAPPERS = {
    "BipedalWalker-v3": BipedalBehavioralWrapper,
    "Ant-v5": AntBehavioralWrapper,
    "HalfCheetah-v5": HalfCheetahBehavioralWrapper,
    "Hopper-v5": HopperBehavioralWrapper,
    "Swimmer-v5": SwimmerBehavioralWrapper,
    "Walker2d-v5": WalkerBehavioralWrapper,
}

##############################
##### Evaluation Utility #####
##############################


def agent_from_DictConfig(cfg: DictConfig):
    return hydra.utils.instantiate(cfg)


class SAC_Agent_Wrapper:
    def __init__(self, policy, n_skills):
        self.policy = policy
        self.n_skills = n_skills

    def from_numpy(self, skill_id):
        self.skill_id = skill_id
        self.skill = np.zeros(self.n_skills, dtype=np.float32)
        self.skill[skill_id] = 1.0
        return self

    @torch.no_grad()
    def act(self, state: np.ndarray):
        batched = True
        if state.ndim == 1:
            state = np.expand_dims(state, 0)
            batched = False
        skill_vec = np.tile(self.skill, state.shape[0]).reshape(
            (state.shape[0], self.n_skills)
        )

        state_skill = torch.FloatTensor(np.concatenate([state, skill_vec], axis=-1))
        action, _, _ = self.policy.sample_action(state_skill)
        if batched:
            return action.cpu().numpy()
        else:
            return action.cpu().numpy().flatten()


def create_sac_agent(state_dim, action_dim, hidden_dims, n_skills, state_dict):
    policy = SACPolicy(state_dim + n_skills, action_dim, hidden_dims)
    policy.load_state_dict(state_dict)
    return SAC_Agent_Wrapper(policy, n_skills)


@ray.remote
class FlexiblePolicyEvaluator:
    def __init__(
        self,
        env_id: str,
        env_kwargs: Dict,
        num_envs: int,
        agent_creator: Callable,
        wrappers: Optional[List[gym.Wrapper]] = None,
        measure_names: Optional[List[str]] = None,
    ):
        """Initialize policy evaluator.

        Args:
            env_id: Gymnasium environment ID
            num_envs: Number of parallel environments to run
            agent_creator: Function that creates an instance of the agent (nn.Module)
            wrappers: None or list of wrappers passed to the vector env constructor
            measure_names: None or list of handcrafted measures that will be extracted
        """
        torch.set_num_threads(1)
        self.agent_creator = agent_creator
        if wrappers and measure_names:
            self.include_measures = True
            self.measure_names = measure_names
            self.wrappers = wrappers
        else:
            self.include_measures = False
            self.measure_names = None
            self.wrappers = []

        # Special token indicating ablation with distractors
        self.num_distractors = 0
        # if "@" in env_id:
        #     env_id, num_distractors = env_id.split("@")
        #     self.num_distractors = int(num_distractors)
        #     self.wrappers.append(
        #         partial(DistractorWrapper, num_extra_dims=self.num_distractors)
        #     )

        self.envs = gym.make_vec(
            env_id,
            num_envs=num_envs,
            vectorization_mode=gym.VectorizeMode.ASYNC,
            wrappers=self.wrappers,
            **env_kwargs,
        )

    def evaluate_policy(
        self,
        policy_params: np.ndarray,
        n_trajectories: int,
    ) -> List[Trajectory]:
        """Evaluate a single policy."""
        policy = self.agent_creator().from_numpy(policy_params)
        if self.include_measures:
            trajs = collect_trajectories(
                self.envs,
                policy,
                n_trajectories,
                self.measure_names,
            )
        else:
            trajs = collect_trajectories(self.envs, policy, n_trajectories)

        if self.num_distractors > 0:
            for traj in trajs:
                traj.states = traj.states[:, : -self.num_distractors]
        return trajs


#############################
##### Loading From File #####
#############################


def load_or_create_embedding_map(env_id, state_dim, action_dim):
    embedding_file_path = Path(f"./evaluation_data/embedding_map")
    actual_env_id = env_id.split("@")[0]
    if (embedding_file_path / f"{actual_env_id}.pkl").is_file():
        with open(embedding_file_path / f"{actual_env_id}.pkl", "rb") as f:
            embedding_map = pickle.load(f)
        print("Loaded RFF embedding map from file")
    else:
        os.makedirs(embedding_file_path, exist_ok=True)
        embedding_map = RFF(
            dim=400,
            state_dim=state_dim,
            action_dim=action_dim,
            kernel_width=None,
            normalize=True,
            gamma=0.999,
        )
        embedding_map.training = False
        with open(f"./evaluation_data/{actual_env_id}_mean_std.pkl", "rb") as f:
            stats = pickle.load(f)
            embedding_map.normalizer.mean = torch.tensor(
                stats["mean"], dtype=torch.float32
            )
            embedding_map.normalizer.std = torch.tensor(
                stats["std"], dtype=torch.float32
            )
        raise ValueError("Could not fine embedding map")
        with open(embedding_file_path / f"{env_id}.pkl", "wb") as f:
            pickle.dump(embedding_map, f)
        print("Initialized new embedding map")
    return embedding_map


def load_median_gamma(env_id):
    env_id = env_id.split("@")[0]
    with open("evaluation_data/gammas.json") as f:
        gammas = json.load(f)
    return gammas[env_id]["rff"]["median"]


def load_population(algo_name, path):
    if algo_name in ["auto_qd", "regular_qd", "aurora", "lstm_aurora"]:
        with open(path + "/checkpoints" + "/final.pkl", "rb") as f:
            ckpt = pickle.load(f)
        archive = ckpt["archive"]
        solutions: np.ndarray = archive.data("solution")  # N x sol_dim
        agent_cfg: DictConfig = ckpt["agent_cfg"]
        agent_creator = partial(agent_from_DictConfig, cfg=agent_cfg)
        return solutions, agent_creator
    elif algo_name in ["diayn", "smerl"]:
        ckpt = torch.load(path + "/checkpoints/final.pt")
        agent_creator = partial(
            create_sac_agent,
            state_dim=ckpt["state_dim"],
            action_dim=ckpt["action_dim"],
            hidden_dims=ckpt["hidden_dims"],
            n_skills=ckpt["n_skills"],
            state_dict=ckpt["policy"],
        )
        solutions = np.arange(ckpt["n_skills"])
        return solutions, agent_creator
    else:
        raise ValueError(f"algo_name {algo_name} not recognized")


###################################
##### Main Evaluation Methods #####
###################################


def chunked_rbf_kernel_gpu(X: np.ndarray, gamma: float, chunk_size=1024):
    """
    Efficient RBF kernel computation using GPU and chunking.
    Note that |x-y|^2 = |x|^2 + |y|^2 - 2<x, y>
    """
    n = X.shape[0]

    device = "cuda" if torch.cuda.is_available() else "cpu"
    X_gpu = torch.as_tensor(X, device=device, dtype=torch.float32)
    K = torch.empty((n, n), device=device, dtype=torch.float32)

    X_norms = torch.sum(X_gpu**2, dim=1)

    for i in range(0, n, chunk_size):
        end_i = min(i + chunk_size, n)
        X_i = X_gpu[i:end_i]

        for j in range(0, n, chunk_size):
            end_j = min(j + chunk_size, n)
            X_j = X_gpu[j:end_j]

            dots = torch.mm(X_i, X_j.t())
            norms_i = X_norms[i:end_i].unsqueeze(1)
            norms_j = X_norms[j:end_j].unsqueeze(0)
            sq_dists = torch.clamp(
                norms_i + norms_j - 2 * dots, min=0.0
            )  # Prevent negatives

            K[i:end_i, j:end_j] = torch.exp(-gamma * sq_dists)

    K = 0.5 * (K + K.T)  # Ensure symmetry
    K = K.cpu().numpy()
    return K


def evaluate_population(
    population: List,
    agent_creator: Callable,
    env_id: str,
    env_kwargs: Dict,
    n_evals: int,
    min_score: float,  # For computing QD score
    state_dim: int,
    action_dim: int,
):
    """
    Given a population, we do two types of analysis to determine the quality and diversity
    of them collectively.

    1. Ground Truth Measures: Create an archive using known ground truth measures of diversity
        and put the policies in that archive. By reporting the QD score, coverage, mean
        and max objectives we understand how good the archive is based on human notions
        of diversity.

    2. Occupancy Measure Based Analysis: We use embeddings similar to those employed by
        AutoQD to obtain an embedding for each policy. The same embedding map is used to
        compare all populations (and it is different from the one used by AutoQD to train)
        We use the Vendi score and the weighted Vendi score to evaluate each population
        based on the diversity of their embeddings and their quality..
    """
    print(f"Evaluating population of size: {len(population)}")

    if not ray.is_initialized():
        ray.init()
    cpus_per_worker = 2
    num_workers = int(max(1, (ray.cluster_resources()["CPU"]) // cpus_per_worker))
    print(f"Using {num_workers} workers each with {cpus_per_worker} CPUs")

    evaluators = [
        FlexiblePolicyEvaluator.options(num_cpus=cpus_per_worker).remote(
            env_id=env_id,
            env_kwargs=env_kwargs,
            num_envs=4,
            agent_creator=agent_creator,
            wrappers=[ENV_WRAPPERS[env_id]],
            measure_names=GT_MEASURES[env_id],
        )
        for _ in range(num_workers)
    ]

    objectives, measures, embeddings = [], [], []
    embedding_map = load_or_create_embedding_map(env_id, state_dim, action_dim)

    for i in range(0, len(population), 10 * num_workers):
        batch = population[i : i + 10 * num_workers]
        futures = [
            evaluators[j % num_workers].evaluate_policy.remote(individual, n_evals)
            for j, individual in enumerate(batch)
        ]
        batch_trajectories = ray.get(futures)

        for trajs in batch_trajectories:
            obj = np.mean([t.rewards.sum() for t in trajs])
            meas = np.array([t.measures for t in trajs]).mean(0)
            embedding = embedding_map.embed_trajectories(trajs)

            objectives.append(obj)
            measures.append(meas)
            embeddings.append(embedding)

        # call garbage collector to reclaim memory
        gc.collect()

    objectives = np.array(objectives)
    measures = np.array(measures)
    embeddings = np.array(embeddings)

    # Calculate stats on archive with GT measures
    archive = GridArchive(
        solution_dim=1,
        dims=list([20 for _ in range(len(GT_MEASURES[env_id]))]),
        ranges=tuple(((0, 1) for _ in range(len(GT_MEASURES[env_id])))),
        qd_score_offset=min_score,
    )
    archive.add(np.zeros((len(objectives), 1)), objectives, measures)
    gt_size = len(archive)
    gt_coverage = archive.stats.coverage
    gt_qd_score = archive.stats.qd_score
    gt_mean_objective = archive.stats.obj_mean
    max_objective = objectives.max()

    # 2. Calculate Vendi Score and qVS with RFF embeddings

    median_gamma = load_median_gamma(env_id)
    sym_matrix = chunked_rbf_kernel_gpu(embeddings, gamma=median_gamma)
    median_vendi_score = float(vendi.score_K(sym_matrix))  # Vendi score
    del sym_matrix

    metrics = {
        "population_size": len(population),
        "mean_objective": float(objectives.mean()),
        "min_objective": float(objectives.min()),
        "max_objective": float(objectives.max()),
        "gt_coverage": float(gt_coverage),  # Diversity
        "gt_size": int(gt_size),  # Diversity
        "gt_qd_score": float(gt_qd_score),  # Quality and Diversity
        "gt_mean_objective": float(gt_mean_objective),  # Quality
        "gt_max_objective": float(max_objective),  # Quality
        "median_vendi_score": median_vendi_score,
    }
    return metrics


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("path", help="Name for the log directory")
    args = parser.parse_args()

    logdir = args.path

    with open(f"{logdir}/config.yaml", "r") as file:
        config = yaml.safe_load(file)

    algo_name = config["algorithm"]["algo_name"]
    env_id = config["env"]["env_id"]
    env_kwargs = config["env"].get("env_kwargs", {})
    state_dim, action_dim = (
        config["env"]["state_dim"],
        config["env"]["action_dim"],
    )
    min_score = config["env"]["min_score"]
    population, agent_creator = load_population(algo_name, logdir)

    metrics = evaluate_population(
        population,
        agent_creator,
        env_id,
        env_kwargs,
        32,  # n_evals
        min_score,
        state_dim,
        action_dim,
    )

    with open(f"{logdir}/evaluation_results.json", "w") as f:
        json.dump(metrics, f, indent=4)


if __name__ == "__main__":
    main()
