from collections.abc import Callable

import numpy as np
import torch
from gymnasium import Env
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv
from tqdm import tqdm

from compression_autoencoder.pgpe.pgpe import PGPE
from compression_autoencoder.policies.policy import Policy
from compression_autoencoder.utils.evaluation import (
    evaluate_policies_batch,
    evaluate_single_policy,
)
from compression_autoencoder.utils.history import History


class PGPELearner:
    def __init__(
        self,
        popsize: int,
        num_generations: int,
        num_params: int,
        env: Env | VecEnv,
        sample_policy: Policy,
        device: torch.device,
        verbose: bool = True,
        parameter_to_weights: Callable[[np.ndarray], torch.Tensor] | None = None,
        exact_step_count: bool = False,
        center_init_dist: str = "normal",  # zeros, normal, policy
        update_type: str = "reinforce",
        pgpe_kwargs: dict | None = None,
        seed: int = 0,
    ) -> None:
        self.popsize = popsize
        self.num_generations = num_generations
        self.num_params = num_params
        self.verbose = verbose
        self.sample_policy = sample_policy
        self.exact_step_count = exact_step_count
        self.device = device

        if isinstance(env, Env):
            self.vec_env: VecEnv = DummyVecEnv([lambda: env])
        else:
            self.vec_env = env

        if pgpe_kwargs is None:
            pgpe_kwargs = {}

        center_init = np.zeros(num_params, dtype=np.float32)
        if center_init_dist == "normal":
            rng = np.random.default_rng()
            center_init = rng.standard_normal(num_params, dtype=np.float32)
        elif center_init_dist == "policy":
            center_init = self.sample_policy.extract_weights().detach().cpu().numpy()

        self.pgpe = PGPE(
            solution_length=self.num_params,
            popsize=self.popsize,
            center_init=center_init,
            update_type=update_type,
            seed=seed,
            max_generations=self.num_generations,
            optimizer_config={"beta1": 0.25},
            **pgpe_kwargs,
        )

        if parameter_to_weights is None:
            self.parameter_to_weights = (
                lambda x: torch.from_numpy(x).float().to(self.device)
            )
        else:
            self.parameter_to_weights = parameter_to_weights

    def learn(
        self,
        eval_freq_steps: int = 100_000,
        max_env_steps: int = 10_000,
        n_eps_per_policy: int = 1,
        n_eval_episodes: int = 10,
        verbose: bool = False,
    ) -> History:
        history = History()

        iter_range: range | tqdm = range(self.num_generations)
        if verbose:
            iter_range = tqdm(iter_range, desc="PGPE Learning")

        # history.append_misc("center", self.pgpe.center.copy())
        total_steps_counter = 0
        next_eval_step = eval_freq_steps

        # initial evaluation
        center_params = self.pgpe.center.copy()
        center_weights = self.parameter_to_weights(center_params)
        reward, _ = evaluate_single_policy(
            self.vec_env,
            self.sample_policy,
            center_weights,
            n_eval_episodes=n_eval_episodes,
            n_envs=self.vec_env.num_envs,
            device=self.device,
        )
        history.append("val_reward", reward, total_steps_counter)
        if verbose:
            print(f"Eval Reward: {reward:.2f}")

        # normal pgpe loop
        for iter in iter_range:
            parameters = self.pgpe.ask()

            weights = self.parameter_to_weights(parameters)
            rewards, steps = evaluate_policies_batch(
                vec_env=self.vec_env,
                sample_policy=self.sample_policy,
                policies_weights=weights,
                n_eps_per_policy=n_eps_per_policy,
                device=self.device,
            )
            # normalize rewards
            # rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-8)
            self.pgpe.tell(rewards)

            if self.exact_step_count:
                # This is the correct way of counting env interactions, but across
                # different runs it produces different sampling frequencies
                total_steps_counter += np.sum(steps)
            else:
                # This overcounts the number of steps, but is consistent across runs
                total_steps_counter += self.popsize * max_env_steps

            center_params = self.pgpe.center.copy()
            # history.append_misc("center", center_params)
            # Check if it's time to evaluate the center policy
            if (
                total_steps_counter >= next_eval_step
                or iter == self.num_generations - 1
            ):
                eval_eps = n_eval_episodes
                if iter == self.num_generations - 1:
                    eval_eps = max(n_eval_episodes, 100)
                center_weights = self.parameter_to_weights(center_params)
                reward, _ = evaluate_single_policy(
                    self.vec_env,
                    self.sample_policy,
                    center_weights,
                    n_eval_episodes=eval_eps,
                    n_envs=self.vec_env.num_envs,
                    device=self.device,
                )
                history.append("val_reward", reward, total_steps_counter)
                if verbose:
                    print(f"Eval Reward: {reward:.2f}")
                next_eval_step += eval_freq_steps
        return history
