import logging
import os
import pickle
import time
from typing import List

import hydra
import matplotlib.pyplot as plt
import numpy as np
import ray
from omegaconf import DictConfig, OmegaConf
from ribs.archives import CVTArchive, GridArchive
from ribs.emitters import EvolutionStrategyEmitter, GaussianEmitter, IsoLineEmitter
from ribs.schedulers import Scheduler
from ribs.visualize import grid_archive_heatmap

from src.logging_utils import MetricLogger
from src.qd.pga_emitter import PGAEmitter
from src.qd.td3 import TD3
from src.qd.wrappers import (
    AntBehavioralWrapper,
    BipedalBehavioralWrapper,
    HalfCheetahBehavioralWrapper,
    HopperBehavioralWrapper,
    SwimmerBehavioralWrapper,
    WalkerBehavioralWrapper,
    BIPEDAL_MEASURE_NAMES,
    ANT_MEASURE_NAMES,
    HALFCHEETAH_MEASURE_NAMES,
    HOPPER_MEASURE_NAMES,
    SWIMMER_MEASURE_NAMES,
    WALKER_MEASURE_NAMES,
)
from src.utils import PolicyEvaluator, umap_archive_plot


ALL_MEASURES = {
    "BipedalWalker-v3": BIPEDAL_MEASURE_NAMES,
    "Ant-v5": ANT_MEASURE_NAMES,
    "HalfCheetah-v5": HALFCHEETAH_MEASURE_NAMES,
    "Hopper-v5": HOPPER_MEASURE_NAMES,
    "Swimmer-v5": SWIMMER_MEASURE_NAMES,
    "Walker2d-v5": WALKER_MEASURE_NAMES,
}


