from copy import deepcopy
from typing import Optional, Tuple, Dict, Union, Any, Callable, List
import json
import gym  # type: ignore
import torch
from torch import optim
from tqdm import tqdm  # type: ignore
import wandb

# import wandb
from .base_agent import IAgent
from .sac import SAC, ImitationLearningPolicy, SACMPC, SACMPCHydra, SACMPCNoReg
from .replay import OffPolicyMemory, BatchTransition
from .evaluate import find_worse, evaluate


def train_sac(
    agent: IAgent,
    optimizer_actor: optim.Optimizer,
    optimizer_critic: optim.Optimizer,
    env: gym.Env,
    sac_algo: SAC,
    buffer: OffPolicyMemory,
    max_step: int,
    train_step: int,
    train_step_start: int,
    experiment_name: str,
    log_loss_every_k_step: int = 1000,
    artifact_folder: str = "result",
    run_name: str = "Default",
    params: Optional[Dict[str, Any]] = None,
):
    """Sac trainer

    Args:
        agent (IAgent): Explicit
        optimizer_actor (optim.Optimizer): Explicit
        optimizer_critic (optim.Optimizer): explicit
        env (gym.Env): Explicit
        sac_algo (SAC): Explicit
        buffer (OffPolicyMemory): Explicit
        max_step (int): number of step for training
        train_step (int): train every
        train_step_start (int): Number of step before train
        experiment_name (str): Name of the experiment for mlflow
        log_loss_every_k_step (int) : Log in mlflow every k step of gradient descent .Defaults to `1000`.
        run_name (str): Name of the run
    """
    if params is None:
        params = {}



    obs = env.reset()
    for step in tqdm(range(1, max_step + 1), ncols=80):
        with torch.no_grad():
            if step > train_step_start:
                mass = agent.action(observation=obs)
                actions = mass.sample()
            else:
                actions = env.action_space.sample()
            new_obs, reward, done, info = env.step(actions)
            storage = BatchTransition(
                state=obs,
                action=actions,
                reward=reward,  # type: ignore
                next_state=new_obs,
                done=done,  # type: ignore
            )
            buffer.append(storage)
            obs = new_obs
            info_log_env = _reduce_vec_info(infos=info)
        # wandb.log(data=info_log_env, step=step)
        if step % train_step == 0 and step > train_step_start:
            dict_loss = sac_algo.update(
                replay=buffer,
                agent=agent,
                optimizer_actor=optimizer_actor,
                optimizer_critic=optimizer_critic,
                scheduler=None,
            )
        else:
            dict_loss = {}
        # if step % log_loss_every_k_step == 0 and step > train_step_start:
        # wandb.log(data=dict_loss, step=step)
    agent.save(path=f"{artifact_folder}/sac_{run_name}.pth")
    # wandb.log_artifact(artifact_or_path=artifact_folder)
    env.close()


def train_sac_mpc(
    agent: IAgent,
    optimizer_actor: optim.Optimizer,
    optimizer_critic: optim.Optimizer,
    optimizer_mpc: optim.Optimizer,
    env: gym.Env,
    sac_algo: SACMPC,
    buffer: OffPolicyMemory,
    max_step: int,
    train_step: int,
    train_step_start: int,
    log_loss_every_k_step: int = 1000,
    artifact_folder: str = "result",
    run_name: str = "Default",
    experiment_name: Optional[str] = None,
    log_runner: Optional["wandb.sdk.wandb_run.Run"] = None,  # type: ignore
    params: Optional[Dict[str, Any]] = None,
):
    """Sac trainer

    Args:
        agent (IAgent): Explicit
        optimizer_actor (optim.Optimizer): Explicit
        optimizer_critic (optim.Optimizer): explicit
        env (gym.Env): Explicit
        sac_algo (SAC): Explicit
        buffer (OffPolicyMemory): Explicit
        max_step (int): number of step for training
        train_step (int): train every
        train_step_start (int): Number of step before train
        experiment_name (str): Name of the experiment for mlflow
        log_loss_every_k_step (int) : Log in mlflow every k step of gradient descent .Defaults to `1000`.
        run_name (str): Name of the run
    """
    if params is None:
        params = {}



    obs = env.reset()
    for step in tqdm(range(1, max_step + 1)):
        with torch.no_grad():
            if step > train_step_start:
                mass = agent.action(observation=obs)
                actions = mass.sample()
            else:
                actions = env.action_space.sample()
            new_obs, reward, done, info = env.step(actions)
            storage = BatchTransition(
                state=obs,
                action=actions,
                reward=reward,  # type: ignore
                next_state=new_obs,  # type: ignore
                done=done,  # type: ignore
            )
            buffer.append(storage)
            obs = new_obs
            info_log_env = _reduce_vec_info(infos=info)
            if log_runner is not None:
                log_runner.log(data=info_log_env)
        # wandb.log(data=info_log_env, step=step)
        if step % train_step == 0 and step > train_step_start:
            dict_loss = sac_algo.update(
                replay=buffer,
                agent=agent,
                optimizer_actor=optimizer_actor,
                optimizer_critic=optimizer_critic,
                optimizer_mpc=optimizer_mpc,
                scheduler=None,
            )
        else:
            dict_loss = {}
        # if step % log_loss_every_k_step == 0 and step > train_step_start:
        # wandb.log(data=dict_loss, step=step)
    agent.save(path=f"{artifact_folder}/sacmpc_{run_name}.pth")
    # artifact = wandb.Artifact(name="model_weights", type="model")
    # artifact.add_dir(artifact_folder)
    # wandb.log_artifact(artifact_or_path=artifact)
    env.close()


