import os
from collections import defaultdict
from pathlib import Path

import gymnasium
import numpy as np
import torch
from joblib import Parallel, delayed
from matplotlib import pyplot as plt
from matplotlib.colors import Normalize
from mpl_toolkits.mplot3d import Axes3D  # Required for 3D plotting # noqa
from scipy.interpolate import griddata
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv
from torch import Tensor, nn
from tqdm import tqdm
import seaborn as sns
from compression_autoencoder.policies.policy import Policy
from compression_autoencoder.utils.misc import get_seed_sequence

# Set a professional and clean plot style
plt.style.use("seaborn-v0_8-whitegrid")


def evaluate_policies_batch(
    vec_env: VecEnv,
    sample_policy: Policy,
    policies_weights: Tensor,
    n_eps_per_policy: int,
    device: torch.device,
) -> tuple[np.ndarray, np.ndarray]:  # (mean_reward, total_episodes)
    assert policies_weights.shape[0] == vec_env.num_envs, (
        "Number of policies must match the number of environments."
    )

    n_envs = vec_env.num_envs
    policy_reward_episodes = np.empty((n_envs, n_eps_per_policy), dtype=float)
    obs = vec_env.reset()
    eps_counts = np.zeros(n_envs, dtype=int)
    current_episode_rewards = np.zeros(n_envs)
    num_completed_steps = np.zeros(n_envs)
    num_total_completed_steps = np.zeros(n_envs)

    while np.any(eps_counts < n_eps_per_policy):
        with torch.no_grad():
            actions = (
                sample_policy.predict(
                    torch.from_numpy(obs).float().to(device),
                    policies_weights,
                    deterministic=True,
                )
                .cpu()
                .numpy()
            )

        obs, rew, dones, _ = vec_env.step(actions)
        current_episode_rewards += rew
        num_completed_steps += 1

        finished_indices = np.logical_and(dones, eps_counts < n_eps_per_policy)
        policy_reward_episodes[finished_indices, eps_counts[finished_indices]] = (
            current_episode_rewards[finished_indices]
        )
        current_episode_rewards[finished_indices] = 0.0
        eps_counts[finished_indices] += 1
        num_total_completed_steps[finished_indices] += num_completed_steps[
            finished_indices
        ]
        num_completed_steps[finished_indices] = 0

    mean_reward = np.mean(policy_reward_episodes, axis=1)
    return mean_reward, num_total_completed_steps


def evaluate_single_policy(
    vec_env: VecEnv,
    sample_policy: Policy,
    single_policy_weights: Tensor,
    n_eval_episodes: int,
    n_envs: int,
    device: torch.device,
) -> tuple[float, float]:  # (mean_reward, total_episodes)
    obs = vec_env.reset()
    total_rewards = 0.0
    current_episode_rewards = np.zeros(n_envs)
    num_total_completed_steps = 0
    num_total_completed_episodes = 0
    num_completed_steps = np.zeros(n_envs, dtype=int)

    while num_total_completed_episodes < n_eval_episodes:
        with torch.no_grad():
            actions = (
                sample_policy.predict(
                    torch.from_numpy(obs).float().to(device),
                    single_policy_weights,
                    deterministic=True,
                )
                .cpu()
                .numpy()
            )

        obs, rew, dones, _ = vec_env.step(actions)
        current_episode_rewards += rew
        num_completed_steps += 1

        finished_indices = np.where(dones)[0]
        num_total_completed_episodes += finished_indices.shape[0]
        total_rewards += np.sum(current_episode_rewards[finished_indices])
        current_episode_rewards[finished_indices] = 0.0
        num_total_completed_steps += np.sum(num_completed_steps[finished_indices])
        num_completed_steps[finished_indices] = 0

    mean_reward = total_rewards / num_total_completed_episodes
    return float(mean_reward), float(num_total_completed_steps)


