import logging
import pickle
from pathlib import Path

import gymnasium as gym
import hydra
import numpy as np
import torch
from gymnasium.wrappers import RecordVideo
from omegaconf import DictConfig


def load_checkpoint(checkpoint_path):
    """Load a checkpoint file containing the archive and configurations."""
    with open(checkpoint_path, "rb") as f:
        checkpoint = pickle.load(f)
    return checkpoint


def select_diverse_policies(archive, num_policies):
    """
    Select a diverse set of high-quality policies from the archive.

    Args:
        archive: The QD archive
        num_policies: Number of policies to select (including the best elite)

    Returns:
        A list of elite dictionaries representing diverse, high-quality policies
    """
    # Get the best performing policy first
    selected_policies = [archive.best_elite]
    selected_indices = []

    # Get all archive data
    archive_data = archive.data()
    measures = np.array(archive_data["measures"])
    objectives = np.array(archive_data["objective"])

    if len(measures) == 0:
        logging.warning("Archive is empty!")
        return selected_policies

    # Sort indices by objective values (assuming higher is better)
    sorted_indices = np.argsort(objectives)[::-1]

    # Select top 50% of policies
    num_top_policies = max(1, len(sorted_indices) // 2)
    top_indices = sorted_indices[:num_top_policies]

    # Filter measures and archive data for top performing policies
    top_measures = measures[top_indices]
    top_archive_data = {
        "solution": [archive_data["solution"][i] for i in top_indices],
        "objective": objectives[top_indices],
        "measures": top_measures,
    }

    # Calculate measure ranges to normalize distances
    measure_min = np.min(top_measures, axis=0)
    measure_max = np.max(top_measures, axis=0)
    measure_range = np.maximum(
        measure_max - measure_min, 1e-6
    )  # Avoid division by zero

    # Find index of best elite in top_indices (to avoid selecting it again)
    for i, idx in enumerate(top_indices):
        if objectives[idx] == archive.best_elite["objective"] and np.array_equal(
            measures[idx], archive.best_elite["measures"]
        ):
            selected_indices.append(i)
            break

    while len(selected_indices) < min(len(top_indices), num_policies):
        # Calculate normalized distances to all selected points
        max_min_distance = -np.inf
        farthest_idx = -1

        for i in range(len(top_measures)):
            if i in selected_indices:
                continue

            # Compute minimum distance to any already selected policy
            min_distance = np.inf
            for j in selected_indices:
                dist = np.linalg.norm(
                    (top_measures[i] - top_measures[j]) / measure_range
                )
                min_distance = min(min_distance, dist)

            # Update farthest point if this one is farther
            if min_distance > max_min_distance:
                max_min_distance = min_distance
                farthest_idx = i

        if farthest_idx != -1:
            selected_indices.append(farthest_idx)

            # Create elite dictionary and add to selected policies
            elite = {
                "solution": top_archive_data["solution"][farthest_idx],
                "objective": top_archive_data["objective"][farthest_idx],
                "measures": top_archive_data["measures"][farthest_idx],
            }
            selected_policies.append(elite)

    return selected_policies


def record_policy_videos(env_id, policy, video_path, num_episodes=1):
    """Record videos of the policy's behavior."""
    env = gym.make(env_id, render_mode="rgb_array")  # , continuous=True)
    env = RecordVideo(
        env,
        video_folder=video_path,
        name_prefix="policy",
        episode_trigger=lambda x: True,
    )

    for episode in range(num_episodes):
        obs, _ = env.reset()
        done = False

        while not done:
            with torch.no_grad():
                action = policy.act(obs)
            obs, _, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

    env.close()


@hydra.main(config_path="../conf", config_name="viz", version_base=None)
def main(cfg: DictConfig):
    # Setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.basicConfig(level=logging.INFO)

    # Load checkpoint
    checkpoint_path = Path(cfg.checkpoint_path)
    checkpoint = load_checkpoint(checkpoint_path)
    archive = checkpoint["archive"]

    # Select diverse policies
    selected_policies = select_diverse_policies(archive, num_policies=5)

    # Create video directory
    video_dir = checkpoint_path.parent / "videos"
    video_dir.mkdir(parents=True, exist_ok=True)

    # Record videos for each selected policy
    for i, elite in enumerate(selected_policies):
        policy_type = "best" if i == 0 else f"corner_{i}"
        logging.info(
            f"Recording {policy_type} policy "
            f"(Objective: {elite['objective']:.3f}, "
            f"Measures: {elite['measures']})"
        )

        # Instantiate policy
        policy = (
            hydra.utils.instantiate(checkpoint["agent_cfg"])
            .from_numpy(elite["solution"])
            .to(device)
        )

        # Record video
        policy_video_dir = video_dir / f"policy_{policy_type}"
        record_policy_videos(
            cfg.env_id,
            policy,
            policy_video_dir,
            num_episodes=cfg.episodes_per_policy,
        )


if __name__ == "__main__":
    main()