def train_ravi(
    agent: IAgent,
    optimizer_actor: optim.Optimizer,
    optimizer_critic: optim.Optimizer,
    env: gym.Env,
    sac_algo: SAC,
    buffer: OffPolicyMemory,
    max_step: int,
    train_step: int,
    train_step_start: int,
    experiment_name: str,
    worst_env_builder: Callable[..., gym.Env],
    plotter_heatmap: Callable[[List[Tuple[Dict[str, float], float]]], str],
    env_params_bounds: Dict[str, List[int]],
    log_loss_every_k_step: int = 1000,
    artifact_folder: str = "result",
    run_name: str = "Default",
    params: Optional[Dict[str, Any]] = None,
    **params_worst_env,
):
    """Sac trainer

    Args:
        agent (IAgent): Explicit
        optimizer_actor (optim.Optimizer): Explicit
        optimizer_critic (optim.Optimizer): explicit
        env (gym.Env): Explicit
        sac_algo (SAC): Explicit
        buffer (OffPolicyMemory): Explicit
        max_step (int): number of step for training
        train_step (int): train every
        train_step_start (int): Number of step before train
        experiment_name (str): Name of the experiment for mlflow
        log_loss_every_k_step (int) : Log in mlflow every k step of gradient descent .Defaults to `1000`.
        run_name (str): Name of the run
    """
    if params is None:
        params = {}



    obs = env.reset()
    for step in tqdm(range(1, max_step + 1)):
        with torch.no_grad():
            if step > train_step_start:
                mass = agent.action(observation=obs)
                actions = mass.sample()
            else:
                actions = env.action_space.sample()
            new_obs, reward, done, info = env.step(actions)
            storage = BatchTransition(
                state=obs,
                action=actions,
                reward=reward,  # type: ignore
                next_state=new_obs,
                done=done,  # type: ignore
            )
            buffer.append(storage)
            obs = new_obs
            info_log_env = _reduce_vec_info(infos=info)
        # wandb.log(data=info_log_env, step=step)
        if step % train_step == 0 and step > train_step_start:
            dict_loss = sac_algo.update(
                replay=buffer,
                agent=agent,
                optimizer_actor=optimizer_actor,
                optimizer_critic=optimizer_critic,
                scheduler=None,
            )
        else:
            dict_loss = {}
        # if step % log_loss_every_k_step == 0 and step > train_step_start:
        # wandb.log(data=dict_loss, step=step)
    agent.save(path=f"{artifact_folder}/sac_{run_name}.pth")
    print("COMPUTE WORST ENV PARAMS")
    worst_parameters, history_configuration = find_worse(
        agent=agent,
        env_builder=worst_env_builder,
        bound=env_params_bounds,
        **params_worst_env,
    )
    with open(f"{artifact_folder}/worst_parameters_{run_name}.json", "w") as wp:
        json.dump(worst_parameters, wp)
    with open(f"{artifact_folder}/history_configuration_{run_name}.json", "w") as hc:
        json.dump(history_configuration, hc)
    print(f"WORST PARAMS IS: {worst_parameters}")
    try:
        plotter_heatmap(
            history_configuration, f"{artifact_folder}/heatmap_{run_name}.png"
        )
    except ValueError as e:  # noqa: F841 horrible hack TODO: Fix it
        pass

    # wandb.log_artifact(artifact_or_path=artifact_folder)
    env.close()
    # TODO Handle info
    # mlflow.pytorch.log_state_dict(
    #         self.network.module.state_dict(),
    #         f"model_{game}_{avg_score}",
    #     )
    #
    return worst_parameters, history_configuration