def _collect_trajectories_chunk(
    vec_env: VecEnv,
    sample_policy: Policy,
    weights_chunk: Tensor,
    n_steps_to_collect: int,
    device: torch.device,
) -> np.ndarray:
    n_policies_in_chunk = weights_chunk.shape[0]
    assert n_policies_in_chunk == vec_env.num_envs, (
        "The number of policies in the chunk must match the number of environments."
    )

    weights_chunk = weights_chunk.to(device)

    all_states = np.empty(
        (n_policies_in_chunk, n_steps_to_collect, vec_env.observation_space.shape[0]),
        dtype=np.float32,
    )
    obs = vec_env.reset()

    for step in range(n_steps_to_collect):
        all_states[:, step, :] = obs
        with torch.no_grad():
            obs_tensor = torch.from_numpy(obs).float().to(device)
            actions = (
                sample_policy.predict(obs_tensor, weights_chunk, deterministic=True)
                .cpu()
                .numpy()
            )

        obs, _, _, _ = vec_env.step(actions)

    return all_states


def _collect_trajectories_worker(
    weights_batch: Tensor,
    env_id: str,
    wrapper_class: gymnasium.Wrapper,
    seed: int,
    device: torch.device,
    layer_shapes: list[tuple[int, int, bool]],
    activation_fn: str,
    n_steps_to_collect: int,
    last_activation_fn: str,
    input_scaler: nn.Module | None = None,
) -> np.ndarray:
    """
    A self-contained worker for joblib. It creates its own environment,
    collects the trajectories for its assigned chunk of policies, and cleans up.
    """
    n_envs = weights_batch.shape[0]

    vec_env = None
    try:
        vec_env = make_vec_env(
            env_id=env_id,
            n_envs=n_envs,
            vec_env_cls=DummyVecEnv,
            wrapper_class=wrapper_class,  # type: ignore
            seed=seed,
        )

        sample_policy = Policy(
            layer_shapes=layer_shapes,
            activation_func=activation_fn,
            last_activation_func=last_activation_fn,
            input_scaler=input_scaler,
            device=device,
        )

        results = _collect_trajectories_chunk(
            vec_env=vec_env,
            sample_policy=sample_policy,
            weights_chunk=weights_batch,
            n_steps_to_collect=n_steps_to_collect,
            device=device,
        )
    finally:
        # Always close the environment to release its resources
        if vec_env is not None:
            vec_env.close()
        if device.type == "cuda":
            torch.cuda.empty_cache()
    return results


