from typing import Sequence, Callable, Dict, List, Tuple, Optional
import functools
import itertools
import logging
import math
from statistics import mean
from .base_agent import IAgent
import numpy as np
import gym
from tqdm import tqdm
from cmaes import CMA


def run_steps(agent: IAgent, env: gym.Env, nb_step: int):
    """Use for record videos

    Args:
        agent (IAgent): Trained agent
        env (gym.Env): environment
        nb_step (int): number of step to record
    """
    obs = env.reset()
    for _ in range(nb_step):
        mass = agent.action(observation=obs)
        action = mass.best_action()
        obs, _, done, _ = env.step(action)
    env.close()


def _reset_agent(agent: IAgent):
    """Reset agent memory

    Args:
        agent (IAgent): Trained agent
    """
    reset_memory_op = getattr(agent, "reset_memory", None)
    if callable(reset_memory_op):
        agent.reset_memory()  # type: ignore


def evaluate(
    agent: IAgent,
    env: gym.Env,
    nb_step: int,
    nb_trial: int,
    sequence_seed: Sequence[int],
) -> float:
    """Evaluate the agent on multiple episode and multiple seed

    Args:
        agent (IAgent): Trained agent
        env (gym.Env): Env
        nb_step (int): number of step for "one episode" session
        nb_trial (int): number of session
        sequence_seed (Sequence[int]): number of seed

    Returns:
        float: mean reward session
    """
    reward_all_session = []
    for seed in sequence_seed:
        env.seed(seed=seed)
        for _ in range(nb_trial):
            _reset_agent(agent=agent)
            obs = env.reset()
            sum_reward_session = 0
            for _ in range(nb_step):
                mass = agent.action(observation=obs)
                action = mass.best_action()
                obs, reward, _, _ = env.step(action)
                sum_reward_session += reward[0]
            reward_all_session.append(sum_reward_session)

    return mean(reward_all_session)


def find_worse(
    agent: IAgent,
    env_builder: Callable[..., gym.Env],
    bound: Dict[str, List[float]],
    step_grid: float,
    verbose: bool = False,
    **evaluate_kwargs,
) -> Dict[str, float]:
    parameters_range = {
        parameter_name: np.arange(
            start=bound_value[0], stop=bound_value[1], step=step_grid
        ).tolist()
        for parameter_name, bound_value in bound.items()
    }
    worst_perf = math.inf
    worst_params = None
    infos: List[Tuple[Dict[str, float], float]] = []
    iterable_params = itertools.product(*parameters_range.values())
    iterable_params = tqdm(iterable_params, ncols=80) if verbose else iterable_params
    for param_set in iterable_params:
        parameters = {
            param_name: param_value
            for param_name, param_value in zip(bound.keys(), param_set)
        }
        env = env_builder(**parameters)
        performance = evaluate(agent=agent, env=env, **evaluate_kwargs)
        if performance < worst_perf:
            worst_perf = performance
            worst_params = parameters
        infos.append((parameters, performance))
        env.close()
    return worst_params, infos  # type: ignore


def find_worse_mesh(
    agent: IAgent,
    env_builder: Callable[..., gym.Env],
    bound: Dict[str, List[float]],
    split: int,
    verbose: bool = False,
    **evaluate_kwargs,
) -> Dict[str, float]:
    parameters_range = {
        parameter_name: np.round(
            np.arange(
                start=bound_value[0],
                stop=bound_value[1],
                step=(bound_value[1] - bound_value[0]) / split,
            ),
            3,
        ).tolist()
        for parameter_name, bound_value in bound.items()
    }
    worst_perf = math.inf
    worst_params = None
    infos: List[Tuple[Dict[str, float], float]] = []
    iterable_params = itertools.product(*parameters_range.values())
    iterable_params = tqdm(iterable_params, ncols=80) if verbose else iterable_params
    for param_set in iterable_params:
        parameters = {
            param_name: param_value
            for param_name, param_value in zip(bound.keys(), param_set)
        }
        env = env_builder(**parameters)
        performance = evaluate(agent=agent, env=env, **evaluate_kwargs)
        if performance < worst_perf:
            worst_perf = performance
            worst_params = parameters
        infos.append((parameters, performance))
        env.close()
    # sort by performance
    infos = sorted(infos, key=lambda x: x[1])
    return worst_params, infos  # type: ignore


def objective_cma_es(
    *params,
    worst_env_builder,
    params_names,
    agent,
    sequence_seed,
    nb_step,
    nb_trial,
    **kparams,
):
    params_env = dict(zip(params_names, params))
    env_set = worst_env_builder(**params_env)
    reward = evaluate(
        agent=agent,
        env=env_set,
        nb_step=nb_step,
        nb_trial=nb_trial,
        sequence_seed=sequence_seed,
    )
    return reward


def find_worse_cma_es(
    agent: IAgent,
    worst_env_builder,
    bound: Dict[str, List[float]],
    generation: int,
    population_size: Optional[int] = None,
    sequence_seed: Optional[Sequence[int]] = None,
    nb_step: int = 1000,
    nb_trial: int = 1,
    mean_value: float = 0.5,
    sigma_value: float = 0.3,
    seed_cma: int = 1,
    **kparams,
):
    if sequence_seed is None:
        sequence_seed = [12345, 12346, 12347, 12348, 12349]
    bounds_cma = np.array([v for v in bound.values()])  # type: ignore

    # CMA ES works with normalized bounds [0, 1] x N parameters
    normalized_bounds = np.array([[0, 1] for _ in bound.values()])
    params_names = list(bound.keys())
    optimizer = CMA(
        mean=np.ones(len(bound.keys())) * mean_value,
        sigma=sigma_value,
        bounds=normalized_bounds,
        population_size=population_size,
        seed=seed_cma,
    )

    objective_params = functools.partial(
        objective_cma_es,
        worst_env_builder=worst_env_builder,
        agent=agent,
        sequence_seed=sequence_seed,
        nb_step=nb_step,
        nb_trial=nb_trial,
        params_names=bound.keys(),
    )

    for generation in tqdm(range(generation)):
        solutions = []
        for _ in range(optimizer.population_size):
            x_normalized = optimizer.ask()
            x = np.zeros_like(x_normalized)
            for i in range(len(x_normalized)):
                x[i] = (
                    x_normalized[i] * (bounds_cma[i][1] - bounds_cma[i][0])
                    + bounds_cma[i][0]
                )
            value = objective_params(*x.tolist())
            solutions.append((x, value))
            # print(f"#{generation} {value} ({x})")
        optimizer.tell(solutions)
        if optimizer.should_stop():
            logging.info(f"CMA ES Optimizer stop at generation {generation}")
            break

    min_reward = float("inf")
    min_psi: Optional[List[float]] = None
    infos: List[List[Dict[str, float]]] = []
    for psi, reward in solutions:  # type: ignore
        if reward < min_reward:
            min_reward = reward
            min_psi = psi.tolist()
        infos.append([dict(zip(params_names, psi.tolist())), reward])

    worst_params = dict(zip(params_names, min_psi))  # type: ignore

    return worst_params, infos
