from copy import deepcopy
import functools
import os
import socket
import datetime
import uuid
import json
import logging
from typing import List, Optional

import fire
import torch
import torch.nn as nn
import wandb

from algos import (
    OffPolicyMemory,
    SACMPCHydra,
    train_ravi_mpc_no_reg_minimal,
    train_ravi_mpc_no_reg_minimal_best,
    find_worse_cma_es,
)
from build import (
    build_envs,
    build_optimizer_mpc_no_reg,
    build_mpc_sac_hydra_agent,
    build_bounds,
    build_pessimist_agent_hydra_default_adaptative_threshold,
)
from project_agent import (
    MujocoSacMPCHydra,
    PessimistExpertAgentHydraDefaultAdaptativeThreshold,
)
from seed import seed_everything
from bounds import _get_center_bounds, BENCHMARK_MUJOCO_REFERENCE

logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)


def run_ravi_experiment(
    nb_iteration: int,
    seed: int = 1,
    experiment_name: str = "ExpertPessimist",
    device: str = "cuda:0",  # type: ignore
    num_envs: int = 1,
    env_type: str = "mujoco",
    env_name: str = "Halfcheetah",
    learning_rate_actor: float = 3e-4,
    learning_rate_critic: float = 1e-3,
    batch_size: int = 256,
    discount: float = 0.99,
    tau: float = 0.995,
    policy_frequency: int = 2,
    memory_size: int = 1_000_000,
    max_step: int = 3_000_000,
    max_step_dr: int = 3_000_000,
    train_step: int = 1,
    first_threshold_pc: float = 1.5,
    train_step_start: int = 10_000,
    generation: int = 20,
    population_size: int = 10,
    mean_value_es: float = 0.5,
    sigma_value_es: float = 0.3,
    log_loss_every_k_step: int = 10000,
    save_every_k_step: int = 100000,
    worst_env_builder_name: str = "mujoco-vanilla",
    bounds_name: str = "mujoco-vanilla",
    init_bound_name: str = "mujoco-vanilla",
    project_name: str = ""anonymous"",
    threshold_list: Optional[List[float]] = None,
):
    """Experiment runner

    Args:
        seed (int, optional): Random seed of the experiments. Defaults to 1.
        experiment_name (str, optional): The name of the experience.
            Defaults to "baseline".
        device (str, optional): Device where run the model. Defaults to "cuda:0".
        num_envs (int, optional): Number of environnement. Defaults to 1.
        env_type (str, optional): mujoco or brax. Defaults to "mujoco".
        env_name (str, optional): Name of the environment. Defaults to "Halfcheetah".
        friction_coef (float, optional): Friction coef oh the env. Defaults to 1.
        mass_coef (float, optional): Mass coef oh the env. Defaults to 1.
        learning_rate_actor (float, optional): Learning rate of the actor model.
            Defaults to 3e-4.
        learning_rate_critic (float, optional): Learning rate of the critics.
            Defaults to 1e-3.
        batch_size (int, optional): Size of the batch sampled for the optimisation
            process. Defaults to 256.
        discount (float, optional): Gamma. Defaults to 0.99.
        tau (float, optional): Soft update value (Polyak) . Defaults to 0.995.
        policy_frequency (int, optional): Optimize policy network every `k` step.
            Defaults to 2.
        memory_size (int, optional): Size of the replay buffer. Defaults to 1_000_000.
        max_step (int, optional): Number of step training. Defaults to 3_000_000.
        train_step (int, optional): Every `k` step an optimisation step. Defaults to 1.
        train_step_start (int, optional): Number of step before training.
            Defaults to 10_000.
        generation (int, optional): Number of generation for the CMA-ES. Defaults to 20.
        population_size (int, optional): Number of individual in the population.
            Defaults to 10.
        mean_value_es (float, optional): Mean value of the CMA-ES. Defaults to 0.5.
        sigma_value_es (float, optional): Sigma value of the CMA-ES. Defaults to 0.3.
        log_loss_every_k_step (int, optional): Log loss every `k` step.
            Defaults to 10000.
        save_every_k_step (int, optional): Save model every `k` step.
            Defaults to 100000.
    """

    video_folder = "result/"
    ravi_id = str(uuid.uuid4())
    run_path = f"{video_folder}/tmp_{ravi_id}"
    if not os.path.exists(run_path):
        os.makedirs(run_path)

    sequence_seed_eval = [12345]
    ENV_BUILDER_FN = {
        "mujoco-vanilla": functools.partial(
            build_envs,
            env_type="mujoco",
            device=device,  # type: ignore
            env_name=env_name,
        ),
        "mujoco_benchmark": functools.partial(
            build_envs,
            env_type="mujoco_benchmark",
            device=device,  # type: ignore
            env_name=env_name,
            seed=sequence_seed_eval[0],
        ),
    }
    BOUNDS = build_bounds(env_type=env_type, env_name=env_name, bound_name=bounds_name)

    INIT_BOUNDS = {
        "mujoco-vanilla": {"mass_coef": 1, "friction_coef": 1},
        "center": _get_center_bounds(BOUNDS),
        "reference": BENCHMARK_MUJOCO_REFERENCE[env_name][bounds_name],
    }

    worst_env_builder = ENV_BUILDER_FN[worst_env_builder_name]
    env_params_bounds = BOUNDS
    init_env_params = INIT_BOUNDS[init_bound_name]
    # sequence_seed_eval = [s for s in range(12345, 12350)]
    seed_everything(seed=seed)
    run_name = f"Default-{env_name}-{bounds_name}-{env_type}-{seed}-pessimist-hydra"

    device: torch.device = torch.device(device)  # type: ignore
    worst_parameters = init_env_params
    best_worst_parameters = worst_parameters

    env_mock = build_envs(
        env_type=env_type,
        env_name=env_name,
        num_env=num_envs,
        device=device,
        seed=seed,
        **worst_parameters,
    )
    observation_dim = deepcopy(env_mock.observation_space[0].shape[1])  # type: ignore
    action_dim = deepcopy(env_mock.action_space[0].shape[0])  # type: ignore
    del env_mock

    agent_pessimist: PessimistExpertAgentHydraDefaultAdaptativeThreshold = (
        build_pessimist_agent_hydra_default_adaptative_threshold(
            observation_dim=observation_dim,
            action_dim=action_dim,
            device=device,
        )
    )

    if threshold_list is None:
        threshold_list = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    params = {
        "seed": seed,
        "num_envs": num_envs,
        "learning_rate_actor": learning_rate_actor,
        "learning_rate_critic": learning_rate_critic,
        "batch_size": batch_size,
        "discount": discount,
        "tau": tau,
        "policy_frequency": policy_frequency,
        "memory_size": memory_size,
        "max_step": max_step,
        "train_step": train_step,
        "train_step_start": train_step_start,
        "nb_iteration": nb_iteration,
        "machine_name": socket.gethostname(),
        "mean_value_es": mean_value_es,
        "sigma_value_es": sigma_value_es,
        "threshold_list": json.dumps(threshold_list),
    }
    wandb.init(
        project=project_name,
        config=params,
        name=f"{run_name}_ravi-{str(datetime.datetime.now())}",
        save_code=True,
        reinit=True,
    )
    list_infos = []
    for iteration in range(nb_iteration + 1):
        env = build_envs(
            env_type=env_type,
            env_name=env_name,
            num_env=num_envs,
            device=device,
            seed=sequence_seed_eval[0],
            **best_worst_parameters,
        )
        if iteration == 0:
            env = build_envs(
                env_type="domain_randomization_mujoco_benchmark",
                env_name=env_name,
                num_env=num_envs,
                device=device,
                seed=sequence_seed_eval[0],
                bound=BOUNDS,
            )

        agent_star: MujocoSacMPCHydra = build_mpc_sac_hydra_agent(  # type: ignore
            observation_dim=observation_dim,
            action_dim=action_dim,
            device=device,
        )

        (
            optimizer_actor,
            optimizer_critic,
            optimizer_mpc,
            optimizer_critic_no_reg,
        ) = build_optimizer_mpc_no_reg(
            actor_parameters=agent_star.parameters_actor(),  # type: ignore
            critic_parameters=agent_star.parameters_critic(),  # type: ignore
            learning_rate_critic=learning_rate_critic,
            learning_rate_actor=learning_rate_actor,
        )

        target_entropy = -torch.prod(
            torch.tensor(env.action_space[0].shape).to(device)  # type: ignore
        ).item()

        replay = OffPolicyMemory(memory_size=memory_size, device=device)
        algo_rl_star = SACMPCHydra(
            batch_size=batch_size,
            target_entropy=target_entropy,  # type: ignore
            discount=discount,
            tau=tau,
            reparametrized_sample=True,
            device=device,
            temperature_lr=learning_rate_critic,
            policy_frequency=policy_frequency,
        )
        params.update(best_worst_parameters)

        print(f"Train an expert on {best_worst_parameters} MDP")
        trainer_fn = train_ravi_mpc_no_reg_minimal_best  # type: ignore
        expert_max_step = max_step
        if iteration == 0:
            print("Train the default expert on domain randomization")
            # Because the default agent is trained on domain randomization
            # so the best performance is will be a simple mdp rather an very good performance
            # overall
            trainer_fn = train_ravi_mpc_no_reg_minimal
            expert_max_step = max_step_dr

        trainer_fn(  # type: ignore
            agent=agent_star,
            optimizer_actor=optimizer_actor,
            optimizer_critic=optimizer_critic,
            optimizer_mpc=optimizer_mpc,
            optimizer_critic_no_reg=optimizer_critic_no_reg,
            iteration_number=iteration,
            env=env,
            sac_algo=algo_rl_star,
            buffer=replay,
            max_step=expert_max_step,
            train_step=train_step,
            train_step_start=train_step_start,
            artifact_folder=run_path,
            run_name=f"{run_name}-{iteration}",
            log_loss_every_k_step=log_loss_every_k_step,
            save_every_k_step=save_every_k_step,
            log_runner=wandb,
        )
        robust_q1: nn.Module = deepcopy(agent_star.q_value_network_1)  # noqa: F821,F841
        robust_q2: nn.Module = deepcopy(agent_star.q_value_network_2)  #
        actor_robust: nn.Module = deepcopy(agent_star.actor)  # noqa: F841

        agent_pessimist.set_robust_nets(
            q_robust_1=robust_q1, q_robust_2=robust_q2, actor_robust=actor_robust
        )  # noqa: F821
        logging.info("start to find worst parameters \n")
        agent_pessimist.eval()

        # Init the adaptive threshold
        agent_pessimist.append_last_threshold(0)
        best_treshold = None
        best_infos = None
        best_worst_parameters = None
        best_worst_performance = -float("inf")

        for threshold_pc in threshold_list:
            # the best threshold is the one that gives the best  worst performance
            agent_pessimist.set_last_threshold(threshold_pc)
            if iteration == 0:
                threshold_pc = first_threshold_pc

            logging.info("start to find worst parameters \n")
            worst_parameters, infos = find_worse_cma_es(
                agent=agent_pessimist,
                worst_env_builder=worst_env_builder,
                bound=env_params_bounds,
                generation=generation,
                population_size=population_size,
                mean_value=mean_value_es,
                sigma_value=sigma_value_es,
                sequence_seed=sequence_seed_eval,  # seed for eval envs
                seed_cma=seed,
                nb_step=1000,
                nb_trial=1,
            )

            # example = [[{'worldfriction': 4.0, 'torsomass': 7.0, 'backthighmass': 0.1676624785390559}, 245.83758857162775], [{'worldfriction': 4.0, 'torsomass': 7.0, 'backthighmass': 0.3753823338048763}, 419.4773163709352]]  # noqa: E501
            worst_case = infos[0][1]
            if worst_case > best_worst_performance:
                best_worst_performance = worst_case
                best_treshold = threshold_pc
                best_infos = infos
                best_worst_parameters = worst_parameters
            logging.info(
                f"Thesorshold {threshold_pc} gives {worst_case} worst performance \n"
            )

            if iteration == 0:
                logging.info(
                    "First iteration, we stop here because we set the threshold for the first expert \n"  # noqa: E501
                )
                break

        logging.info(f"worst parameters found is : {best_worst_parameters} \n")
        logging.info(f"The best threshold is {best_treshold} \n")
        logging.info(f"{best_infos} \n")

        list_infos.append(best_infos)

        agent_pessimist.set_last_threshold(best_treshold)
        agent_pessimist.save(f"{run_path}/agent_pessimist_hydra_{iteration}.pt")

        if iteration != 0 and best_treshold == 0:
            logging.info(
                f"The threshold is 0, and the iteration is {iteration} we stop the training \n"
            )
            break

    agent_pessimist.save(f"{run_path}/agent_pessimist_hydra_best.pt")

    with open(f"{run_path}/infos_all_train.json", "w") as f:
        json.dump(list_infos, f)  # type: ignore
    artifact = wandb.Artifact(name="model_weights", type="model")
    artifact.add_dir(run_path)
    wandb.log_artifact(artifact_or_path=artifact)
    # shutil.rmtree(run_path)
    # env.close()


if __name__ == "__main__":
    fire.Fire(run_ravi_experiment)