def collect_trajectories(
    env_id: str,
    wrapper_class: gymnasium.Wrapper,
    seed: int,
    layer_shapes: list[tuple[int, int, bool]],
    activation_fn: str,
    last_activation_fn: str,
    policies_weights: Tensor,
    eval_batch_size: int,
    n_steps_to_collect: int = 3000,
    input_scaler: nn.Module | None = None,
    n_jobs: int = -1,
) -> np.ndarray:
    """
    Evaluate multiple policies in parallel using joblib.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    batches = torch.split(policies_weights, eval_batch_size)
    rng = np.random.default_rng(seed)
    batch_seeds = get_seed_sequence(rng, len(batches))

    states = Parallel(n_jobs=n_jobs)(
        delayed(_collect_trajectories_worker)(
            weights_batch=batch,
            env_id=env_id,
            wrapper_class=wrapper_class,
            seed=batch_seeds[i],
            device=device,
            layer_shapes=layer_shapes,
            activation_fn=activation_fn,
            n_steps_to_collect=n_steps_to_collect,
            last_activation_fn=last_activation_fn,
            input_scaler=input_scaler,
        )
        for i, batch in enumerate(tqdm(batches, desc="Dispatching collection jobs"))
    )

    return np.concatenate(states, axis=0)  # type: ignore


def _multi_evaluate_chunk(
    vec_env: VecEnv,
    sample_policy: Policy,
    weights_chunk: Tensor,
    stats_to_collect: list[str],
    n_eval_episodes: int,
    device: torch.device,
) -> dict[str, np.ndarray]:
    n_policies_in_chunk = weights_chunk.shape[0]
    assert n_policies_in_chunk == vec_env.num_envs, (
        "The number of policies in the chunk must match the number of environments."
    )

    weights_chunk = weights_chunk.to(device)

    # Use a dictionary of lists to dynamically store episode results
    episode_stats: dict[str, list[list[float]]] = {
        stat: [[] for _ in range(n_policies_in_chunk)] for stat in stats_to_collect
    }

    episode_counts = np.zeros(n_policies_in_chunk, dtype=int)
    obs = vec_env.reset()

    while np.any(episode_counts < n_eval_episodes):
        with torch.no_grad():
            obs_tensor = torch.from_numpy(obs).float().to(device)
            actions = (
                sample_policy.predict(obs_tensor, weights_chunk, deterministic=True)
                .cpu()
                .numpy()
            )

        obs, _, dones, infos = vec_env.step(actions)

        done_indices = np.where(dones)[0]
        for env_idx in done_indices:
            episode_counts[env_idx] += 1
            for stat in stats_to_collect:
                if stat in infos[env_idx]:
                    episode_stats[stat][env_idx].append(infos[env_idx][stat])

    final_stats = {}
    for stat, all_policy_episodes in episode_stats.items():
        mean_rewards = [np.mean(episodes) for episodes in all_policy_episodes]
        final_stats[stat] = np.array(mean_rewards)

    return final_stats


def _evaluate_worker(
    weights_batch: Tensor,
    stats_to_collect: list[str],
    n_eval_episodes: int,
    env_id: str,
    wrapper_class: gymnasium.Wrapper,
    seed: int,
    device: torch.device,
    layer_shapes: list[tuple[int, int, bool]],
    activation_fn: str,
    last_activation_fn: str,
    input_scaler: nn.Module | None = None,
) -> dict[str, np.ndarray]:
    """
    A self-contained worker for joblib. It creates its own environment,
    evaluates its assigned chunk of policies, and cleans up.
    """
    n_envs = weights_batch.shape[0]

    vec_env = None
    try:
        vec_env = make_vec_env(
            env_id=env_id,
            n_envs=n_envs,
            vec_env_cls=DummyVecEnv,
            wrapper_class=wrapper_class,  # type: ignore
            seed=seed,
        )

        sample_policy = Policy(
            layer_shapes=layer_shapes,
            activation_func=activation_fn,
            last_activation_func=last_activation_fn,
            input_scaler=input_scaler,
            device=device,
        )

        results = _multi_evaluate_chunk(
            vec_env=vec_env,
            sample_policy=sample_policy,
            weights_chunk=weights_batch,
            stats_to_collect=stats_to_collect,
            n_eval_episodes=n_eval_episodes,
            device=device,
        )
    finally:
        # Always close the environment to release its resources
        if vec_env is not None:
            vec_env.close()
        if device.type == "cuda":
            torch.cuda.empty_cache()
    return results


def multi_reward_evaluate(
    env_id: str,
    wrapper_class: gymnasium.Wrapper,
    seed: int,
    layer_shapes: list[tuple[int, int, bool]],
    activation_fn: str,
    last_activation_fn: str,
    policies_weights: Tensor,
    stats_to_collect: list[str],
    eval_batch_size: int,
    n_eval_episodes: int = 10,
    input_scaler: nn.Module | None = None,
    n_jobs: int = -1,
) -> dict[str, list[float]]:
    """
    Evaluate multiple policies in parallel using joblib.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    batches = torch.split(policies_weights, eval_batch_size)
    rng = np.random.default_rng(seed)
    batch_seeds = get_seed_sequence(rng, len(batches))

    results_from_workers = Parallel(n_jobs=n_jobs)(
        delayed(_evaluate_worker)(
            weights_batch=batch,
            stats_to_collect=stats_to_collect,
            n_eval_episodes=n_eval_episodes,
            env_id=env_id,
            wrapper_class=wrapper_class,
            seed=batch_seeds[i],
            device=device,
            layer_shapes=layer_shapes,
            activation_fn=activation_fn,
            last_activation_fn=last_activation_fn,
            input_scaler=input_scaler,
        )
        for i, batch in enumerate(tqdm(batches, desc="Dispatching evaluation jobs"))
    )

    # Aggregate results from all workers
    aggregated_stats = defaultdict(list)
    for worker_result in results_from_workers:
        for stat_name, values in worker_result.items():
            aggregated_stats[stat_name].extend(values.tolist())

    return dict(aggregated_stats)