class RegularQD:
    def __init__(
        self,
        env_cfg: DictConfig,
        qd: DictConfig,
        agent: DictConfig,
        measure_names: List[str],
        total_iterations: int,
        n_evals: int,
        exp_dir: str,
        metric_logger: MetricLogger,
        info_logger: logging.Logger,
        device,
        seed,
        **kwargs,
    ) -> None:
        """Initialize the QD algorithm.

        Args:
            env_cfg: Configuration for environment
            qd: Configuration for QD optimization
            agent: Agent configuration
            measure_names: List of names of measures that will be used. Must be a subset of the measures provided by the Wrapper
            total_iterations: Number of QD iterations in total.
            n_evals: Number of evaluations per solution
            exp_dir: Directory for experiment results
            metric_logger: MetricLogger instance
            info_logger: Regular Logger instance
            device: cuda or cpu
            seed: experiment seed
        """
        self.env_cfg = env_cfg
        self.device = device
        self.seed = seed
        self.measure_names = (
            measure_names if measure_names else ALL_MEASURES[env_cfg.env_id]
        )
        self.measure_dim = len(self.measure_names)
        self.setup_qd(qd, agent)

        self.agent_cfg = agent
        self.qd_cfg = qd
        self.total_iterations = total_iterations
        self.n_evals = n_evals
        self.exp_dir = exp_dir
        self.logger = metric_logger
        self.info_logger = info_logger

        # Setup ray multiprocessing
        if not ray.is_initialized():
            ray.init()
        cpus_per_worker = 2  # Adjust based on machine
        num_workers = int(max(1, (ray.cluster_resources()["CPU"]) // cpus_per_worker))
        info_logger.info(
            f"Using {num_workers} workers each with {cpus_per_worker} CPUs"
        )

        # TODO: Pass the wrapper and the measure names as well
        wrapper_cls = self.get_wrapper(self.env_cfg.env_id)
        self.evaluators = [
            PolicyEvaluator.options(num_cpus=cpus_per_worker).remote(
                env_id=self.env_cfg.env_id,
                env_kwargs=(
                    OmegaConf.to_container(self.env_cfg.env_kwargs, resolve=True)
                    if "env_kwargs" in self.env_cfg
                    else {}
                ),
                num_envs=self.n_evals,
                agent_cfg=self.agent_cfg,
                wrappers=[wrapper_cls],
                measure_names=self.measure_names,
            )
            for _ in range(num_workers)
        ]

    def setup_qd(self, qd: DictConfig, agent: DictConfig) -> None:
        """Initialize QD components including archive and emitters.

        Args:
            qd: QD configuration
            agent: Agent configuration

        Raises:
            AssertionError: If QD method is not supported
        """
        assert qd.method in [
            "cma_me",
            "sep_cma_mae",
            "cma_mae",
            "map_elites",
            "pga_me",
        ], f"'{qd.method}' is not a valid QD method"

        solution_dim = hydra.utils.instantiate(agent).num_params()
        dims = tuple((qd.archive_dim for _ in range(self.measure_dim)))
        # Handcrafted measures are expected to be normalized
        ranges = list([(0.0, 1.0) for _ in range(self.measure_dim)])

        if qd.archive == "grid":
            self.archive = GridArchive(
                solution_dim=solution_dim,
                dims=dims,
                ranges=ranges,
                qd_score_offset=qd.qd_score_offset,
                seed=self.seed,
            )
        elif qd.archive == "cvt":
            self.archive = CVTArchive(
                solution_dim=solution_dim,
                cells=qd.archive_dim,
                ranges=ranges,
                qd_score_offset=qd.qd_score_offset,
                seed=self.seed,
            )
        else:
            raise ValueError(f"Archive type {qd.archive} not recognized.")

        # Create emitters, results_archive (if needed) and schedulers for each algo
        if qd.method in ["cma_mae", "sep_cma_mae"]:
            if qd.archive == "grid":
                self.train_archive = GridArchive(
                    solution_dim=solution_dim,
                    dims=dims,
                    ranges=ranges,
                    qd_score_offset=qd.qd_score_offset,
                    threshold_min=qd.qd_score_offset,
                    learning_rate=qd.archive_lr,
                    seed=self.seed,
                )
            elif qd.archive == "cvt":
                self.train_archive = CVTArchive(
                    solution_dim=solution_dim,
                    cells=qd.archive_dim,
                    ranges=ranges,
                    qd_score_offset=qd.qd_score_offset,
                    threshold_min=qd.qd_score_offset,
                    learning_rate=qd.archive_lr,
                    custom_centroids=self.archive.centroids,
                    seed=self.seed,
                )
            if qd.method == "cma_mae":
                self.emitters = [
                    EvolutionStrategyEmitter(
                        self.train_archive,
                        x0=hydra.utils.instantiate(agent).to_numpy(),
                        sigma0=qd.sigma0,
                        ranker="imp",
                        selection_rule="mu",
                        restart_rule=qd.restart_rule,
                        batch_size=qd.batch_size,
                        seed=self.seed + i,
                    )
                    for i in range(qd.emitters)
                ]
            elif qd.method == "sep_cma_mae":
                self.emitters = [
                    EvolutionStrategyEmitter(
                        self.train_archive,
                        x0=hydra.utils.instantiate(agent).to_numpy(),
                        sigma0=qd.sigma0,
                        ranker="imp",
                        es="sep_cma_es",
                        selection_rule="mu",
                        restart_rule=qd.restart_rule,
                        batch_size=qd.batch_size,
                        seed=self.seed + i,
                    )
                    for i in range(qd.emitters)
                ]
            self.scheduler = Scheduler(
                self.train_archive, self.emitters, result_archive=self.archive
            )

        elif qd.method == "cma_me":
            self.emitters = [
                EvolutionStrategyEmitter(
                    self.archive,
                    x0=hydra.utils.instantiate(agent).to_numpy(),
                    sigma0=qd.sigma0,
                    ranker="2imp",
                    restart_rule=qd.restart_rule,
                    batch_size=qd.batch_size,
                    seed=self.seed + i,
                )
                for i in range(qd.emitters)
            ]
            self.scheduler = Scheduler(self.archive, self.emitters)

        elif qd.method == "pga_me":
            td3 = TD3(qd.td3, agent, seed=self.seed + 2323)
            self.emitters = [
                PGAEmitter(
                    self.archive,
                    td3=td3,
                    x0=hydra.utils.instantiate(agent).to_numpy(),
                    sigma0=qd.pga_sigma0,
                    batch_size=qd.pga_batch_size,
                    init_iters=qd.init_iters,
                    seed=self.seed + 1,
                ),
                IsoLineEmitter(
                    self.archive,
                    iso_sigma=qd.iso_sigma,
                    line_sigma=qd.line_sigma,
                    x0=hydra.utils.instantiate(agent).to_numpy(),
                    batch_size=qd.isoline_batch_size,
                    seed=self.seed,
                ),
            ]
            self.scheduler = Scheduler(self.archive, self.emitters)

        elif qd.method == "map_elites":
            self.emitters = [
                GaussianEmitter(
                    self.archive,
                    x0=hydra.utils.instantiate(agent).to_numpy(),
                    sigma=qd.sigma,
                    batch_size=qd.batch_size,
                    seed=self.seed,
                )
            ]
            self.scheduler = Scheduler(self.archive, self.emitters)

    def step_qd(self) -> None:
        """Execute one step of QD optimization."""
        timings = {}

        ask_time = time.perf_counter()
        solution_batch = self.scheduler.ask()
        timings["time/ask"] = time.perf_counter() - ask_time

        eval_time = time.perf_counter()
        futures = []
        for i, solution in enumerate(solution_batch):
            evaluator = self.evaluators[i % len(self.evaluators)]
            future = evaluator.evaluate_policy.remote(
                solution, self.n_evals, self.device
            )
            futures.append(future)

        # Collect results as they complete
        trajectory_batches = ray.get(futures)
        timings["time/eval"] = time.perf_counter() - eval_time

        traj_process_time = time.perf_counter()
        objective_batch, measure_batch = [], []
        for trajs in trajectory_batches:
            meas = np.array([t.measures for t in trajs]).mean(0)
            obj = np.mean([t.rewards.sum() for t in trajs])

            objective_batch.append(obj)
            measure_batch.append(meas)

            if self.qd_cfg.method == "pga_me":
                # Collected experience will be added to td3's buffer to be used for training
                self.emitters[0].add_experience(trajs)

        objective_batch = np.array(objective_batch)
        measure_batch = np.stack(measure_batch)
        timings["time/traj_process"] = time.perf_counter() - traj_process_time

        self.scheduler.tell(objective_batch, measure_batch)
        if self.qd_cfg.method == "pga_me":
            # Train critics and actor of td3
            td3_time = time.perf_counter()
            self.emitters[0].train_td3()
            timings["time/td3_training"] = time.perf_counter() - td3_time
        return timings

    def train(self) -> None:
        """Execute the main training loop."""
        itr = 0
        total_evals = 0  # Total number of trajectories throughout training
        while itr < self.total_iterations:
            timings = self.step_qd()
            itr += 1
            total_evals += sum([e.batch_size for e in self.emitters]) * self.n_evals
            self.logger.update(
                {
                    "archive_size": len(self.archive),
                    "coverage": self.archive.stats.coverage,
                    "mean_obj": self.archive.stats.obj_mean,
                    "max_obj": self.archive.stats.obj_max,
                    "qd_score": self.archive.stats.qd_score,
                    "total_evals": total_evals,
                }
                | timings
            )
            if itr % 10 == 9:
                self.save_archive_img(itr)
            self.logger.log(itr)

    def save_archive_img(self, itr: int):
        plt.figure(figsize=(8, 6))

        if self.measure_dim == 2:
            grid_archive_heatmap(
                self.archive, vmin=self.env_cfg.min_score, vmax=self.env_cfg.max_score
            )
        else:
            umap_archive_plot(
                self.archive, vmin=self.env_cfg.min_score, vmax=self.env_cfg.max_score
            )

        plt.title(f"Result Archive at t={itr}")
        plt.xlabel("m1")
        plt.ylabel("m2")

        # Save the figure and close it to free memory
        path = self.exp_dir / "archive_images"
        os.makedirs(path, exist_ok=True)
        plt.savefig(path / f"{itr}.png", dpi=300, bbox_inches="tight")
        plt.close()

    def save(self, name):
        path = self.exp_dir / f"checkpoints/{name}.pkl"
        with open(path, "wb") as f:
            pickle.dump(
                {
                    "archive": self.archive,
                    "agent_cfg": self.agent_cfg,
                },
                f,
            )

    def get_wrapper(self, env_id):
        env_id = env_id.lower()
        if "ant" in env_id:
            return AntBehavioralWrapper
        elif "bipedal" in env_id:
            return BipedalBehavioralWrapper
        elif "cheetah" in env_id:
            return HalfCheetahBehavioralWrapper
        elif "hopper" in env_id:
            return HopperBehavioralWrapper
        elif "swimmer" in env_id:
            return SwimmerBehavioralWrapper
        elif "walker2d" in env_id:
            return WalkerBehavioralWrapper
        else:
            raise ValueError(f"Behavioral wrapper for {env_id} is not implemented")
