import logging
import os
import pickle
import time
from typing import Any, Dict, List, Union

import hydra
import matplotlib.pyplot as plt
import numpy as np
import ray
import torch
from omegaconf import DictConfig, OmegaConf
from ribs.archives import CVTArchive, GridArchive
from ribs.emitters import EvolutionStrategyEmitter, GaussianEmitter, IsoLineEmitter
from ribs.schedulers import Scheduler
from ribs.visualize import grid_archive_heatmap

from src.logging_utils import MetricLogger
from src.measure_maps.cwpca import CalibratedWeightedPCAMap
from src.measure_maps.greedy import GreedyAnchorMap, SubspaceAnchorMap
from src.measure_maps.pca import PCAMap
from src.measure_maps.wpca import WeightedPCAMap
from src.qd.pga_emitter import PGAEmitter
from src.qd.td3 import TD3
from src.utils import PolicyEvaluator, umap_archive_plot


class AutoQD:
    def __init__(
        self,
        env_cfg: DictConfig,
        qd: DictConfig,
        agent: DictConfig,
        embedding: DictConfig,
        measures_dim: int,
        measure_margin: float,
        total_iterations: int,
        update_interval: Union[int, List[int]],
        refine_frac: float,
        map_type: str,
        n_evals: int,
        exp_dir: str,
        metric_logger: MetricLogger,
        info_logger: logging.Logger,
        device,
        seed,
        **kwargs,
    ) -> None:
        """Initialize the AutoQD algorithm.

        Args:
            env_cfg: Configuration for environment
            qd: Configuration for QD optimization
            agent: Agent configuration
            embedding: Embedding map configuration
            measures_dim: Dimensionality of measure space
            measure_margin: The margin added to the measures to allow extrapolation
            total_iterations: Number of QD iterations in total.
            update_interval: Steps between measure map updates
            refine_frac: Fraction of top solutions that contribute to measure map's refinement
            map_type: Type of measure mapping ("greedy", "subspace", or "pca")
            n_evals: Number of evaluations per solution
            exp_dir: Directory for experiment results
            metric_logger: MetricLogger instance
            info_logger: Regular Logger instance
            device: cuda or cpu
            seed: experiment seed
        """
        self.env_cfg = env_cfg
        self.agent_cfg = agent
        self.qd_cfg = qd
        self.embedding_cfg = embedding
        self.device = device
        self.seed = seed
        self.map_type = map_type
        self.measure_margin = measure_margin
        self.logger = metric_logger
        self.info_logger = info_logger
        self.setup_qd(qd, agent, measures_dim, embedding.dim)
        # mapping trajectories to embedding space
        self.embedding_map = hydra.utils.instantiate(embedding, device=device)
        # mapping embeddings to (low-dim) measures
        self.measure_map = self._create_measure_map(
            map_type, measures_dim, embedding.dim
        )

        self.total_iterations = total_iterations
        update_interval = OmegaConf.to_container(update_interval)
        self.update_interval = (
            update_interval
            if isinstance(update_interval, list)
            else list(range(update_interval, total_iterations, update_interval))
        )
        assert 0 < refine_frac <= 1
        self.refine_frac = refine_frac
        self.measure_dim = measures_dim
        self.embedding_dim = embedding.dim
        self.n_evals = n_evals
        self.exp_dir = exp_dir

        # Setup ray multiprocessing
        if not ray.is_initialized():
            ray.init()
        cpus_per_worker = 2  # Adjust based on machine
        num_workers = int(max(1, (ray.cluster_resources()["CPU"]) // cpus_per_worker))
        info_logger.info(
            f"Using {num_workers} workers each with {cpus_per_worker} CPUs"
        )

        self.evaluators = [
            PolicyEvaluator.options(num_cpus=cpus_per_worker).remote(
                env_id=self.env_cfg.env_id,
                env_kwargs=(
                    OmegaConf.to_container(self.env_cfg.env_kwargs, resolve=True)
                    if "env_kwargs" in self.env_cfg
                    else {}
                ),
                num_envs=n_evals,
                agent_cfg=self.agent_cfg,
            )
            for _ in range(num_workers)
        ]

    def _create_measure_map(
        self, map_type: str, measures_dim: int, embedding_dim: int
    ) -> Any:
        """Create the appropriate measure mapping based on specified type.

        Args:
            map_type: Type of measure mapping
            measures_dim: Dimensionality of measure space
            embedding_dim: Dimensionality of embedding space

        Returns:
            Initialized measure mapping object

        Raises:
            ValueError: If map_type is not recognized
        """
        if map_type == "greedy":
            return GreedyAnchorMap(measures_dim, embedding_dim, device=self.device)
        elif map_type == "subspace":
            return SubspaceAnchorMap(measures_dim, embedding_dim, device=self.device)
        elif map_type == "pca":
            return PCAMap(measures_dim, embedding_dim, device=self.device)
        elif map_type == "wpca":
            return WeightedPCAMap(measures_dim, embedding_dim, device=self.device)
        elif map_type == "cwpca":
            return CalibratedWeightedPCAMap(
                measures_dim, embedding_dim, device=self.device
            )
        raise ValueError(f"Undefined map_type: {map_type}")

    def setup_qd(
        self, qd: DictConfig, agent: DictConfig, measures_dim: int, embedding_dim: int
    ) -> None:
        """Initialize QD components including archive and emitters.

        Args:
            qd: QD configuration
            agent: Agent configuration
            measures_dim: Dimensionality of measure space
            embedding_dim: Dimensionality of embedding space

        Raises:
            AssertionError: If QD method is not supported
        """
        assert qd.method in [
            "cma_me",
            "sep_cma_mae",
            "cma_mae",
            "map_elites",
            "pga_me",
        ], f"'{qd.method}' is not a valid QD method"

        solution_dim = hydra.utils.instantiate(agent).num_params()
        dims = tuple((qd.archive_dim for _ in range(measures_dim)))
        if qd.init_from_pretrained:
            init_solution = self.load_pretrained_policy_from_file()
        else:
            init_solution = hydra.utils.instantiate(agent).to_numpy()

        self.info_logger.info(f"Number of agent parameters: {len(init_solution)}")
        # determine range of measures
        if self.map_type in ["pca", "wpca", "cwpca"]:
            # Start with a smaller range because measures are clustered near 0
            ranges = list(
                [
                    (-self.measure_margin / 10, self.measure_margin / 10)
                    for _ in range(measures_dim)
                ]
            )
        else:
            ranges = list(
                [
                    (-self.measure_margin, 1.0 + self.measure_margin)
                    for _ in range(measures_dim)
                ]
            )

        if qd.archive == "grid":
            self.archive = GridArchive(
                solution_dim=solution_dim,
                dims=dims,
                ranges=ranges,
                qd_score_offset=qd.qd_score_offset,
                extra_fields={"embedding": ((embedding_dim,), np.float32)},
                seed=self.seed,
            )
        elif qd.archive == "cvt":
            self.archive = CVTArchive(
                solution_dim=solution_dim,
                cells=qd.archive_dim,
                ranges=ranges,
                qd_score_offset=qd.qd_score_offset,
                extra_fields={"embedding": ((embedding_dim,), np.float32)},
                seed=self.seed,
            )
        else:
            raise ValueError(f"Archive type {qd.archive} not recognized.")

        # Create emitters, results_archive (if needed) and schedulers for each algo
        if qd.method in ["cma_mae", "sep_cma_mae"]:
            if qd.archive == "grid":
                self.train_archive = GridArchive(
                    solution_dim=solution_dim,
                    dims=dims,
                    ranges=ranges,
                    qd_score_offset=qd.qd_score_offset,
                    threshold_min=qd.qd_score_offset,
                    learning_rate=qd.archive_lr,
                    extra_fields={"embedding": ((embedding_dim,), np.float32)},
                    seed=self.seed,
                )
            elif qd.archive == "cvt":
                self.train_archive = CVTArchive(
                    solution_dim=solution_dim,
                    cells=qd.archive_dim,
                    ranges=ranges,
                    qd_score_offset=qd.qd_score_offset,
                    threshold_min=qd.qd_score_offset,
                    learning_rate=qd.archive_lr,
                    extra_fields={"embedding": ((embedding_dim,), np.float32)},
                    custom_centroids=self.archive.centroids,
                    seed=self.seed,
                )
            if qd.method == "cma_mae":
                self.emitters = [
                    EvolutionStrategyEmitter(
                        self.train_archive,
                        x0=(
                            init_solution
                            if i == 0
                            else hydra.utils.instantiate(agent).to_numpy()
                        ),
                        sigma0=(
                            (qd.sigma0 * (2**i)) if qd.different_sigmas else qd.sigma0
                        ),
                        ranker="imp",
                        selection_rule="mu",
                        restart_rule=qd.restart_rule,
                        batch_size=qd.batch_size,
                        seed=self.seed + i,
                    )
                    for i in range(qd.emitters)
                ]
            elif qd.method == "sep_cma_mae":
                self.emitters = [
                    EvolutionStrategyEmitter(
                        self.train_archive,
                        x0=(
                            init_solution
                            if i == 0
                            else hydra.utils.instantiate(agent).to_numpy()
                        ),
                        sigma0=(
                            (qd.sigma0 * (2**i)) if qd.different_sigmas else qd.sigma0
                        ),
                        ranker="imp",
                        es="sep_cma_es",
                        selection_rule="mu",
                        restart_rule=qd.restart_rule,
                        batch_size=qd.batch_size,
                        seed=self.seed + i,
                    )
                    for i in range(qd.emitters)
                ]
            self.scheduler = Scheduler(
                self.train_archive, self.emitters, result_archive=self.archive
            )

        elif qd.method == "cma_me":
            self.emitters = [
                EvolutionStrategyEmitter(
                    self.archive,
                    x0=(
                        init_solution
                        if i == 0
                        else hydra.utils.instantiate(agent).to_numpy()
                    ),
                    sigma0=(qd.sigma0 * (2**i)) if qd.different_sigmas else qd.sigma0,
                    ranker="2imp",
                    restart_rule=qd.restart_rule,
                    batch_size=qd.batch_size,
                    seed=self.seed + i,
                )
                for i in range(qd.emitters)
            ]
            self.scheduler = Scheduler(self.archive, self.emitters)

        elif qd.method == "pga_me":
            td3 = TD3(qd.td3, agent, seed=self.seed + 2323)
            self.emitters = [
                PGAEmitter(
                    self.archive,
                    td3=td3,
                    x0=hydra.utils.instantiate(agent).to_numpy(),
                    sigma0=qd.pga_sigma0,
                    batch_size=qd.pga_batch_size,
                    init_iters=qd.init_iters,
                    seed=self.seed + 1,
                ),
                IsoLineEmitter(
                    self.archive,
                    iso_sigma=qd.iso_sigma,
                    line_sigma=qd.line_sigma,
                    x0=hydra.utils.instantiate(agent).to_numpy(),
                    batch_size=qd.isoline_batch_size,
                    seed=self.seed,
                ),
            ]
            self.scheduler = Scheduler(self.archive, self.emitters)

        elif qd.method == "map_elites":
            self.emitters = [
                GaussianEmitter(
                    self.archive,
                    x0=init_solution,
                    sigma=qd.sigma,
                    batch_size=qd.batch_size,
                    seed=self.seed,
                )
            ]
            self.scheduler = Scheduler(self.archive, self.emitters)

    def step_qd(self) -> None:
        """Execute one step of QD optimization."""
        timings = {}

        ask_time = time.perf_counter()
        solution_batch = self.scheduler.ask()
        timings["time/ask"] = time.perf_counter() - ask_time

        eval_time = time.perf_counter()
        futures = []
        for i, solution in enumerate(solution_batch):
            evaluator = self.evaluators[i % len(self.evaluators)]
            future = evaluator.evaluate_policy.remote(
                solution, self.n_evals, self.device
            )
            futures.append(future)

        # Collect results as they complete
        trajectory_batches = ray.get(futures)
        timings["time/eval"] = time.perf_counter() - eval_time

        traj_process_time = time.perf_counter()
        objective_batch, measure_batch, embedding_batch = [], [], []
        for trajs in trajectory_batches:
            emb = self.embedding_map.embed_trajectories(trajs)
            meas = self.measure_map(emb)
            obj = np.mean([t.rewards.sum() for t in trajs])

            objective_batch.append(obj)
            measure_batch.append(meas.cpu().numpy())
            embedding_batch.append(emb.cpu().numpy())

            if self.qd_cfg.method == "pga_me":
                # Collected experience will be added to td3's buffer to be used for training
                self.emitters[0].add_experience(trajs)

        objective_batch = np.array(objective_batch)
        measure_batch = np.stack(measure_batch)
        embedding_batch = np.stack(embedding_batch)
        timings["time/traj_process"] = time.perf_counter() - traj_process_time

        self.scheduler.tell(objective_batch, measure_batch, embedding=embedding_batch)
        if self.qd_cfg.method == "pga_me":
            # Train critics and actor of td3
            td3_time = time.perf_counter()
            self.emitters[0].train_td3()
            timings["time/td3_training"] = time.perf_counter() - td3_time
        return timings

    def reset_qd(self, archive_data: Dict[str, np.ndarray]) -> None:
        """Resets archives and emitters by recomputing measures from embeddings.

        Args:
            archive_data: Dict containing `solution`, `objective`, and `embedding` keys
                corresponding to solutions from previous iterations that we want to keep
                in the result archive.
        """
        assert self.qd_cfg.method in [
            "cma_me",
            "cma_mae",
            "sep_cma_mae",
            "map_elites",
            "pga_me",
        ], f"QD method {self.qd_cfg.method} is not supported"

        if self.map_type in ["pca", "wpca", "cwpca"]:
            ranges = list(
                [
                    (-self.measure_margin, self.measure_margin)
                    for _ in range(self.measure_dim)
                ]
            )
        else:
            ranges = list(
                [
                    (-self.measure_margin, 1.0 + self.measure_margin)
                    for _ in range(self.measure_dim)
                ]
            )

        # Reset train archive
        if self.qd_cfg.method in ["cma_mae", "sep_cma_mae"]:
            if self.qd_cfg.archive == "grid":
                self.train_archive = GridArchive(
                    solution_dim=self.train_archive.solution_dim,
                    dims=self.train_archive.dims,
                    ranges=ranges,
                    qd_score_offset=self.qd_cfg.qd_score_offset,
                    threshold_min=self.qd_cfg.qd_score_offset,
                    learning_rate=self.qd_cfg.archive_lr,
                    extra_fields={"embedding": ((self.embedding_dim,), np.float32)},
                    seed=self.seed,
                )
            elif self.qd_cfg.archive == "cvt":
                self.train_archive = CVTArchive(
                    solution_dim=self.train_archive.solution_dim,
                    cells=self.train_archive.cells,
                    ranges=ranges,
                    qd_score_offset=self.qd_cfg.qd_score_offset,
                    threshold_min=self.qd_cfg.qd_score_offset,
                    learning_rate=self.qd_cfg.archive_lr,
                    extra_fields={"embedding": ((self.embedding_dim,), np.float32)},
                    custom_centroids=self.archive.centroids,
                    seed=self.seed,
                )

        # Re-insert solutions into new [result] archive with newly computed measures
        if self.qd_cfg.archive == "grid":
            new_archive = GridArchive(
                solution_dim=self.archive.solution_dim,
                dims=self.archive.dims,
                ranges=ranges,
                qd_score_offset=self.qd_cfg.qd_score_offset,
                extra_fields={"embedding": ((self.embedding_dim,), np.float32)},
                seed=self.seed,
            )
        elif self.qd_cfg.archive == "cvt":
            new_archive = CVTArchive(
                solution_dim=self.archive.solution_dim,
                cells=self.archive.cells,
                ranges=ranges,
                qd_score_offset=self.qd_cfg.qd_score_offset,
                extra_fields={"embedding": ((self.embedding_dim,), np.float32)},
                custom_centroids=self.archive.centroids,
                seed=self.seed,
            )
        for sol, obj, emb in zip(
            archive_data["solution"],
            archive_data["objective"],
            archive_data["embedding"],
        ):
            new_meas = (
                self.measure_map(torch.from_numpy(emb).to(self.device)).cpu().numpy()
            )
            new_archive.add_single(
                sol,
                obj,
                new_meas,
                embedding=emb,
            )
        self.archive = new_archive

        # Reset emitters and scheduler
        if self.qd_cfg.method == "cma_mae":
            # Start emitters from a random solution in the top 20% + best solution
            objectives = self.archive.data("objective")
            solutions = self.archive.data("solution")
            threshold = np.percentile(objectives, 80)
            top_indices = np.where(objectives >= threshold)[0]
            self.emitters = [
                EvolutionStrategyEmitter(
                    self.train_archive,
                    x0=(
                        self.archive.best_elite["solution"].reshape(-1)
                        if i == 0
                        else solutions[np.random.choice(top_indices)].reshape(-1)
                    ),
                    sigma0=(
                        (self.qd_cfg.sigma0 * (2**i))
                        if self.qd_cfg.different_sigmas
                        else self.qd_cfg.sigma0
                    ),
                    ranker="imp",
                    selection_rule="mu",
                    restart_rule=self.qd_cfg.restart_rule,
                    batch_size=self.qd_cfg.batch_size,
                    seed=self.seed + i,
                )
                for i in range(self.qd_cfg.emitters)
            ]
            self.scheduler = Scheduler(
                self.train_archive, self.emitters, result_archive=self.archive
            )
        elif self.qd_cfg.method == "sep_cma_mae":
            # Start emitters from a random solution in the top 20% + best solution
            objectives = self.archive.data("objective")
            solutions = self.archive.data("solution")
            threshold = np.percentile(objectives, 80)
            top_indices = np.where(objectives >= threshold)[0]
            self.emitters = [
                EvolutionStrategyEmitter(
                    self.train_archive,
                    x0=(
                        self.archive.best_elite["solution"].reshape(-1)
                        if i == 0
                        else solutions[np.random.choice(top_indices)].reshape(-1)
                    ),
                    sigma0=(
                        (self.qd_cfg.sigma0 * (2**i))
                        if self.qd_cfg.different_sigmas
                        else self.qd_cfg.sigma0
                    ),
                    ranker="imp",
                    es="sep_cma_es",
                    selection_rule="mu",
                    restart_rule=self.qd_cfg.restart_rule,
                    batch_size=self.qd_cfg.batch_size,
                    seed=self.seed + i,
                )
                for i in range(self.qd_cfg.emitters)
            ]
            self.scheduler = Scheduler(
                self.train_archive, self.emitters, result_archive=self.archive
            )
        elif self.qd_cfg.method == "cma_me":
            # Start emitters from a random solution in the top 20% + best solution
            objectives = self.archive.data("objective")
            solutions = self.archive.data("solution")
            threshold = np.percentile(objectives, 80)
            top_indices = np.where(objectives >= threshold)[0]
            self.emitters = [
                EvolutionStrategyEmitter(
                    self.archive,
                    x0=(
                        self.archive.best_elite["solution"].reshape(-1)
                        if i == 0
                        else solutions[np.random.choice(top_indices)].reshape(-1)
                    ),
                    sigma0=(
                        (self.qd_cfg.sigma0 * (2**i))
                        if self.qd_cfg.different_sigmas
                        else self.qd_cfg.sigma0
                    ),
                    restart_rule=self.qd_cfg.restart_rule,
                    ranker="2imp",
                    batch_size=self.qd_cfg.batch_size,
                    seed=self.seed + i,
                )
                for i in range(self.qd_cfg.emitters)
            ]
            self.scheduler = Scheduler(self.archive, self.emitters)
        elif self.qd_cfg.method == "pga_me":
            # No need to reset TD3 critics and actor
            td3 = self.emitters[0].td3
            self.emitters = [
                PGAEmitter(
                    self.archive,
                    td3=td3,
                    x0=self.archive.best_elite["solution"].reshape(-1),
                    sigma0=self.qd_cfg.pga_sigma0,
                    batch_size=self.qd_cfg.pga_batch_size,
                    init_iters=self.qd_cfg.init_iters,
                    seed=self.seed + 1,
                ),
                IsoLineEmitter(
                    self.archive,
                    iso_sigma=self.qd_cfg.iso_sigma,
                    line_sigma=self.qd_cfg.line_sigma,
                    x0=self.archive.best_elite["solution"].reshape(-1),
                    batch_size=self.qd_cfg.isoline_batch_size,
                    seed=self.seed,
                ),
            ]
        elif self.qd_cfg.method == "map_elites":
            # MAP-Elite emitters don't need reset
            self.scheduler = Scheduler(self.archive, self.emitters)

    def reevaluate_archive(self) -> Dict[str, np.ndarray]:
        """Re evaluates the solutions in self.archive and returns the new embeddings and
        objectives. This is needed when, for instance, the embedding map has changed
        as a result of updating the normalizer.
        """
        futures, solution_batch = [], []
        for i, elite in enumerate(self.archive):
            evaluator = self.evaluators[i % len(self.evaluators)]
            solution_batch.append(elite["solution"])
            future = evaluator.evaluate_policy.remote(
                elite["solution"], self.n_evals, self.device
            )
            futures.append(future)

        # Collect results as they complete
        trajectory_batches = ray.get(futures)

        objective_batch, embedding_batch = [], []
        for trajs in trajectory_batches:
            emb = self.embedding_map.embed_trajectories(trajs)
            obj = np.mean([t.rewards.sum() for t in trajs])

            objective_batch.append(obj)
            embedding_batch.append(emb.cpu().numpy())

        objective_batch = np.array(objective_batch)
        embedding_batch = np.stack(embedding_batch)
        solution_batch = np.stack(solution_batch)

        return {
            "solution": solution_batch,
            "embedding": embedding_batch,
            "objective": objective_batch,
        }

    def refine_measures(self, archive_data: Dict[str, np.ndarray]) -> Dict[str, float]:
        """Refine measure mapping using collected embeddings.

        - Only considers the `self.refine_frac` best solutions.

        - Explicitely takes the data as an argument (instead of using self.archive)
            because with normalization, the embeddings will change, and we would need to
            recompute them before refining the measures.

        Args:
            archive_data: Dict containing `objective` and `embedding` keys corresponding
                to solutions from the archive that we want to use for refinement.

        Returns metrics for logging
        """
        all_embeddings = torch.from_numpy(archive_data["embedding"])
        all_objs = torch.from_numpy(archive_data["objective"])
        num_tops = int(len(all_objs) * self.refine_frac)
        # If too few solutions, consider them all
        if num_tops <= self.measure_dim:
            num_tops = len(all_objs)
        _, indices = all_objs.topk(num_tops)
        if isinstance(self.measure_map, WeightedPCAMap) or isinstance(
            self.measure_map, CalibratedWeightedPCAMap
        ):
            return self.measure_map.adapt(all_embeddings[indices], all_objs[indices])
        else:
            return self.measure_map.adapt(all_embeddings[indices])

    def train(self) -> None:
        """Execute the main training loop."""
        itr = 0
        total_evals = 0  # Total number of trajectories throughout training
        while itr < self.total_iterations:
            timings = self.step_qd()
            itr += 1
            total_evals += sum([e.batch_size for e in self.emitters]) * self.n_evals
            self.logger.update(
                {
                    "archive_size": len(self.archive),
                    "coverage": self.archive.stats.coverage,
                    "mean_obj": self.archive.stats.obj_mean,
                    "max_obj": self.archive.stats.obj_max,
                    "qd_score": self.archive.stats.qd_score,
                    "total_evals": total_evals,
                }
                | timings
            )
            if itr % 10 == 9:
                plot_time = time.perf_counter()
                self.save_archive_img(itr)
                plot_time = time.perf_counter() - plot_time
            else:
                plot_time = 0
            if itr in self.update_interval:
                if self.embedding_map.normalize:
                    # Update state normalizer
                    self.embedding_map.normalizer.apply_update()
                    # Re-evaluate solutions in the archive
                    # Because updating the normalizer chagnes the embeddings
                    archive_data = self.reevaluate_archive()
                else:
                    all_data = self.archive.data()
                    archive_data = {
                        "solution": all_data["solution"],
                        "embedding": all_data["embedding"],
                        "objective": all_data["objective"],
                    }
                refine_time = time.perf_counter()
                metric_info = self.refine_measures(archive_data)
                refine_time = time.perf_counter() - refine_time
                self.logger.update(
                    metric_info | {"time/refine": refine_time, "time/plot": plot_time}
                )
                self.reset_qd(archive_data)

            self.logger.log(itr)
        ray.shutdown()

    def save_archive_img(self, itr: int):
        plt.figure(figsize=(8, 6))

        if self.measure_dim == 2:
            grid_archive_heatmap(
                self.archive, vmin=self.env_cfg.min_score, vmax=self.env_cfg.max_score
            )
        else:
            umap_archive_plot(
                self.archive, vmin=self.env_cfg.min_score, vmax=self.env_cfg.max_score
            )

        plt.title(f"Result Archive at t={itr}")
        plt.xlabel("m1")
        plt.ylabel("m2")

        # Save the figure and close it to free memory
        path = self.exp_dir / "archive_images"
        os.makedirs(path, exist_ok=True)
        plt.savefig(path / f"{itr}.png", dpi=300, bbox_inches="tight")
        plt.close()

    def save(self, name):
        path = self.exp_dir / f"checkpoints/{name}.pkl"
        with open(path, "wb") as f:
            pickle.dump(
                {
                    "archive": self.archive,
                    "embedding_map": self.embedding_map.to("cpu"),
                    "measure_map": self.measure_map,
                    "agent_cfg": self.agent_cfg,
                },
                f,
            )

    def load_pretrained_policy_from_file(self):
        """
        WARNING: This assumes that the saved policy has the same architecture as the agent.
        Violations could cause unintended behavior without throwing an error.
        The loaded actor must be a SAC Actor (see baseline_sac.py), such that if we
        throw away its logstd output layer and use the mean output as the final layer,
        we get the exact same architecture as the Agent as specified in the config.
        """
        path = f"./pretrained_policies/{self.env_cfg.env_id}.pt"
        # Create agent from config
        agent = hydra.utils.instantiate(self.agent_cfg)

        try:
            # Load SAC state dict
            actor_state_dict = torch.load(path)

            # Check if architectures are compatible
            actor_layers = [
                k
                for k in actor_state_dict.keys()
                if k.startswith("fc")
                and not (k.startswith("fc_mean") or k.startswith("fc_logstd"))
            ]
            agent_layers = [f"layers.{i}" for i in range(len(agent.layers))]

            # Count input and output dimensions of each layer to verify compatibility
            compatible = True

            # Check first layer compatibility (input dimensions)
            if "fc1.weight" in actor_state_dict and "0" in [
                l.split(".")[1] for l in agent_layers
            ]:
                actor_input_dim = actor_state_dict["fc1.weight"].shape[1]
                agent_input_dim = agent.layers[0].weight.shape[1]
                if actor_input_dim != agent_input_dim:
                    compatible = False
                    print(
                        f"Input dimensions don't match: Actor: {actor_input_dim}, Agent: {agent_input_dim}"
                    )

            # Check final layer compatibility (output dimensions)
            if "fc_mean.weight" in actor_state_dict and str(len(agent.layers) - 1) in [
                l.split(".")[1] for l in agent_layers
            ]:
                actor_output_dim = actor_state_dict["fc_mean.weight"].shape[0]
                agent_output_dim = agent.layers[-1].weight.shape[0]
                if actor_output_dim != agent_output_dim:
                    compatible = False
                    print(
                        f"Output dimensions don't match: Actor: {actor_output_dim}, Agent: {agent_output_dim}"
                    )

            if compatible:
                # Create new state dict for agent
                new_state_dict = {}

                # Map first layer
                if "fc1.weight" in actor_state_dict:
                    new_state_dict["layers.0.weight"] = actor_state_dict["fc1.weight"]
                    new_state_dict["layers.0.bias"] = actor_state_dict["fc1.bias"]

                # Map hidden layers if any
                for i in range(2, len(actor_layers) + 1):
                    if (
                        f"fc{i}.weight" in actor_state_dict
                        and f"layers.{i-1}.weight" in {k: None for k in agent_layers}
                    ):
                        new_state_dict[f"layers.{i-1}.weight"] = actor_state_dict[
                            f"fc{i}.weight"
                        ]
                        new_state_dict[f"layers.{i-1}.bias"] = actor_state_dict[
                            f"fc{i}.bias"
                        ]

                # Map output layer (use only mean outputs from stochastic policy)
                new_state_dict[f"layers.{len(agent.layers)-1}.weight"] = (
                    actor_state_dict["fc_mean.weight"]
                )
                new_state_dict[f"layers.{len(agent.layers)-1}.bias"] = actor_state_dict[
                    "fc_mean.bias"
                ]

                # Load mapped parameters
                agent.load_state_dict(new_state_dict, strict=False)
                print(f"Successfully loaded pretrained policy from {path}")
            else:
                print(
                    f"Architectures are not compatible. Using randomly initialized agent."
                )

        except Exception as e:
            print(f"Error loading pretrained policy: {e}")
            print("Using randomly initialized agent.")

        return agent.to_numpy()