def multi_reward_visualize(
    codes: np.ndarray,
    stats: list[str],
    codes_stats: dict[str, list[float]],
    save_path: str | Path = ".",
    points_per_axis: int | None = None,
    show_colorbar: bool = True,
) -> None:
    """
    Generates and saves visualizations of policy performance in the latent space,
    aligned with a professional, publication-ready style.

    This function handles 1D, 2D, and 3D latent spaces with tailored plots:
    - 1D: A single 2D scatter plot showing a line of dots colored by performance.
    - 2D: Two plots per metric:
        1. A 2D heatmap using pcolormesh for a smooth visualization.
        2. A 3D interpolated surface plot.
    - 3D: A single 3D scatter plot where both color and size of points
          represent performance, enhancing visual clarity.

    Args:
        codes: A numpy array of latent codes, shape (N, D) where D is the dimension.
        stats: A list of metric names (strings) to plot from codes_stats.
        codes_stats: A dictionary mapping metric names to their corresponding values.
        save_path: The directory where plot images will be saved.
        points_per_axis: The number of points along each axis for grid-based data.
        show_colorbar: A flag to control the visibility of the color bar on plots.
    """
    # --- Set consistent, publication-quality plotting style ---
    sns.set_theme(style="whitegrid", palette="colorblind")
    plt.rcParams.update({
        "font.family": "serif", "font.serif": ["Times New Roman", "DejaVu Serif"],
        "font.size": 18,
        "axes.labelsize": 20,      # Reduced for better fit
        "xtick.labelsize": 18,     # Reduced for better fit
        "ytick.labelsize": 18,     # Reduced for better fit
        "legend.fontsize": 14,
        "figure.figsize": (8, 8),
    })

    os.makedirs(save_path, exist_ok=True)
    latent_dim = codes.shape[1]

    # --- Handle high-dimensional case (not plottable) ---
    if latent_dim > 3:
        fig, ax = plt.subplots()
        for metric in stats:
            metric_values = np.array(codes_stats[metric])
            min_val, max_val = np.min(metric_values), np.max(metric_values)
            text_str = f"{metric.replace('_', ' ').title()} Range:\nMin: {min_val:.3f}\nMax: {max_val:.3f}"
            ax.text(
                0.5, 0.5, text_str,
                fontsize=18, ha="center", va="center", transform=ax.transAxes,
                bbox=dict(boxstyle="round,pad=0.5", facecolor="lightyellow", edgecolor="gray")
            )
            ax.axis("off")
            output_file = os.path.join(save_path, f"{metric}_range_text.png")
            fig.savefig(output_file, dpi=300, bbox_inches="tight")
            plt.close(fig)
            print(f" -> Saved range text plot to {output_file}")
        return

    # --- Main plotting loop for 1D, 2D, 3D ---
    for metric in stats:
        metric_values = np.array(codes_stats[metric])
        norm = Normalize(vmin=metric_values.min(), vmax=metric_values.max())
        cmap = "viridis"

        print(f"Plotting Metric: '{metric}' for {latent_dim}D latent space...")

        # --- 1D Latent Space ---
        if latent_dim == 1:
            fig, ax = plt.subplots(figsize=(8, 8))
            scatter = ax.scatter(codes[:, 0], np.zeros_like(codes[:, 0]), c=metric_values, cmap=cmap, norm=norm, s=100)
            
            if show_colorbar:
                cbar = fig.colorbar(scatter, ax=ax)

            ax.axis('off')
            plt.tight_layout()
            output_file = os.path.join(save_path, f"{metric}_1D_scatter.pdf")
            fig.savefig(output_file, dpi=300, bbox_inches="tight")
            plt.close(fig)
            print(f" -> Saved 1D plot to {output_file}")

        # --- 2D Latent Space ---
        elif latent_dim == 2:
            if points_per_axis is None:
                print("Error: `points_per_axis` is required for 2D plotting.")
                continue
            
            grid_shape = (points_per_axis, points_per_axis)
            try:
                X = codes[:, 0].reshape(grid_shape)
                Y = codes[:, 1].reshape(grid_shape)
                Z = metric_values.reshape(grid_shape)
            except ValueError:
                print(f"Error: Could not reshape data to {grid_shape}. Check `points_per_axis`.")
                continue

            # Plot 1: True Heatmap with pcolormesh
            fig1, ax1 = plt.subplots(figsize=(8, 8))
            mesh = ax1.pcolormesh(X, Y, Z, cmap=cmap, norm=norm, shading="gouraud")
            
            if show_colorbar:
                cbar = fig1.colorbar(mesh, ax=ax1)
            
            ax1.set_aspect("equal")
            ax1.axis("off")
            plt.tight_layout()
            output_file1 = os.path.join(save_path, f"{metric}_2D_heatmap.pdf")
            fig1.savefig(output_file1, dpi=300, bbox_inches="tight")
            plt.close(fig1)
            print(f" -> Saved 2D heatmap plot to {output_file1}")

            # Plot 2: 3D Interpolated Surface Plot
            fig2 = plt.figure() # Let rcParams handle figsize
            ax2 = fig2.add_subplot(projection="3d")
            
            grid_x, grid_y = np.mgrid[X.min():X.max():100j, Y.min():Y.max():100j]
            grid_z = griddata(codes, metric_values, (grid_x, grid_y), method="cubic")
            face_colors = plt.get_cmap(cmap)(norm(grid_z))
            
            ax2.plot_surface(grid_x, grid_y, grid_z, facecolors=face_colors, rstride=1, cstride=1, antialiased=True, shade=False)
            
            if show_colorbar:
                mappable = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
                cbar = fig2.colorbar(mappable, ax=ax2, shrink=0.6, aspect=20, pad=0.1)
            
            ax2.set_zlabel("Performance", labelpad=10)
            ax2.view_init(elev=20, azim=135)

            
            output_file2 = os.path.join(save_path, f"{metric}_3D_surface.pdf")
            fig2.savefig(output_file2, dpi=300, bbox_inches="tight")
            plt.close(fig2)
            print(f" -> Saved 3D surface plot to {output_file2}")

        # --- 3D Latent Space ---
        elif latent_dim == 3:
            fig = plt.figure(figsize=(8, 8))  # Let rcParams handle figsize
            ax = fig.add_subplot(projection="3d")
            
            sizes = 5 + 100 * norm(metric_values)
            scatter = ax.scatter(codes[:, 0], codes[:, 1], codes[:, 2], c=metric_values, cmap=cmap, norm=norm, s=sizes, alpha=0.6, depthshade=True)
            
            if show_colorbar:
                cbar = fig.colorbar(scatter, ax=ax, shrink=0.6, aspect=20, pad=0.1)
            
            ax.view_init(elev=20, azim=45)
            ax.axis('off') 
            
            plt.tight_layout()
            output_file = os.path.join(save_path, f"{metric}_3D_scatter.pdf")
            fig.savefig(output_file, dpi=300, bbox_inches="tight")
            plt.close(fig)
            print(f" -> Saved 3D scatter plot to {output_file}")