def train_ravi_mpc(
    agent: IAgent,
    optimizer_actor: optim.Optimizer,
    optimizer_critic: optim.Optimizer,
    optimizer_mpc: optim.Optimizer,
    env: gym.Env,
    sac_algo: SACMPC,
    buffer: OffPolicyMemory,
    max_step: int,
    train_step: int,
    train_step_start: int,
    experiment_name: str,
    log_runner: "wandb.sdk.wandb_run.Run",  # type: ignore
    worst_env_builder: Callable[..., gym.Env],
    plotter_heatmap: Callable[[List[Tuple[Dict[str, float], float]]], str],
    env_params_bounds: Dict[str, List[int]],
    iteration_number: int,
    log_loss_every_k_step: int = 1000,
    artifact_folder: str = "result",
    run_name: str = "Default",
    params: Optional[Dict[str, Any]] = None,
    **params_worst_env,
):
    """Sac trainer

    Args:
        agent (IAgent): Explicit
        optimizer_actor (optim.Optimizer): Explicit
        optimizer_critic (optim.Optimizer): explicit
        env (gym.Env): Explicit
        sac_algo (SAC): Explicit
        buffer (OffPolicyMemory): Explicit
        max_step (int): number of step for training
        train_step (int): train every
        train_step_start (int): Number of step before train
        experiment_name (str): Name of the experiment for mlflow
        log_loss_every_k_step (int) : Log in mlflow every k step of gradient descent .Defaults to `1000`.
        run_name (str): Name of the run
    """
    if params is None:
        params = {}
    obs = env.reset()
    custom_step = f"ravi_step_{iteration_number}"
    key_already_seen = set()
    log_runner.define_metric(f"ravi_step_{iteration_number}")
    for step in tqdm(range(1, max_step + 1)):
        with torch.no_grad():
            if step > train_step_start:
                mass = agent.action(observation=obs)
                actions = mass.sample()
            else:
                actions = env.action_space.sample()
            new_obs, reward, done, info = env.step(actions)
            storage = BatchTransition(
                state=obs,
                action=actions,
                reward=reward,  # type: ignore
                next_state=new_obs,
                done=done,  # type: ignore
            )
            buffer.append(storage)
            obs = new_obs
            info_log_env = _reduce_vec_info(infos=info)
            info_log_env = {
                f"{k}_{iteration_number}": v for k, v in info_log_env.items()
            }
            info_log_env.update({custom_step: step})
            for k in info_log_env.keys():
                if k not in key_already_seen:
                    log_runner.define_metric(k, step_metric=custom_step)
                    key_already_seen.add(k)

        log_runner.log(data=info_log_env)
        if step % train_step == 0 and step > train_step_start:
            dict_loss = sac_algo.update(
                replay=buffer,
                agent=agent,
                optimizer_actor=optimizer_actor,
                optimizer_critic=optimizer_critic,
                optimizer_mpc=optimizer_mpc,
                scheduler=None,
            )
            dict_loss = {f"{k}_{iteration_number}": v for k, v in dict_loss.items()}
            dict_loss.update({f"ravi_step_{iteration_number}": step})
            for k in info_log_env.keys():
                if k not in key_already_seen:
                    log_runner.define_metric(k, step_metric=custom_step)
                    key_already_seen.add(k)
        else:
            dict_loss = {}
        if step % log_loss_every_k_step == 0 and step > train_step_start:
            log_runner.log(data=dict_loss)
    agent.save(path=f"{artifact_folder}/sac_{run_name}_{iteration_number}.pth")
    print("COMPUTE WORST ENV PARAMS")
    worst_parameters, history_configuration = find_worse(
        agent=agent,
        env_builder=worst_env_builder,
        bound=env_params_bounds,
        verbose=True,
        **params_worst_env,
    )
    with open(
        f"{artifact_folder}/worst_parameters_{run_name}_{iteration_number}.json", "w"
    ) as wp:
        json.dump(worst_parameters, wp)
    with open(
        f"{artifact_folder}/history_configuration_{run_name}_{iteration_number}.json",
        "w",
    ) as hc:
        json.dump(history_configuration, hc)
    print(f"WORST PARAMS IS: {worst_parameters}")
    try:
        plotter_heatmap(
            history_configuration,
            f"{artifact_folder}/heatmap_{run_name}_{iteration_number}.png",
        )
    except ValueError as e:  # noqa: F841 horrible hack TODO: Fix it
        pass

    # artifact = log_runner.Artifact(name="model_weights", type="model")
    # artifact.add_dir(artifact_folder)
    # log_runner.log_artifact(artifact_or_path=artifact)
    env.close()
    # TODO Handle info
    # mlflow.pytorch.log_state_dict(
    #         self.network.module.state_dict(),
    #         f"model_{game}_{avg_score}",
    #     )
    #
    return worst_parameters, history_configuration


