from copy import deepcopy
import functools
import os
import socket
import datetime
import uuid
import json
import logging

import fire
import torch
import torch.nn as nn
import wandb

from algos import (
    OffPolicyMemory,
    SACMPCHydra,
    train_ravi_mpc_no_reg_minimal,
    find_worse_cma_es,
)
from build import (
    build_envs,
    build_optimizer_mpc_no_reg,
    build_mpc_sac_hydra_agent,
    build_bounds,
    build_pessimist_agent_hydra_default_adaptative_threshold,
)
from project_agent import (
    MujocoSacMPCHydra,
    PessimistExpertAgentHydraDefaultAdaptativeThreshold,
)
from seed import seed_everything
from bounds import _get_center_bounds, BENCHMARK_MUJOCO_REFERENCE

logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)


def run_ravi_experiment(
    nb_iteration: int,
    seed: int = 1,
    experiment_name: str = "IWOCS",
    device: str = "cuda:0",  # type: ignore
    num_envs: int = 1,
    env_type: str = "mujoco",
    env_name: str = "Halfcheetah",
    learning_rate_actor: float = 3e-4,
    learning_rate_critic: float = 1e-3,
    batch_size: int = 256,
    discount: float = 0.99,
    tau: float = 0.995,
    policy_frequency: int = 2,
    memory_size: int = 1_000_000,
    max_step: int = 3_000_000,
    max_step_dr: int = 3_000_000,
    train_step: int = 1,
    first_threshold_pc: float = 1.5,
    train_step_start: int = 10_000,
    generation: int = 20,
    population_size: int = 10,
    mean_value_es: float = 0.5,
    sigma_value_es: float = 0.3,
    log_loss_every_k_step: int = 10000,
    save_every_k_step: int = 100000,
    worst_env_builder_name: str = "mujoco-vanilla",
    bounds_name: str = "mujoco-vanilla",
    init_bound_name: str = "mujoco-vanilla",
    project_name: str = "IWOCS",
):
    """Runs the Ravi experiment for a specified number of iterations.

    Args:
        nb_iteration (int): The number of iterations for the experiment.
        seed (int, optional): Random seed for the experiment. Defaults to 1.
        experiment_name (str, optional): Name of the experiment. Defaults to "ExpertPessimist".
        device (str, optional): Device on which to run the model. Defaults to "cuda:0".
        num_envs (int, optional): Number of environments. Defaults to 1.
        env_type (str, optional): Type of environment to be used (mujoco or brax). Defaults to "mujoco".
        env_name (str, optional): Name of the specific environment. Defaults to "Halfcheetah".
        learning_rate_actor (float, optional): Learning rate of the actor model. Defaults to 3e-4.
        learning_rate_critic (float, optional): Learning rate of the critic model. Defaults to 1e-3.
        batch_size (int, optional): Size of batches for optimization. Defaults to 256.
        discount (float, optional): Discount rate (gamma). Defaults to 0.99.
        tau (float, optional): Factor for soft update (Polyak). Defaults to 0.995.
        policy_frequency (int, optional): Frequency of policy network optimization, every `k` step. Defaults to 2.
        memory_size (int, optional): Size of the replay buffer. Defaults to 1_000_000.
        max_step (int, optional): Maximum number of training steps. Defaults to 3_000_000.
        max_step_dr (int, optional): Maximum number of steps for domain randomization. Defaults to 3_000_000.
        train_step (int, optional): Number of steps between each optimization step. Defaults to 1.
        first_threshold_pc (float, optional): First threshold percentage for the optimization. Defaults to 1.5.
        train_step_start (int, optional): Number of steps before starting training. Defaults to 10_000.
        generation (int, optional): Number of generations for the CMA-ES. Defaults to 20.
        population_size (int, optional): Size of the population for the CMA-ES. Defaults to 10.
        mean_value_es (float, optional): Mean value for the CMA-ES. Defaults to 0.5.
        sigma_value_es (float, optional): Sigma value for the CMA-ES. Defaults to 0.3.
        log_loss_every_k_step (int, optional): Frequency of logging loss, every `k` step. Defaults to 10000.
        save_every_k_step (int, optional): Frequency of saving the model, every `k` step. Defaults to 100000.
        worst_env_builder_name (str, optional): Name of the environment builder for the worst-case scenario. Defaults to "mujoco-vanilla".
        bounds_name (str, optional): Name of the bounds for the environment. Defaults to "mujoco-vanilla".
        init_bound_name (str, optional): Name of the initial bounds for the environment. Defaults to "mujoco-vanilla".
        project_name (str, optional): Name of the project. Defaults to "Ravioli".

    Returns:
        None. The function is used for running the experiment with specified parameters.
    """  # noqa: E501

    video_folder = "result/"
    ravi_id = str(uuid.uuid4())
    run_path = f"{video_folder}/tmp_{ravi_id}"
    if not os.path.exists(run_path):
        os.makedirs(run_path)

    ENV_BUILDER_FN = {
        "mujoco-vanilla": functools.partial(
            build_envs,
            env_type="mujoco",
            device=device,  # type: ignore
            env_name=env_name,
        ),
        "mujoco_benchmark": functools.partial(
            build_envs,
            env_type="mujoco_benchmark",
            device=device,  # type: ignore
            env_name=env_name,
        ),
    }
    BOUNDS = build_bounds(env_type=env_type, env_name=env_name, bound_name=bounds_name)

    INIT_BOUNDS = {
        "mujoco-vanilla": {"mass_coef": 1, "friction_coef": 1},
        "center": _get_center_bounds(BOUNDS),
        "reference": BENCHMARK_MUJOCO_REFERENCE[env_name][bounds_name],
    }

    worst_env_builder = ENV_BUILDER_FN[worst_env_builder_name]
    env_params_bounds = BOUNDS
    init_env_params = INIT_BOUNDS[init_bound_name]
    # sequence_seed_eval = [s for s in range(12345, 12350)]
    sequence_seed_eval = [12345]
    seed_everything(seed=seed)
    run_name = f"Default-{env_name}-{bounds_name}-{env_type}-{seed}-pessimist-hydra"

    device: torch.device = torch.device(device)  # type: ignore
    worst_parameters = init_env_params
    best_worst_parameters = worst_parameters

    env_mock = build_envs(
        env_type=env_type,
        env_name=env_name,
        num_env=num_envs,
        device=device,
        seed=seed,
        **worst_parameters,
    )
    observation_dim = deepcopy(env_mock.observation_space[0].shape[1])  # type: ignore
    action_dim = deepcopy(env_mock.action_space[0].shape[0])  # type: ignore
    del env_mock

    agent_pessimist: PessimistExpertAgentHydraDefaultAdaptativeThreshold = (
        build_pessimist_agent_hydra_default_adaptative_threshold(
            observation_dim=observation_dim,
            action_dim=action_dim,
            device=device,
        )
    )

    params = {
        "seed": seed,
        "num_envs": num_envs,
        "learning_rate_actor": learning_rate_actor,
        "learning_rate_critic": learning_rate_critic,
        "batch_size": batch_size,
        "discount": discount,
        "tau": tau,
        "policy_frequency": policy_frequency,
        "memory_size": memory_size,
        "max_step": max_step,
        "train_step": train_step,
        "train_step_start": train_step_start,
        "nb_iteration": nb_iteration,
        "machine_name": socket.gethostname(),
        "mean_value_es": mean_value_es,
        "sigma_value_es": sigma_value_es,
    }
    wandb.init(
        project=project_name,
        config=params,
        name=f"{run_name}_ravi-{str(datetime.datetime.now())}",
        save_code=True,
        reinit=True,
    )
    list_infos = []
    for iteration in range(nb_iteration + 1):
        env = build_envs(
            env_type=env_type,
            env_name=env_name,
            num_env=num_envs,
            device=device,
            seed=seed,
            **best_worst_parameters,
        )
        if iteration == 0:
            env = build_envs(
                env_type="domain_randomization_mujoco_benchmark",
                env_name=env_name,
                num_env=num_envs,
                device=device,
                seed=seed,
                bound=BOUNDS,
            )

        agent_star: MujocoSacMPCHydra = build_mpc_sac_hydra_agent(  # type: ignore
            observation_dim=observation_dim,
            action_dim=action_dim,
            device=device,
        )

        (
            optimizer_actor,
            optimizer_critic,
            optimizer_mpc,
            optimizer_critic_no_reg,
        ) = build_optimizer_mpc_no_reg(
            actor_parameters=agent_star.parameters_actor(),  # type: ignore
            critic_parameters=agent_star.parameters_critic(),  # type: ignore
            learning_rate_critic=learning_rate_critic,
            learning_rate_actor=learning_rate_actor,
        )

        target_entropy = -torch.prod(
            torch.tensor(env.action_space[0].shape).to(device)  # type: ignore
        ).item()

        replay = OffPolicyMemory(memory_size=memory_size, device=device)
        algo_rl_star = SACMPCHydra(
            batch_size=batch_size,
            target_entropy=target_entropy,  # type: ignore
            discount=discount,
            tau=tau,
            reparametrized_sample=True,
            device=device,
            temperature_lr=learning_rate_critic,
            policy_frequency=policy_frequency,
        )
        params.update(best_worst_parameters)

        print(f"Train an expert on {best_worst_parameters} MDP")

        expert_max_step = max_step
        if iteration == 0:
            print("Train the default expert on domain randomization")
            expert_max_step = max_step_dr

        train_ravi_mpc_no_reg_minimal(  # type: ignore
            agent=agent_star,
            optimizer_actor=optimizer_actor,
            optimizer_critic=optimizer_critic,
            optimizer_mpc=optimizer_mpc,
            optimizer_critic_no_reg=optimizer_critic_no_reg,
            iteration_number=iteration,
            env=env,
            sac_algo=algo_rl_star,
            buffer=replay,
            max_step=expert_max_step,
            train_step=train_step,
            train_step_start=train_step_start,
            artifact_folder=run_path,
            run_name=f"{run_name}-{iteration}",
            log_loss_every_k_step=log_loss_every_k_step,
            save_every_k_step=save_every_k_step,
            log_runner=wandb,
        )
        robust_q1: nn.Module = deepcopy(agent_star.q_value_network_1)  # noqa: F821,F841
        robust_q2: nn.Module = deepcopy(agent_star.q_value_network_2)  #
        actor_robust: nn.Module = deepcopy(agent_star.actor)  # noqa: F841

        agent_pessimist.set_robust_nets(
            q_robust_1=robust_q1, q_robust_2=robust_q2, actor_robust=actor_robust
        )  # noqa: F821
        logging.info("start to find worst parameters \n")
        agent_pessimist.eval()

        # Init the adaptive threshold
        agent_pessimist.append_last_threshold(0)
        best_treshold = None
        best_infos = None
        best_worst_parameters = None
        best_worst_performance = -float("inf")

        for threshold_pc in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
            # the best threshold is the one that gives the best  worst performance
            agent_pessimist.set_last_threshold(threshold_pc)
            if iteration == 0:
                threshold_pc = first_threshold_pc

            logging.info("start to find worst parameters \n")
            worst_parameters, infos = find_worse_cma_es(
                agent=agent_pessimist,
                worst_env_builder=worst_env_builder,
                bound=env_params_bounds,
                generation=generation,
                population_size=population_size,
                mean_value=mean_value_es,
                sigma_value=sigma_value_es,
                sequence_seed=sequence_seed_eval,  # seed for eval envs
                seed_cma=seed,
                nb_step=1000,
                nb_trial=1,
            )

            # example = [[{'worldfriction': 4.0, 'torsomass': 7.0, 'backthighmass': 0.1676624785390559}, 245.83758857162775], [{'worldfriction': 4.0, 'torsomass': 7.0, 'backthighmass': 0.3753823338048763}, 419.4773163709352]]  # noqa: E501
            worst_case = infos[0][1]
            if worst_case > best_worst_performance:
                best_worst_performance = worst_case
                best_treshold = threshold_pc
                best_infos = infos
                best_worst_parameters = worst_parameters
            logging.info(
                f"Thesorshold {threshold_pc} gives {worst_case} worst performance \n"
            )

            if iteration == 0:
                logging.info(
                    "First iteration, we stop here because we set the threshold for the first expert \n"  # noqa: E501
                )
                break

        logging.info(f"worst parameters found is : {best_worst_parameters} \n")
        logging.info(f"The best threshold is {best_treshold} \n")
        logging.info(f"{best_infos} \n")

        list_infos.append(best_infos)

        agent_pessimist.set_last_threshold(best_treshold)
        agent_pessimist.save(f"{run_path}/agent_pessimist_hydra_{iteration}.pt")

    with open(f"{run_path}/infos_all_train.json", "w") as f:
        json.dump(list_infos, f)  # type: ignore
    artifact = wandb.Artifact(name="model_weights", type="model")
    artifact.add_dir(run_path)
    wandb.log_artifact(artifact_or_path=artifact)
    # shutil.rmtree(run_path)
    # env.close()


if __name__ == "__main__":
    fire.Fire(run_ravi_experiment)
