from copy import deepcopy
import functools
import os
import socket
import datetime
import uuid
import json
import logging
from typing import Optional, List

import fire
import torch
import torch.nn as nn
import wandb
from algos import (
    OffPolicyMemory,
    SACMPCHydra,
    train_ravi_mpc_no_reg_minimal_best,
    find_worse_mesh,
)
from build import (
    build_envs,
    build_optimizer_mpc_no_reg,
    build_mpc_sac_hydra_agent,
    build_bounds,
    build_pessimist_agent_hydra_default_adaptative_threshold,
)
from project_agent import (
    MujocoSacMPCHydra,
    PessimistExpertAgentHydraDefaultAdaptativeThreshold,
)
from seed import seed_everything
from bounds import _get_center_bounds, BENCHMARK_MUJOCO_REFERENCE

logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)


def run_ravi_experiment(
    nb_iteration: int,
    path_best_dr: str,
    seed: int = 1,
    experiment_name: str = "ExpertPessimist",
    device: str = "cuda:0",  # type: ignore
    num_envs: int = 1,
    env_type: str = "mujoco",
    env_name: str = "Halfcheetah",
    learning_rate_actor: float = 3e-4,
    learning_rate_critic: float = 1e-3,
    batch_size: int = 256,
    discount: float = 0.99,
    tau: float = 0.995,
    policy_frequency: int = 2,
    memory_size: int = 1_000_000,
    max_step: int = 3_000_000,
    train_step: int = 1,
    train_step_start: int = 10_000,
    log_loss_every_k_step: int = 10000,
    save_every_k_step: int = 100000,
    split: int = 10,
    worst_env_builder_name: str = "mujoco-vanilla",
    bounds_name: str = "mujoco-vanilla",
    init_bound_name: str = "mujoco-vanilla",
    project_name: str = ""anonymous"",
    threshold_list: Optional[List[float]] = None,
):
    """Experiment runner

    Args:
        seed (int, optional): Random seed of the experiments. Defaults to 1.
        experiment_name (str, optional): The name of the experience.
            Defaults to "baseline".
        device (str, optional): Device where run the model. Defaults to "cuda:0".
        num_envs (int, optional): Number of environnement. Defaults to 1.
        env_type (str, optional): mujoco or brax. Defaults to "mujoco".
        env_name (str, optional): Name of the environment. Defaults to "Halfcheetah".
        friction_coef (float, optional): Friction coef oh the env. Defaults to 1.
        mass_coef (float, optional): Mass coef oh the env. Defaults to 1.
        learning_rate_actor (float, optional): Learning rate of the actor model.
            Defaults to 3e-4.
        learning_rate_critic (float, optional): Learning rate of the critics.
            Defaults to 1e-3.
        batch_size (int, optional): Size of the batch sampled for the optimisation
            process. Defaults to 256.
        discount (float, optional): Gamma. Defaults to 0.99.
        tau (float, optional): Soft update value (Polyak) . Defaults to 0.995.
        policy_frequency (int, optional): Optimize policy network every `k` step.
            Defaults to 2.
        memory_size (int, optional): Size of the replay buffer. Defaults to 1_000_000.
        max_step (int, optional): Number of step training. Defaults to 3_000_000.
        train_step (int, optional): Every `k` step an optimisation step. Defaults to 1.
        train_step_start (int, optional): Number of step before training.
            Defaults to 10_000.
        generation (int, optional): Number of generation for the CMA-ES. Defaults to 20.
        population_size (int, optional): Number of individual in the population.
            Defaults to 10.
        mean_value_es (float, optional): Mean value of the CMA-ES. Defaults to 0.5.
        sigma_value_es (float, optional): Sigma value of the CMA-ES. Defaults to 0.3.
        log_loss_every_k_step (int, optional): Log loss every `k` step.
            Defaults to 10000.
        save_every_k_step (int, optional): Save model every `k` step.
            Defaults to 100000.
    """
    video_folder = "result/"
    ravi_id = str(uuid.uuid4())
    run_path = f"{video_folder}/tmp_{ravi_id}"
    if not os.path.exists(run_path):
        os.makedirs(run_path)

    sequence_seed_eval = [12345]
    ENV_BUILDER_FN = {
        "mujoco-vanilla": functools.partial(
            build_envs,
            env_type="mujoco",
            device=device,  # type: ignore
            env_name=env_name,
        ),
        "mujoco_benchmark": functools.partial(
            build_envs,
            env_type="mujoco_benchmark",
            device=device,  # type: ignore
            env_name=env_name,
            seed=sequence_seed_eval[0],
        ),
    }
    BOUNDS = build_bounds(env_type=env_type, env_name=env_name, bound_name=bounds_name)

    INIT_BOUNDS = {
        "mujoco-vanilla": {"mass_coef": 1, "friction_coef": 1},
        "center": _get_center_bounds(BOUNDS),
        "reference": BENCHMARK_MUJOCO_REFERENCE[env_name][bounds_name],
    }

    worst_env_builder = ENV_BUILDER_FN[worst_env_builder_name]
    env_params_bounds = BOUNDS
    init_env_params = INIT_BOUNDS[init_bound_name]
    # sequence_seed_eval = [s for s in range(12345, 12350)]
    seed_everything(seed=seed)
    run_name = (
        f"Default-{env_name}-{bounds_name}-{env_type}-{seed}-pessimist-hydra-best-dr"
    )

    device: torch.device = torch.device(device)  # type: ignore

    worst_parameters = init_env_params

    env_mock = build_envs(
        env_type=env_type,
        env_name=env_name,
        num_env=num_envs,
        device=device,
        seed=seed,
        **worst_parameters,
    )
    observation_dim = deepcopy(env_mock.observation_space[0].shape[1])  # type: ignore
    action_dim = deepcopy(env_mock.action_space[0].shape[0])  # type: ignore
    del env_mock

    agent_pessimist: PessimistExpertAgentHydraDefaultAdaptativeThreshold = (
        build_pessimist_agent_hydra_default_adaptative_threshold(
            observation_dim=observation_dim,
            action_dim=action_dim,
            device=device,
        )
    )
    agent_pessimist.load(path_best_dr)
    # TODO :ADD worst case eval here to initialize the iterative process
    list_infos = []

    worst_parameters, infos = find_worse_mesh(
        agent=agent_pessimist,
        env_builder=worst_env_builder,
        bound=env_params_bounds,
        split=split,
        verbose=True,
        nb_step=1000,
        nb_trial=1,
        sequence_seed=sequence_seed_eval,
    )
    logging.info(
        f"Worst parameters for domain randomization found is : {worst_parameters} \n"
    )
    list_infos.append(infos)
    best_worst_parameters = worst_parameters
    if threshold_list is None:
        threshold_list = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    params = {
        "seed": seed,
        "num_envs": num_envs,
        "learning_rate_actor": learning_rate_actor,
        "learning_rate_critic": learning_rate_critic,
        "batch_size": batch_size,
        "discount": discount,
        "tau": tau,
        "policy_frequency": policy_frequency,
        "memory_size": memory_size,
        "max_step": max_step,
        "train_step": train_step,
        "train_step_start": train_step_start,
        "nb_iteration": nb_iteration,
        "machine_name": socket.gethostname(),
        "threshold_list": json.dumps(threshold_list),
    }
    wandb.init(
        project=project_name,
        config=params,
        name=f"{run_name}_ravi-{str(datetime.datetime.now())}",
        save_code=True,
        reinit=True,
    )
    for iteration in range(nb_iteration + 1):
        env = build_envs(
            env_type=env_type,
            env_name=env_name,
            num_env=num_envs,
            device=device,
            seed=sequence_seed_eval[0],
            **best_worst_parameters,
        )

        agent_star: MujocoSacMPCHydra = build_mpc_sac_hydra_agent(  # type: ignore
            observation_dim=observation_dim,
            action_dim=action_dim,
            device=device,
        )

        (
            optimizer_actor,
            optimizer_critic,
            optimizer_mpc,
            optimizer_critic_no_reg,
        ) = build_optimizer_mpc_no_reg(
            actor_parameters=agent_star.parameters_actor(),  # type: ignore
            critic_parameters=agent_star.parameters_critic(),  # type: ignore
            learning_rate_critic=learning_rate_critic,
            learning_rate_actor=learning_rate_actor,
        )

        target_entropy = -torch.prod(
            torch.tensor(env.action_space[0].shape).to(device)  # type: ignore
        ).item()

        replay = OffPolicyMemory(memory_size=memory_size, device=device)
        algo_rl_star = SACMPCHydra(
            batch_size=batch_size,
            target_entropy=target_entropy,  # type: ignore
            discount=discount,
            tau=tau,
            reparametrized_sample=True,
            device=device,
            temperature_lr=learning_rate_critic,
            policy_frequency=policy_frequency,
        )
        params.update(best_worst_parameters)

        print(f"Train an expert on {best_worst_parameters} MDP")

        expert_max_step = max_step

        train_ravi_mpc_no_reg_minimal_best(  # type: ignore
            agent=agent_star,
            optimizer_actor=optimizer_actor,
            optimizer_critic=optimizer_critic,
            optimizer_mpc=optimizer_mpc,
            optimizer_critic_no_reg=optimizer_critic_no_reg,
            iteration_number=iteration,
            env=env,
            sac_algo=algo_rl_star,
            buffer=replay,
            max_step=expert_max_step,
            train_step=train_step,
            train_step_start=train_step_start,
            artifact_folder=run_path,
            run_name=f"{run_name}-{iteration}",
            log_loss_every_k_step=log_loss_every_k_step,
            save_every_k_step=save_every_k_step,
            log_runner=wandb,
        )
        robust_q1: nn.Module = deepcopy(agent_star.q_value_network_1)  # noqa: F821,F841
        robust_q2: nn.Module = deepcopy(agent_star.q_value_network_2)  #
        actor_robust: nn.Module = deepcopy(agent_star.actor)  # noqa: F841

        agent_pessimist.set_robust_nets(
            q_robust_1=robust_q1, q_robust_2=robust_q2, actor_robust=actor_robust
        )  # noqa: F821
        logging.info("start to find worst parameters \n")
        agent_pessimist.eval()

        # Init the adaptive threshold
        agent_pessimist.append_last_threshold(0)
        best_treshold = None
        best_infos = None
        best_worst_parameters = None
        best_worst_performance = -float("inf")

        for threshold_pc in threshold_list:
            # the best threshold is the one that gives the best  worst performance
            agent_pessimist.set_last_threshold(threshold_pc)

            logging.info(
                f"start to find worst parameters for the threshold {threshold_pc} \n"
            )
            worst_parameters, infos = find_worse_mesh(
                agent=agent_pessimist,
                env_builder=worst_env_builder,
                bound=env_params_bounds,
                split=split,
                verbose=True,
                nb_step=1000,
                nb_trial=1,
                sequence_seed=sequence_seed_eval,
            )

            # example = [[{'worldfriction': 4.0, 'torsomass': 7.0, 'backthighmass': 0.1676624785390559}, 245.83758857162775], [{'worldfriction': 4.0, 'torsomass': 7.0, 'backthighmass': 0.3753823338048763}, 419.4773163709352]]  # noqa: E501
            worst_case = infos[0][1]
            if worst_case > best_worst_performance:
                best_worst_performance = worst_case
                best_treshold = threshold_pc
                best_infos = infos
                best_worst_parameters = worst_parameters
            logging.info(
                f"Thesorshold {threshold_pc} gives {worst_case} worst performance \n"
            )

        logging.info(f"worst parameters found is : {best_worst_parameters} \n")
        logging.info(f"The best threshold is {best_treshold} \n")
        logging.info(f"{best_infos} \n")

        list_infos.append(best_infos)

        agent_pessimist.set_last_threshold(best_treshold)
        agent_pessimist.save(f"{run_path}/agent_pessimist_hydra_{iteration}.pt")

        if best_treshold == 0:
            logging.info(
                f"The threshold is 0, and the iteration is {iteration} we stop the training \n"
            )
            break

    agent_pessimist.save(f"{run_path}/agent_pessimist_hydra_best.pt")

    with open(f"{run_path}/infos_all_train.json", "w") as f:
        json.dump(list_infos, f)  # type: ignore
    artifact = wandb.Artifact(name="model_weights", type="model")
    artifact.add_dir(run_path)
    wandb.log_artifact(artifact_or_path=artifact)
    # shutil.rmtree(run_path)
    # env.close()


if __name__ == "__main__":
    fire.Fire(run_ravi_experiment)