def train_ravi_mpc_no_reg(
    agent: IAgent,
    optimizer_actor: optim.Optimizer,
    optimizer_critic: optim.Optimizer,
    optimizer_mpc: optim.Optimizer,
    optimizer_critic_no_reg: optim.Optimizer,
    env: gym.Env,
    sac_algo: SACMPCHydra,
    buffer: OffPolicyMemory,
    max_step: int,
    train_step: int,
    train_step_start: int,
    experiment_name: str,
    log_runner: "wandb.sdk.wandb_run.Run",  # type: ignore
    worst_env_builder: Callable[..., gym.Env],
    plotter_heatmap: Callable[[List[Tuple[Dict[str, float], float]]], str],
    env_params_bounds: Dict[str, List[int]],
    iteration_number: int,
    log_loss_every_k_step: int = 1000,
    artifact_folder: str = "result",
    run_name: str = "Default",
    params: Optional[Dict[str, Any]] = None,
    **params_worst_env,
):
    """Sac trainer

    Args:
        agent (IAgent): Explicit
        optimizer_actor (optim.Optimizer): Explicit
        optimizer_critic (optim.Optimizer): explicit
        optimizer_mpc (optim.Optimizer): explicit
        optimizer_critic_no_reg (optim.Optimizer): explicit
        env (gym.Env): Explicit
        sac_algo (SAC): Explicit
        buffer (OffPolicyMemory): Explicit
        max_step (int): number of step for training
        train_step (int): train every
        train_step_start (int): Number of step before train
        experiment_name (str): Name of the experiment for mlflow
        log_loss_every_k_step (int) : Log in mlflow every k step of gradient descent .Defaults to `1000`.
        run_name (str): Name of the run
    """
    if params is None:
        params = {}
    obs = env.reset()
    custom_step = f"ravi_step_{iteration_number}"
    key_already_seen = set()
    log_runner.define_metric(f"ravi_step_{iteration_number}")
    for step in tqdm(range(1, max_step + 1)):
        with torch.no_grad():
            if step > train_step_start:
                mass = agent.action(observation=obs)
                actions = mass.sample()
            else:
                actions = env.action_space.sample()
            new_obs, reward, done, info = env.step(actions)
            storage = BatchTransition(
                state=obs,
                action=actions,
                reward=reward,  # type: ignore
                next_state=new_obs,
                done=done,  # type: ignore
            )
            buffer.append(storage)
            obs = new_obs
            info_log_env = _reduce_vec_info(infos=info)
            info_log_env = {
                f"{k}_{iteration_number}": v for k, v in info_log_env.items()
            }
            info_log_env.update({custom_step: step})
            for k in info_log_env.keys():
                if k not in key_already_seen:
                    log_runner.define_metric(k, step_metric=custom_step)
                    key_already_seen.add(k)

        log_runner.log(data=info_log_env)
        if step % train_step == 0 and step > train_step_start:
            dict_loss = sac_algo.update(
                replay=buffer,
                agent=agent,
                optimizer_actor=optimizer_actor,
                optimizer_critic=optimizer_critic,
                optimizer_mpc=optimizer_mpc,
                optimizer_critic_no_reg=optimizer_critic_no_reg,
                scheduler=None,
            )
            dict_loss = {f"{k}_{iteration_number}": v for k, v in dict_loss.items()}
            dict_loss.update({f"ravi_step_{iteration_number}": step})
            for k in info_log_env.keys():
                if k not in key_already_seen:
                    log_runner.define_metric(k, step_metric=custom_step)
                    key_already_seen.add(k)
        else:
            dict_loss = {}
        if step % log_loss_every_k_step == 0 and step > train_step_start:
            log_runner.log(data=dict_loss)
    agent.save(path=f"{artifact_folder}/sac_{run_name}_{iteration_number}.pth")
    print("COMPUTE WORST ENV PARAMS")
    worst_parameters, history_configuration = find_worse(
        agent=agent,
        env_builder=worst_env_builder,
        bound=env_params_bounds,
        verbose=True,
        **params_worst_env,
    )
    with open(
        f"{artifact_folder}/worst_parameters_{run_name}_{iteration_number}.json", "w"
    ) as wp:
        json.dump(worst_parameters, wp)
    with open(
        f"{artifact_folder}/history_configuration_{run_name}_{iteration_number}.json",
        "w",
    ) as hc:
        json.dump(history_configuration, hc)
    print(f"WORST PARAMS IS: {worst_parameters}")
    try:
        plotter_heatmap(
            history_configuration,
            f"{artifact_folder}/heatmap_{run_name}_{iteration_number}.png",
        )
    except ValueError as e:  # noqa: F841 horrible hack TODO: Fix it
        pass

    # artifact = log_runner.Artifact(name="model_weights", type="model")
    # artifact.add_dir(artifact_folder)
    # log_runner.log_artifact(artifact_or_path=artifact)
    env.close()
    # TODO Handle info
    # mlflow.pytorch.log_state_dict(
    #         self.network.module.state_dict(),
    #         f"model_{game}_{avg_score}",
    #     )
    #
    return worst_parameters, history_configuration


