from dataclasses import dataclass
from typing import Dict, List, Optional

import gymnasium as gym
import hydra
import matplotlib.pyplot as plt
import numpy as np
import ray
import torch
import umap
from omegaconf import DictConfig


@dataclass
class Trajectory:
    states: np.ndarray
    actions: np.ndarray
    rewards: np.ndarray
    measures: Optional[np.ndarray] = None


@ray.remote  # (num_gpus=0.02)
class PolicyEvaluator:
    def __init__(
        self,
        env_id: str,
        env_kwargs: Dict,
        num_envs: int,
        agent_cfg: DictConfig,
        wrappers: Optional[List[gym.Wrapper]] = None,
        measure_names: Optional[List[str]] = None,
    ):
        """Initialize policy evaluator.

        Args:
            env_id: Gymnasium environment ID
            num_envs: Number of parallel environments to run
            agent_cfg: Agent configuration used to instantiate agents
            wrappers: None or list of wrappers passed to the vector env constructor
            measure_names: None or list of handcrafted measures that will be extracted
        """
        torch.set_num_threads(1)
        self.agent_cfg = agent_cfg
        if wrappers and measure_names:
            self.include_measures = True
            self.measure_names = measure_names
            self.wrappers = wrappers
        else:
            self.include_measures = False
            self.measure_names = None
            self.wrappers = []

        self.envs = gym.make_vec(
            env_id,
            num_envs=num_envs,
            vectorization_mode=gym.VectorizeMode.ASYNC,
            wrappers=self.wrappers,
            **env_kwargs,
        )

    def evaluate_policy(
        self,
        policy_params: np.ndarray,
        n_trajectories: int,
        device: str = "cpu",
    ) -> List[Trajectory]:
        """Evaluate a single policy."""
        policy = (
            hydra.utils.instantiate(self.agent_cfg).from_numpy(policy_params).to(device)
        )
        if self.include_measures:
            return collect_trajectories(
                self.envs,
                policy,
                n_trajectories,
                self.measure_names,
            )
        else:
            return collect_trajectories(self.envs, policy, n_trajectories)


def collect_trajectories(
    env: gym.vector.VectorEnv,
    policy: torch.nn.Module,
    n_trajectories: int,
    measure_names: Optional[List[str]] = None,
) -> List[Trajectory]:
    """
    Collects n trajectories using the given policy in a vectorized environment.

    Args:
        env: A vectorized gymnasium environment
        policy: Policy object with a get_action(state) method that returns actions
        n_trajectories: Number of trajectories to collect
        measure_names: Name of measures that should be included in the trajectory. The
            env should return these measures as part of the final info dict.

    Returns:
        List of Trajectory objects containing states, actions, and rewards
    """
    include_measures = measure_names is not None and len(measure_names) > 0
    # Initialize storage for active trajectories
    active_trajectories = {
        i: {"states": [], "actions": [], "rewards": []} for i in range(env.num_envs)
    }

    completed_trajectories = []
    autoreset = np.zeros(env.num_envs, dtype=bool)

    # Initial reset
    obs, infos = env.reset()

    while len(completed_trajectories) < n_trajectories:
        # Get actions from policy
        with torch.no_grad():
            actions = policy.act(obs)

        # Step environments
        next_obs, rewards, terminations, truncations, infos = env.step(actions)

        # Update trajectories that haven't just reset
        for i in range(env.num_envs):
            if not autoreset[i]:
                active_trajectories[i]["states"].append(obs[i].astype(np.float32))
                active_trajectories[i]["actions"].append(actions[i].astype(np.float32))
                active_trajectories[i]["rewards"].append(rewards[i])

        # Handle episode terminations
        autoreset = np.logical_or(terminations, truncations)
        for i in range(env.num_envs):
            if autoreset[i]:
                # Convert trajectory lists to numpy arrays
                measures = (
                    np.array([float(infos[m][i]) for m in measure_names])
                    if include_measures
                    else None
                )
                traj = Trajectory(
                    states=np.array(active_trajectories[i]["states"]),
                    actions=np.array(active_trajectories[i]["actions"]),
                    rewards=np.array(active_trajectories[i]["rewards"]),
                    measures=measures,
                )
                completed_trajectories.append(traj)

                # Reset storage for this environment
                active_trajectories[i] = {"states": [], "actions": [], "rewards": []}

                # Break if we have enough trajectories
                if len(completed_trajectories) >= n_trajectories:
                    break

        # Update observation
        obs = next_obs

    return completed_trajectories[:n_trajectories]


