import logging
import os
import pickle
import time
from typing import Dict, List, Tuple, Union

import hydra
import matplotlib.pyplot as plt
import numpy as np
import ray
import torch
from omegaconf import DictConfig, OmegaConf
from ribs.archives import GridArchive
from ribs.emitters import EvolutionStrategyEmitter
from ribs.schedulers import Scheduler
from ribs.visualize import grid_archive_heatmap

from src.logging_utils import MetricLogger
from src.utils import PolicyEvaluator, umap_archive_plot


class AURORA:
    """Implementation of AURORA algorithm from Cully (2019).

    AURORA alternates between:
    1. QD optimization using auto-encoded behavioral descriptors
    2. Updating the auto-encoder using data from the archive
    """

    def __init__(
        self,
        env_cfg: DictConfig,
        qd: DictConfig,
        agent: DictConfig,
        encoder: DictConfig,
        encoder_training: DictConfig,
        use_lstm: bool,
        measures_dim: int,
        total_iterations: int,
        update_interval: Union[int, List[int]],
        measure_margin: float,
        traj_sampling_freq: float,
        n_evals: int,
        exp_dir: str,
        metric_logger: MetricLogger,
        info_logger: logging.Logger,
        device: str,
        seed: int,
        **kwargs,
    ) -> None:
        """Initialize AURORA.

        Args:
            env_cfg: Environment configuration
            qd: QD optimizer configuration
            agent: Agent configuration
            encoder: Auto-encoder configuration
            encoder_training: Parameters related to AE's update
            use_lstm: If true, autoencoder will be an LSTM reconstructing whole trajectories
            measures_dim: Dimensionality of latent space (measures)
            total_iterations: Total number of QD iterations
            update_interval: Steps between encoder updates (can be list for variable schedule)
            measure_margin: Margin added to measures for exploration
            traj_sampling_freq: Frequency of subsampling states for encoder's use
            n_evals: Number of evaluations per solution
            exp_dir: Directory for experiment results
            metric_logger: Logger for metrics
            info_logger: Logger for info messages
            device: Device to run on
            seed: Random seed
        """
        self.env_cfg = env_cfg
        self.device = device
        self.seed = seed
        self.measure_margin = measure_margin
        self.traj_sampling_freq = traj_sampling_freq
        self.use_lstm = use_lstm

        # Initialize components
        self.encoder = hydra.utils.instantiate(encoder, device=device)
        self.setup_qd(qd, agent, measures_dim, n_evals)

        # Store configurations
        self.agent_cfg = agent
        self.qd_cfg = qd
        self.encoder_cfg = encoder
        self.encoder_training = encoder_training

        # Training parameters
        self.total_iterations = total_iterations
        update_interval = OmegaConf.to_container(update_interval)
        self.update_interval = (
            update_interval
            if isinstance(update_interval, list)
            else list(range(update_interval, total_iterations, update_interval))
        )
        self.latent_dim = measures_dim
        self.n_evals = n_evals

        # Logging
        self.exp_dir = exp_dir
        self.logger = metric_logger
        self.info_logger = info_logger

        # Setup parallel evaluation
        if not ray.is_initialized():
            ray.init()
        cpus_per_worker = 2
        num_workers = int(max(1, (ray.cluster_resources()["CPU"]) // cpus_per_worker))
        info_logger.info(
            f"Using {num_workers} workers each with {cpus_per_worker} CPUs"
        )

        self.evaluators = [
            PolicyEvaluator.remote(
                env_id=self.env_cfg.env_id,
                env_kwargs=(
                    OmegaConf.to_container(self.env_cfg.env_kwargs, resolve=True)
                    if "env_kwargs" in self.env_cfg
                    else {}
                ),
                num_envs=n_evals,
                agent_cfg=self.agent_cfg,
            )
            for _ in range(num_workers)
        ]

    # TODO: CMA-MAE
    def setup_qd(
        self, qd: DictConfig, agent: DictConfig, latent_dim: int, n_evals: int
    ) -> None:
        """Initialize QD components including archive and emitters.

        Args:
            qd: QD configuration
            agent: Agent configuration
            latent_dim: Dimensionality of latent space (measures)
            n_evals: Number of states to store per solution
        """
        assert qd.method in [
            "cma_me",
            "sep_cma_mae",
            "cma_mae",
        ], f"'{qd.method}' is not a valid QD method for AURORA"
        solution_dim = hydra.utils.instantiate(agent).num_params()
        dims = tuple(qd.archive_dim for _ in range(latent_dim))

        # Initialize with placeholder ranges - will be updated in reset_qd
        ranges = [(-1.0, 1.0) for _ in range(latent_dim)]
        if self.use_lstm:
            # store n_evals trajectories (list of states)
            extra_fields = {"trajs": ((n_evals,), object)}
        else:
            # store the last states of n_evals trajectories
            extra_fields = {"states": ((n_evals, self.env_cfg.state_dim), np.float32)}

        self.archive = GridArchive(
            solution_dim=solution_dim,
            dims=dims,
            ranges=ranges,
            qd_score_offset=qd.qd_score_offset,
            # Store multiple final states per solution
            extra_fields=extra_fields,
            seed=self.seed,
        )

        if qd.method in ["cma_mae", "sep_cma_mae"]:
            self.train_archive = GridArchive(
                solution_dim=solution_dim,
                dims=dims,
                ranges=ranges,
                qd_score_offset=qd.qd_score_offset,
                threshold_min=qd.qd_score_offset,
                learning_rate=qd.archive_lr,
                extra_fields=extra_fields,
                seed=self.seed,
            )
            if qd.method == "cma_mae":
                self.emitters = [
                    EvolutionStrategyEmitter(
                        self.train_archive,
                        x0=hydra.utils.instantiate(agent).to_numpy(),
                        sigma0=(
                            (qd.sigma0 * (2**i)) if qd.different_sigmas else qd.sigma0
                        ),
                        ranker="imp",
                        selection_rule="mu",
                        restart_rule=qd.restart_rule,
                        batch_size=qd.batch_size,
                        seed=self.seed + i,
                    )
                    for i in range(qd.emitters)
                ]
            elif qd.method == "sep_cma_mae":
                self.emitters = [
                    EvolutionStrategyEmitter(
                        self.train_archive,
                        x0=hydra.utils.instantiate(agent).to_numpy(),
                        sigma0=(
                            (qd.sigma0 * (2**i)) if qd.different_sigmas else qd.sigma0
                        ),
                        ranker="imp",
                        es="sep_cma_es",
                        selection_rule="mu",
                        restart_rule=qd.restart_rule,
                        batch_size=qd.batch_size,
                        seed=self.seed + i,
                    )
                    for i in range(qd.emitters)
                ]
            self.scheduler = Scheduler(
                self.train_archive, self.emitters, result_archive=self.archive
            )

        elif qd.method == "cma_me":
            # Using CMA-ME emitters
            self.emitters = [
                EvolutionStrategyEmitter(
                    self.archive,
                    x0=hydra.utils.instantiate(agent).to_numpy(),
                    sigma0=((qd.sigma0 * (2**i)) if qd.different_sigmas else qd.sigma0),
                    ranker="2imp",
                    batch_size=qd.batch_size,
                    seed=self.seed + i,
                )
                for i in range(qd.emitters)
            ]

            self.scheduler = Scheduler(self.archive, self.emitters)

    def compute_archive_ranges(
        self, encodings: torch.Tensor
    ) -> List[Tuple[float, float]]:
        """Compute archive ranges based on encoded states.

        Args:
            encodings: Encoded states/trajectories to use for range computation

        Returns:
            List of (min, max) tuples for each latent dimension
        """
        # Compute ranges with margin
        ranges = []
        for dim in range(self.latent_dim):
            dim_min = float(encodings[:, dim].min())
            dim_max = float(encodings[:, dim].max())
            margin = self.measure_margin * (dim_max - dim_min)
            ranges.append((dim_min - margin, dim_max + margin))

        return ranges

    def reset_qd(self) -> None:
        """Reset archive and emitters with updated ranges."""
        # Get all states and re-encode them
        if self.use_lstm:
            all_data = self.archive.data("trajs")  # (n_sols, n_evals)
            # np array where each element is a list of some states
        else:
            all_data = self.archive.data("states")

        batch_size = 64  # Process in batches to avoid OOM
        all_encodings = []

        with torch.no_grad():
            for i in range(0, len(all_data), batch_size):
                batch = all_data[i : i + batch_size]
                if self.use_lstm:
                    batch = [
                        torch.from_numpy(x).to(self.device) for x in batch.flatten()
                    ]
                else:
                    batch = torch.from_numpy(batch).to(self.device)
                encodings = self.encoder.encode(batch)
                all_encodings.append(encodings)
            all_encodings = torch.cat(all_encodings, dim=0)
            all_encodings = all_encodings.reshape(-1, self.n_evals, self.latent_dim)

        # Compute new ranges based on mean encodings
        ranges = self.compute_archive_ranges(all_encodings.mean(dim=1))

        # Create new archive with updated ranges
        if self.use_lstm:
            # store n_evals trajectories (list of states)
            extra_fields = {"trajs": ((self.n_evals,), object)}
        else:
            # store the last states of n_evals trajectories
            extra_fields = {
                "states": ((self.n_evals, self.env_cfg.state_dim), np.float32)
            }

        if self.qd_cfg.method in ["cma_mae", "sep_cma_mae"]:
            self.train_archive = GridArchive(
                solution_dim=self.train_archive.solution_dim,
                dims=self.train_archive.dims,
                ranges=ranges,
                qd_score_offset=self.qd_cfg.qd_score_offset,
                threshold_min=self.qd_cfg.qd_score_offset,
                learning_rate=self.qd_cfg.archive_lr,
                extra_fields=extra_fields,
                seed=self.seed,
            )
        new_archive = GridArchive(
            solution_dim=self.archive.solution_dim,
            dims=self.archive.dims,
            ranges=ranges,
            qd_score_offset=self.qd_cfg.qd_score_offset,
            extra_fields=extra_fields,
        )

        # Transfer elites to new archive with updated BDs
        for idx, elite in enumerate(self.archive):
            bd = all_encodings[idx].mean(dim=0).cpu().numpy()
            if self.use_lstm:
                extra_info = {"trajs": elite["trajs"]}
            else:
                extra_info = {"states": elite["states"]}

            new_archive.add_single(
                elite["solution"],
                elite["objective"],
                bd,
                **extra_info,
            )

        self.archive = new_archive

        # Reset emitters and scheduler
        # Start emitters from a random solution in the top 20% + best solution
        objectives = self.archive.data("objective")
        solutions = self.archive.data("solution")
        threshold = np.percentile(objectives, 80)
        top_indices = np.where(objectives >= threshold)[0]
        if self.qd_cfg.method == "cma_mae":
            self.emitters = [
                EvolutionStrategyEmitter(
                    self.train_archive,
                    x0=(
                        self.archive.best_elite["solution"].reshape(-1)
                        if i == 0
                        else solutions[np.random.choice(top_indices)].reshape(-1)
                    ),
                    sigma0=(
                        (self.qd_cfg.sigma0 * (2**i))
                        if self.qd_cfg.different_sigmas
                        else self.qd_cfg.sigma0
                    ),
                    ranker="imp",
                    selection_rule="mu",
                    restart_rule=self.qd_cfg.restart_rule,
                    batch_size=self.qd_cfg.batch_size,
                    seed=self.seed + i,
                )
                for i in range(self.qd_cfg.emitters)
            ]
            self.scheduler = Scheduler(
                self.train_archive, self.emitters, result_archive=self.archive
            )
        elif self.qd_cfg.method == "sep_cma_mae":
            self.emitters = [
                EvolutionStrategyEmitter(
                    self.train_archive,
                    x0=(
                        self.archive.best_elite["solution"].reshape(-1)
                        if i == 0
                        else solutions[np.random.choice(top_indices)].reshape(-1)
                    ),
                    sigma0=(
                        (self.qd_cfg.sigma0 * (2**i))
                        if self.qd_cfg.different_sigmas
                        else self.qd_cfg.sigma0
                    ),
                    ranker="imp",
                    es="sep_cma_es",
                    selection_rule="mu",
                    restart_rule=self.qd_cfg.restart_rule,
                    batch_size=self.qd_cfg.batch_size,
                    seed=self.seed + i,
                )
                for i in range(self.qd_cfg.emitters)
            ]
            self.scheduler = Scheduler(
                self.train_archive, self.emitters, result_archive=self.archive
            )
        elif self.qd_cfg.method == "cma_me":
            # Start emitters from a random solution + best solution
            self.emitters = [
                EvolutionStrategyEmitter(
                    self.archive,
                    x0=(
                        self.archive.best_elite["solution"].reshape(-1)
                        if i == 0
                        else solutions[np.random.choice(top_indices)].reshape(-1)
                    ),
                    sigma0=(
                        (self.qd_cfg.sigma0 * (2**i))
                        if self.qd_cfg.different_sigmas
                        else self.qd_cfg.sigma0
                    ),
                    restart_rule=self.qd_cfg.restart_rule,
                    ranker="2imp",
                    batch_size=self.qd_cfg.batch_size,
                    seed=self.seed + i,
                )
                for i in range(self.qd_cfg.emitters)
            ]
            self.scheduler = Scheduler(self.archive, self.emitters)

    def step_qd(self) -> Dict[str, float]:
        """Execute one step of QD optimization.

        Returns:
            Dictionary of timing information
        """
        timings = {}

        ask_time = time.perf_counter()
        solution_batch = self.scheduler.ask()
        timings["time/ask"] = time.perf_counter() - ask_time

        eval_time = time.perf_counter()
        futures = []
        for i, solution in enumerate(solution_batch):
            evaluator = self.evaluators[i % len(self.evaluators)]
            future = evaluator.evaluate_policy.remote(solution, self.n_evals, "cpu")
            futures.append(future)

        # Collect results as they complete
        trajectory_batches = ray.get(futures)
        timings["time/eval"] = time.perf_counter() - eval_time

        process_time = time.perf_counter()
        objective_batch, measure_batch, extra_data_batch = [], [], []

        for trajs in trajectory_batches:
            if self.use_lstm:
                # Get downsampled trajectories
                traj_data = [
                    torch.from_numpy(t.states[:: self.traj_sampling_freq]).to(
                        self.device
                    )
                    for t in trajs
                ]
            else:
                # Get final states from all trajectories
                traj_data = torch.from_numpy(
                    np.stack([t.states[-1] for t in trajs])
                ).to(self.device)

            # Encode all final states
            with torch.no_grad():
                encodings = self.encoder.encode(traj_data)

            # Use mean encoding as BD
            bd = encodings.mean(dim=0)

            # Calculate mean reward across trajectories
            obj = np.mean([t.rewards.sum() for t in trajs])

            objective_batch.append(obj)
            measure_batch.append(bd.cpu().numpy())
            # if use_lstm, extra_data is a np array of type object where each element is
            #   a numpy array of shape T, state_dim
            if self.use_lstm:
                # traj_data: list of len n_evals of torch tensors of shape (T, state_dim)
                extra_data_element = np.empty(len(traj_data), dtype=object)
                for i, t in enumerate(traj_data):  # tensor of shape (T, state_dim)
                    extra_data_element[i] = t.cpu().numpy()
                extra_data_batch.append(extra_data_element)
            else:
                extra_data_batch.append(traj_data.cpu().numpy())

        objective_batch = np.array(objective_batch)
        measure_batch = np.stack(measure_batch)
        extra_data_batch = np.stack(extra_data_batch)
        if self.use_lstm:
            extra_info = {"trajs": extra_data_batch}
        else:
            extra_info = {"states": extra_data_batch}

        timings["time/process"] = time.perf_counter() - process_time

        self.scheduler.tell(
            objective_batch,
            measure_batch,
            **extra_info,
        )

        return timings

    def update_encoder(self) -> Dict[str, float]:
        """Update encoder using states from archive.

        Returns:
            Dictionary of training metrics
        """
        train_time = time.perf_counter()

        if self.use_lstm:
            data = [
                torch.from_numpy(np.array(t)).to(self.device)
                for t in self.archive.data("trajs").flatten()
            ]
        else:
            # Get all states from archive and reshape to (n_solutions * n_evals, state_dim)
            data = torch.from_numpy(
                self.archive.data("states").reshape(-1, self.env_cfg.state_dim)
            ).to(self.device)

        # Train encoder
        metrics = self.encoder.update(
            data,
            batch_size=self.encoder_training.batch_size,
            epochs=self.encoder_training.epochs,
            learning_rate=self.encoder_training.learning_rate,
            validation_split=self.encoder_training.validation_split,
            patience=self.encoder_training.patience,
        )

        metrics["time/train_time"] = time.perf_counter() - train_time
        return metrics

    def train(self) -> None:
        """Execute main training loop."""
        itr = 0
        total_evals = 0

        while itr < self.total_iterations:
            # Execute QD step
            timings = self.step_qd()

            # Update counters
            itr += 1
            total_evals += sum([e.batch_size for e in self.emitters]) * self.n_evals

            # Log metrics
            self.logger.update(
                {
                    "archive_size": len(self.archive),
                    "coverage": self.archive.stats.coverage,
                    "mean_obj": self.archive.stats.obj_mean,
                    "max_obj": self.archive.stats.obj_max,
                    "qd_score": self.archive.stats.qd_score,
                    "total_evals": total_evals,
                }
                | timings
            )

            if itr % 10 == 0:
                self.save_archive_img(itr)

            # Update encoder and reset QD search
            if itr in self.update_interval:
                metric_info = self.update_encoder()
                self.logger.update(metric_info)
                self.reset_qd()

            self.logger.log(itr)

    def save_archive_img(self, itr: int):
        plt.figure(figsize=(8, 6))

        if self.latent_dim == 2:
            grid_archive_heatmap(
                self.archive, vmin=self.env_cfg.min_score, vmax=self.env_cfg.max_score
            )
        else:
            umap_archive_plot(
                self.archive, vmin=self.env_cfg.min_score, vmax=self.env_cfg.max_score
            )

        plt.title(f"Result Archive at t={itr}")
        plt.xlabel("m1")
        plt.ylabel("m2")

        # Save the figure and close it to free memory
        path = self.exp_dir / "archive_images"
        os.makedirs(path, exist_ok=True)
        plt.savefig(path / f"{itr}.png", dpi=300, bbox_inches="tight")
        plt.close()

    def save(self, name: str) -> None:
        """Save algorithm state."""
        path = self.exp_dir / f"checkpoints/{name}.pkl"
        with open(path, "wb") as f:
            pickle.dump(
                {
                    "archive": self.archive,
                    "encoder": self.encoder,
                    "agent_cfg": self.agent_cfg,
                },
                f,
            )