def train_imitation(
    agent: IAgent,
    optimizer_actor: optim.Optimizer,
    optimizer_critic: Optional[optim.Optimizer],
    env: gym.Env,
    imitation_algo: ImitationLearningPolicy,
    buffer: OffPolicyMemory,
    max_step_imitation: int,
    experiment_name: str,
    iteration_number: int,
    worst_env_builder: Callable[..., gym.Env],
    plotter_heatmap: Callable[[List[Tuple[Dict[str, float], float]]], str],
    env_params_bounds: Dict[str, List[int]],
    log_runner: "wandb.sdk.wandb_run.Run",  # type: ignore
    log_loss_every_k_step: int = 1000,
    artifact_folder: str = "result",
    run_name: str = "Imitation_Default",
    params: Optional[Dict[str, Any]] = None,
    **params_worst_env,
):
    """Sac trainer

    Args:
        agent (IAgent): Explicit
        optimizer_actor (optim.Optimizer): Explicit
        optimizer_critic (optim.Optimizer): explicit
        env (gym.Env): Explicit
        sac_algo (SAC): Explicit
        buffer (OffPolicyMemory): Explicit
        max_step (int): number of step for training
        train_step (int): train every
        train_step_start (int): Number of step before train
        experiment_name (str): Name of the experiment for mlflow
        log_loss_every_k_step (int) : Log in mlflow every k step of gradient descent .Defaults to `1000`.
        run_name (str): Name of the run
    """
    if params is None:
        params = {}

    custom_step = f"imitiation_step_{iteration_number}"
    key_already_seen = set()
    log_runner.define_metric(f"ravi_step_{iteration_number}")
    for step in tqdm(range(1, max_step_imitation + 1)):
        dict_loss = imitation_algo.update(
            replay=buffer,
            agent=agent,
            optimizer_critic=optimizer_critic,
            optimizer_actor=optimizer_actor,
            scheduler=None,
        )
        if step % log_loss_every_k_step == 0:
            mean_reward = evaluate(
                agent=agent,
                env=deepcopy(env),
                nb_step=1000,
                nb_trial=1,
                sequence_seed=[42],
            )
            dict_loss.update({f"mean_reward_imitation": mean_reward})
            dict_loss = {f"{k}_{iteration_number}": v for k, v in dict_loss.items()}

            dict_loss.update({custom_step: step})
            for k in dict_loss.keys():
                if k not in key_already_seen:
                    log_runner.define_metric(k, step_metric=custom_step)
                    key_already_seen.add(k)
            log_runner.log(data=dict_loss)
    agent.save(
        path=f"{artifact_folder}/sac_imitation_pc_{run_name}_{iteration_number}.pth"
    )
    print("COMPUTE WORST ENV PARAMS")
    worst_parameters, history_configuration = find_worse(
        agent=agent,
        env_builder=worst_env_builder,
        bound=env_params_bounds,
        **params_worst_env,
    )
    with open(
        f"{artifact_folder}/worst_parameters_imitation_{run_name}_{iteration_number}.json",
        "w",
    ) as wp:
        json.dump(worst_parameters, wp)
    with open(
        f"{artifact_folder}/history_configuration_imitation_{run_name}_{iteration_number}.json",
        "w",
    ) as hc:
        json.dump(history_configuration, hc)
    print(f"WORST PARAMS IS: {worst_parameters}")
    try:
        plotter_heatmap(
            history_configuration,
            f"{artifact_folder}/heatmap_imitation_{run_name}_{iteration_number}.png",
        )
    except ValueError as e:  # noqa: F841 horrible hack TODO: Fix it
        pass

    # artifact = log_runner.Artifact(name="model_weights", type="model")
    # artifact.add_dir(artifact_folder)
    # log_runner.log_artifact(artifact_or_path=artifact)
    env.close()
    return worst_parameters, history_configuration