def umap_archive_plot(
    archive, ax=None, cmap="magma", point_size=3, vmax=None, vmin=None
):
    """Plots a UMAP projection of the top 3000 highest objective points
    in a high-dimensional GridArchive.

    This function reduces the high-dimensional measure space to 2D using UMAP
    and visualizes the points with a scatter plot, where color represents
    the objective (fitness) value.

    Args:
        archive (GridArchive): A GridArchive with high-dimensional measures.
        ax (matplotlib.axes.Axes, optional): Axes on which to plot.
            If None, the current axis is used.
        cmap (str): Colormap for the scatter plot.
        point_size (int): Size of the scatter plot points.
        vmax (float, optional): Maximum value for the color scale.
        vmin (float, optional): Minimum value for the color scale.
    """
    # Retrieve archive data
    measures = archive.data("measures")
    objectives = archive.data("objective")

    if len(objectives) < 5:
        # UMAP will not work with too few datapoints
        return

    # Select the top 3000 points based on objective values
    top_indices = np.argsort(objectives)[-3000:]  # Get indices of top 3000
    measures = measures[top_indices]
    objectives = objectives[top_indices]

    # Reduce to 2D using UMAP
    reducer = umap.UMAP(n_components=2, low_memory=False)
    embedded = reducer.fit_transform(measures)

    # Plot the results
    ax = plt.gca() if ax is None else ax
    scatter = ax.scatter(
        embedded[:, 0],
        embedded[:, 1],
        c=objectives,
        cmap=cmap,
        s=point_size,
        vmin=vmin,
        vmax=vmax,
    )
    plt.colorbar(scatter, ax=ax, label="Objective Value")

    return ax


class StateNormalizer:
    def __init__(self, state_dim, device, buffer_size=10_000_000, eps=1e-8):
        """
        Initialize state normalizer with a buffer-based normalization strategy.

        Args:
            state_dim: Dimension of the state space
            device: torch device to store tensors on
            buffer_size: Maximum number of states to keep in the buffer
            eps: Small constant for numerical stability
        """
        self.state_dim = state_dim
        self.device = device
        self.eps = eps
        self.buffer_size = buffer_size

        self.buffer = torch.zeros((buffer_size, state_dim))
        self.buffer_count = 0

        # Normalization statistics
        self.mean = torch.zeros(state_dim, device=device)
        self.std = torch.ones(state_dim, device=device)

    @torch.no_grad()
    def update_stats(self, states):
        """
        Add a subset of states from an episode to the buffer.

        Args:
            states: Tensor of states from one episode [T, state_dim] or [B, T, state_dim]
        """
        states = states.reshape(-1, self.state_dim)
        # num_samples = torch.randint(
        #     1, max(2, states.shape[0] // 5), (1,)
        # ).item()  # Up to 20% of states
        num_samples = states.shape[0] // 10  # 10% of states
        sampled_states = states[torch.randperm(states.shape[0])[:num_samples]]

        num_new = sampled_states.shape[0]
        num_free = max(0, self.buffer_size - self.buffer_count)  # Free space in buffer

        if num_free > 0:
            # Fill up remaining space first
            self.buffer[
                self.buffer_count : self.buffer_count + min(num_free, num_new)
            ] = sampled_states[:num_free].cpu()
            self.buffer_count += min(num_free, num_new)

        if num_new > num_free:
            # Replace random elements in buffer
            replace_indices = torch.randint(0, self.buffer_size, (num_new - num_free,))
            self.buffer[replace_indices] = sampled_states[num_free:].cpu()

    @torch.no_grad()
    def apply_update(self):
        """Recompute mean and standard deviation from the buffer and update normalization parameters."""
        if self.buffer_count > 0:
            buffer_subset = self.buffer[
                : self.buffer_count
            ]  # Only consider filled buffer entries
            self.mean = buffer_subset.mean(dim=0)
            self.std = buffer_subset.std(dim=0, unbiased=True)

    @torch.no_grad()
    def normalize(self, states):
        """
        Normalize states using current statistics.

        Args:
            states: Tensor of states to normalize

        Returns:
            Normalized states
        """
        return (states.to(self.device) - self.mean) / (self.std + self.eps)

    def get_stats(self):
        """Return current normalization statistics"""
        return {
            "mean": self.mean.cpu().numpy(),
            "std": self.std.cpu().numpy(),
            "buffer_size": self.buffer_count,
        }


# Testing measures. TODO: Remove
if __name__ == "__main__":
    from agents import MLPAgent
    from qd.wrappers import BipedalBehavioralWrapper

    envs = gym.make_vec(
        "BipedalWalker-v3",
        num_envs=5,
        vectorization_mode=gym.VectorizeMode.SYNC,
        wrappers=[BipedalBehavioralWrapper],
    )
    agent = MLPAgent(24, 4)
    trajs = collect_trajectories(
        envs,
        agent,
        50,
        measure_names=[
            "left_contact_freq",
            "right_contact_freq",
            "normalized_hull_angle",
            "normalized_height_var",
        ],
    )
    print("Done")