def _reduce_vec_info(
    infos: Union[Dict[str, float], Tuple[Dict[str, float]]]
) -> Dict[str, float]:
    if isinstance(infos, dict):
        if "episode" in infos:
            infos["episode_length"] = infos["episode"]["l"]
            infos["episode_reward"] = infos["episode"]["r"]
            infos["episode_time"] = infos["episode"]["t"]
            del infos["episode"]
        return infos

    flatten_info = {
        "episode_length": 0,
        "episode_reward": 0,
        "episode_time": 0,
    }

    nb_episode_end = 0
    for info in infos:
        if "episode" in info:
            flatten_info["episode_length"] += info["episode"]["l"] / len(infos)  # type: ignore
            flatten_info["episode_reward"] += info["episode"]["r"] / len(infos)  # type: ignore
            flatten_info["episode_time"] += info["episode"]["t"] / len(infos)  # type: ignore
            nb_episode_end += 1
    return {} if nb_episode_end == 0 else flatten_info


def train_ravi_mpc_minimal(
    agent: IAgent,
    optimizer_actor: optim.Optimizer,
    optimizer_critic: optim.Optimizer,
    optimizer_mpc: optim.Optimizer,
    env: gym.Env,
    sac_algo: SACMPC,
    buffer: OffPolicyMemory,
    max_step: int,
    train_step: int,
    train_step_start: int,
    log_runner: "wandb.sdk.wandb_run.Run",  # type: ignore
    iteration_number: int,
    log_loss_every_k_step: int = 1000,
    save_every_k_step: int = 100_000,
    params: Optional[Dict[str, Any]] = None,
    artifact_folder: str = "result",
    run_name: str = "Default",
):
    """Sac trainer

    Args:
        agent (IAgent): Explicit
        optimizer_actor (optim.Optimizer): Explicit
        optimizer_critic (optim.Optimizer): explicit
        env (gym.Env): Explicit
        sac_algo (SAC): Explicit
        buffer (OffPolicyMemory): Explicit
        max_step (int): number of step for training
        train_step (int): train every
        train_step_start (int): Number of step before train
        experiment_name (str): Name of the experiment for mlflow
        log_loss_every_k_step (int) : Log in mlflow every k step of gradient descent .Defaults to `1000`.
        run_name (str): Name of the run
    """
    if params is None:
        params = {}
    obs = env.reset()
    custom_step = f"ravi_step_{iteration_number}"
    key_already_seen = set()
    log_runner.define_metric(f"ravi_step_{iteration_number}")
    for step in tqdm(range(1, max_step + 1)):
        with torch.no_grad():
            if step > train_step_start:
                mass = agent.action(observation=obs)
                actions = mass.sample()
            else:
                actions = env.action_space.sample()
            new_obs, reward, done, info = env.step(actions)
            storage = BatchTransition(
                state=obs,
                action=actions,
                reward=reward,  # type: ignore
                next_state=new_obs,
                done=done,  # type: ignore
            )
            buffer.append(storage)
            obs = new_obs
            info_log_env = _reduce_vec_info(infos=info)
            info_log_env = {
                f"{k}_{iteration_number}": v for k, v in info_log_env.items()
            }
            info_log_env.update({custom_step: step})
            for k in info_log_env.keys():
                if k not in key_already_seen:
                    log_runner.define_metric(k, step_metric=custom_step)
                    key_already_seen.add(k)

        log_runner.log(data=info_log_env)
        if step % train_step == 0 and step > train_step_start:
            dict_loss = sac_algo.update(
                replay=buffer,
                agent=agent,
                optimizer_actor=optimizer_actor,
                optimizer_critic=optimizer_critic,
                optimizer_mpc=optimizer_mpc,
                scheduler=None,
            )
            dict_loss = {f"{k}_{iteration_number}": v for k, v in dict_loss.items()}
            dict_loss.update({f"ravi_step_{iteration_number}": step})
            for k in info_log_env.keys():
                if k not in key_already_seen:
                    log_runner.define_metric(k, step_metric=custom_step)
                    key_already_seen.add(k)
        else:
            dict_loss = {}
        if step % log_loss_every_k_step == 0 and step > train_step_start:
            log_runner.log(data=dict_loss)

        if step % save_every_k_step == 0 and step > train_step_start:
            agent.save(
                path=f"{artifact_folder}/sac_{run_name}_{iteration_number}_step_{step}.pth"
            )
    agent.save(
        path=f"{artifact_folder}/sac_{run_name}_{iteration_number}_final_step_{step}.pth"  # type: ignore
    )

    env.close()


def train_ravi_mpc_no_reg_minimal(
    agent: IAgent,
    optimizer_actor: optim.Optimizer,
    optimizer_critic: optim.Optimizer,
    optimizer_mpc: optim.Optimizer,
    optimizer_critic_no_reg: optim.Optimizer,
    env: gym.Env,
    sac_algo: SACMPCHydra,
    buffer: OffPolicyMemory,
    max_step: int,
    train_step: int,
    train_step_start: int,
    log_runner: "wandb.sdk.wandb_run.Run",  # type: ignore
    iteration_number: int,
    log_loss_every_k_step: int = 1000,
    save_every_k_step: int = 300_000,
    artifact_folder: str = "result",
    run_name: str = "Default",
):
    """Sac trainer

    Args:
        agent (IAgent): Explicit
        optimizer_actor (optim.Optimizer): Explicit
        optimizer_critic (optim.Optimizer): explicit
        optimizer_mpc (optim.Optimizer): explicit
        optimizer_critic_no_reg (optim.Optimizer): explicit
        env (gym.Env): Explicit
        sac_algo (SAC): Explicit
        buffer (OffPolicyMemory): Explicit
        max_step (int): number of step for training
        train_step (int): train every
        train_step_start (int): Number of step before train
        experiment_name (str): Name of the experiment for mlflow
        log_loss_every_k_step (int) : Log in mlflow every k step of gradient descent .Defaults to `1000`.
        run_name (str): Name of the run
    """
    obs = env.reset()
    custom_step = f"ravi_step_{iteration_number}"
    key_already_seen = set()
    log_runner.define_metric(f"ravi_step_{iteration_number}")
    for step in tqdm(range(1, max_step + 1)):
        with torch.no_grad():
            if step > train_step_start:
                mass = agent.action(observation=obs)
                actions = mass.sample()
            else:
                actions = env.action_space.sample()
            new_obs, reward, done, info = env.step(actions)
            storage = BatchTransition(
                state=obs,
                action=actions,
                reward=reward,  # type: ignore
                next_state=new_obs,
                done=done,  # type: ignore
            )
            buffer.append(storage)
            obs = new_obs
            info_log_env = _reduce_vec_info(infos=info)
            info_log_env = {
                f"{k}_{iteration_number}": v for k, v in info_log_env.items()
            }
            info_log_env.update({custom_step: step})
            for k in info_log_env.keys():
                if k not in key_already_seen:
                    log_runner.define_metric(k, step_metric=custom_step)
                    key_already_seen.add(k)

        log_runner.log(data=info_log_env)
        if step % train_step == 0 and step > train_step_start:
            dict_loss = sac_algo.update(
                replay=buffer,
                agent=agent,
                optimizer_actor=optimizer_actor,
                optimizer_critic=optimizer_critic,
                optimizer_mpc=optimizer_mpc,
                optimizer_critic_no_reg=optimizer_critic_no_reg,
                scheduler=None,
            )
            dict_loss = {f"{k}_{iteration_number}": v for k, v in dict_loss.items()}
            dict_loss.update({f"ravi_step_{iteration_number}": step})
            for k in info_log_env.keys():
                if k not in key_already_seen:
                    log_runner.define_metric(k, step_metric=custom_step)
                    key_already_seen.add(k)
        else:
            dict_loss = {}
        if step % log_loss_every_k_step == 0 and step > train_step_start:
            log_runner.log(data=dict_loss)

        if step % save_every_k_step == 0 and step > train_step_start:
            agent.save(
                path=f"{artifact_folder}/sac_hydra_{run_name}_{iteration_number}_step_{step}.pth"
            )
    agent.save(
        path=f"{artifact_folder}/sac_hydra_{run_name}_{iteration_number}_final_step_{step}.pth"  # type: ignore
    )

    env.close()


def train_ravi_mpc_no_reg_minimal_best(
    agent: IAgent,
    optimizer_actor: optim.Optimizer,
    optimizer_critic: optim.Optimizer,
    optimizer_mpc: optim.Optimizer,
    optimizer_critic_no_reg: optim.Optimizer,
    env: gym.Env,
    sac_algo: SACMPCHydra,
    buffer: OffPolicyMemory,
    max_step: int,
    train_step: int,
    train_step_start: int,
    log_runner: "wandb.sdk.wandb_run.Run",  # type: ignore
    iteration_number: int,
    log_loss_every_k_step: int = 1000,
    save_every_k_step: int = 300_000,
    artifact_folder: str = "result",
    run_name: str = "Default",
):
    """Sac trainer

    Args:
        agent (IAgent): Explicit
        optimizer_actor (optim.Optimizer): Explicit
        optimizer_critic (optim.Optimizer): explicit
        optimizer_mpc (optim.Optimizer): explicit
        optimizer_critic_no_reg (optim.Optimizer): explicit
        env (gym.Env): Explicit
        sac_algo (SAC): Explicit
        buffer (OffPolicyMemory): Explicit
        max_step (int): number of step for training
        train_step (int): train every
        train_step_start (int): Number of step before train
        experiment_name (str): Name of the experiment for mlflow
        log_loss_every_k_step (int) : Log in mlflow every k step of gradient descent .Defaults to `1000`.
        run_name (str): Name of the run
    """
    obs = env.reset()
    custom_step = f"ravi_step_{iteration_number}"
    key_already_seen = set()
    log_runner.define_metric(f"ravi_step_{iteration_number}")
    best_episode_reward = -float("inf")
    agent.save(
        path=f"{artifact_folder}/sac_hydra_{run_name}_{iteration_number}_best.pth"  # type: ignore
    )
    for step in tqdm(range(1, max_step + 1)):
        with torch.no_grad():
            if step > train_step_start:
                mass = agent.action(observation=obs)
                actions = mass.sample()
            else:
                actions = env.action_space.sample()
            new_obs, reward, done, info = env.step(actions)
            storage = BatchTransition(
                state=obs,
                action=actions,
                reward=reward,  # type: ignore
                next_state=new_obs,
                done=done,  # type: ignore
            )
            buffer.append(storage)
            obs = new_obs
            info_log_env = _reduce_vec_info(infos=info)
            if "episode_reward" in info_log_env:
                if info_log_env["episode_reward"] > best_episode_reward:
                    agent.save(
                        path=f"{artifact_folder}/sac_hydra_{run_name}_{iteration_number}_best.pth"  # type: ignore
                    )
                    best_episode_reward = info_log_env["episode_reward"]
            info_log_env = {
                f"{k}_{iteration_number}": v for k, v in info_log_env.items()
            }
            info_log_env.update({custom_step: step})
            for k in info_log_env.keys():
                if k not in key_already_seen:
                    log_runner.define_metric(k, step_metric=custom_step)
                    key_already_seen.add(k)

        log_runner.log(data=info_log_env)
        if step % train_step == 0 and step > train_step_start:
            dict_loss = sac_algo.update(
                replay=buffer,
                agent=agent,
                optimizer_actor=optimizer_actor,
                optimizer_critic=optimizer_critic,
                optimizer_mpc=optimizer_mpc,
                optimizer_critic_no_reg=optimizer_critic_no_reg,
                scheduler=None,
            )
            dict_loss = {f"{k}_{iteration_number}": v for k, v in dict_loss.items()}
            dict_loss.update({f"ravi_step_{iteration_number}": step})
            for k in info_log_env.keys():
                if k not in key_already_seen:
                    log_runner.define_metric(k, step_metric=custom_step)
                    key_already_seen.add(k)
        else:
            dict_loss = {}
        if step % log_loss_every_k_step == 0 and step > train_step_start:
            log_runner.log(data=dict_loss)

        if step % save_every_k_step == 0 and step > train_step_start:
            agent.save(
                path=f"{artifact_folder}/sac_hydra_{run_name}_{iteration_number}_step_{step}.pth"
            )
    agent.save(
        path=f"{artifact_folder}/sac_hydra_{run_name}_{iteration_number}_final_step_{step}.pth"  # type: ignore
    )
    agent.load(f"{artifact_folder}/sac_hydra_{run_name}_{iteration_number}_best.pth")
    env.close()
