from copy import deepcopy
from typing import Optional, Tuple, Dict, Union, Any, Callable, List
import json
import gym  # type: ignore
import torch
from torch import optim
from tqdm import tqdm  # type: ignore

# import wandb
from .base_agent import IAgent
from .sac import SAC, SACMPC, SACMPCHydra
from .replay import OffPolicyMemory, BatchTransition


def train_sac(
    agent: IAgent,
    optimizer_actor: optim.Optimizer,
    optimizer_critic: optim.Optimizer,
    env: gym.Env,
    sac_algo: SAC,
    buffer: OffPolicyMemory,
    max_step: int,
    train_step: int,
    train_step_start: int,
    experiment_name: str,
    log_loss_every_k_step: int = 1000,
    artifact_folder: str = "result",
    run_name: str = "Default",
    params: Optional[Dict[str, Any]] = None,
):
    """Sac trainer

    Args:
        agent (IAgent): Explicit
        optimizer_actor (optim.Optimizer): Explicit
        optimizer_critic (optim.Optimizer): explicit
        env (gym.Env): Explicit
        sac_algo (SAC): Explicit
        buffer (OffPolicyMemory): Explicit
        max_step (int): number of step for training
        train_step (int): train every
        train_step_start (int): Number of step before train
        experiment_name (str): Name of the experiment for mlflow
        log_loss_every_k_step (int) : Log in mlflow every k step of gradient descent .Defaults to `1000`.
        run_name (str): Name of the run
    """
    if params is None:
        params = {}

    obs = env.reset()
    for step in tqdm(range(1, max_step + 1), ncols=80):
        with torch.no_grad():
            if step > train_step_start:
                mass = agent.action(observation=obs)
                actions = mass.sample()
            else:
                actions = env.action_space.sample()
            new_obs, reward, done, info = env.step(actions)
            storage = BatchTransition(
                state=obs,
                action=actions,
                reward=reward,  # type: ignore
                next_state=new_obs,
                done=done,  # type: ignore
            )
            buffer.append(storage)
            obs = new_obs
            info_log_env = _reduce_vec_info(infos=info)
        if step % train_step == 0 and step > train_step_start:
            dict_loss = sac_algo.update(
                replay=buffer,
                agent=agent,
                optimizer_actor=optimizer_actor,
                optimizer_critic=optimizer_critic,
                scheduler=None,
            )
        else:
            dict_loss = {}
    env.close()


def train_sac_mpc(
    agent: IAgent,
    optimizer_actor: optim.Optimizer,
    optimizer_critic: optim.Optimizer,
    optimizer_mpc: optim.Optimizer,
    env: gym.Env,
    sac_algo: SACMPC,
    buffer: OffPolicyMemory,
    max_step: int,
    train_step: int,
    train_step_start: int,
    experiment_name: str,
    log_loss_every_k_step: int = 1000,
    artifact_folder: str = "result",
    run_name: str = "Default",
    params: Optional[Dict[str, Any]] = None,
):
    """Sac trainer

    Args:
        agent (IAgent): Explicit
        optimizer_actor (optim.Optimizer): Explicit
        optimizer_critic (optim.Optimizer): explicit
        env (gym.Env): Explicit
        sac_algo (SAC): Explicit
        buffer (OffPolicyMemory): Explicit
        max_step (int): number of step for training
        train_step (int): train every
        train_step_start (int): Number of step before train
        experiment_name (str): Name of the experiment for mlflow
        log_loss_every_k_step (int) : Log in mlflow every k step of gradient descent .Defaults to `1000`.
        run_name (str): Name of the run
    """
    if params is None:
        params = {}

    obs = env.reset()
    for step in tqdm(range(1, max_step + 1)):
        with torch.no_grad():
            if step > train_step_start:
                mass = agent.action(observation=obs)
                actions = mass.sample()
            else:
                actions = env.action_space.sample()
            new_obs, reward, done, info = env.step(actions)
            storage = BatchTransition(
                state=obs,
                action=actions,
                reward=reward,  # type: ignore
                next_state=new_obs,  # type: ignore
                done=done,  # type: ignore
            )
            buffer.append(storage)
            obs = new_obs
            info_log_env = _reduce_vec_info(infos=info)
        if step % train_step == 0 and step > train_step_start:
            dict_loss = sac_algo.update(
                replay=buffer,
                agent=agent,
                optimizer_actor=optimizer_actor,
                optimizer_critic=optimizer_critic,
                optimizer_mpc=optimizer_mpc,
                scheduler=None,
            )
        else:
            dict_loss = {}
    agent.save(path=f"{artifact_folder}/sacmpc_{run_name}.pth")

    env.close()


def _reduce_vec_info(
    infos: Union[Dict[str, float], Tuple[Dict[str, float]]]
) -> Dict[str, float]:
    if isinstance(infos, dict):
        if "episode" in infos:
            infos["episode_length"] = infos["episode"]["l"]
            infos["episode_reward"] = infos["episode"]["r"]
            infos["episode_time"] = infos["episode"]["t"]
            del infos["episode"]
        return infos

    flatten_info = {
        "episode_length": 0,
        "episode_reward": 0,
        "episode_time": 0,
    }

    nb_episode_end = 0
    for info in infos:
        if "episode" in info:
            flatten_info["episode_length"] += info["episode"]["l"] / len(infos)  # type: ignore
            flatten_info["episode_reward"] += info["episode"]["r"] / len(infos)  # type: ignore
            flatten_info["episode_time"] += info["episode"]["t"] / len(infos)  # type: ignore
            nb_episode_end += 1
    return {} if nb_episode_end == 0 else flatten_info


def train_ravi_mpc_no_reg_minimal(
    agent: IAgent,
    optimizer_actor: optim.Optimizer,
    optimizer_critic: optim.Optimizer,
    optimizer_mpc: optim.Optimizer,
    optimizer_critic_no_reg: optim.Optimizer,
    env: gym.Env,
    sac_algo: SACMPCHydra,
    buffer: OffPolicyMemory,
    max_step: int,
    train_step: int,
    train_step_start: int,
    log_runner: "wandb.sdk.wandb_run.Run",  # type: ignore
    iteration_number: int,
    log_loss_every_k_step: int = 1000,
    save_every_k_step: int = 300_000,
    artifact_folder: str = "result",
    run_name: str = "Default",
):
    """Sac trainer

    Args:
        agent (IAgent): Explicit
        optimizer_actor (optim.Optimizer): Explicit
        optimizer_critic (optim.Optimizer): explicit
        optimizer_mpc (optim.Optimizer): explicit
        optimizer_critic_no_reg (optim.Optimizer): explicit
        env (gym.Env): Explicit
        sac_algo (SAC): Explicit
        buffer (OffPolicyMemory): Explicit
        max_step (int): number of step for training
        train_step (int): train every
        train_step_start (int): Number of step before train
        experiment_name (str): Name of the experiment for mlflow
        log_loss_every_k_step (int) : Log in mlflow every k step of gradient descent .Defaults to `1000`.
        run_name (str): Name of the run
    """
    obs = env.reset()
    custom_step = f"ravi_step_{iteration_number}"
    key_already_seen = set()
    log_runner.define_metric(f"ravi_step_{iteration_number}")
    for step in tqdm(range(1, max_step + 1)):
        with torch.no_grad():
            if step > train_step_start:
                mass = agent.action(observation=obs)
                actions = mass.sample()
            else:
                actions = env.action_space.sample()
            new_obs, reward, done, info = env.step(actions)
            storage = BatchTransition(
                state=obs,
                action=actions,
                reward=reward,  # type: ignore
                next_state=new_obs,
                done=done,  # type: ignore
            )
            buffer.append(storage)
            obs = new_obs
            info_log_env = _reduce_vec_info(infos=info)
            info_log_env = {
                f"{k}_{iteration_number}": v for k, v in info_log_env.items()
            }
            info_log_env.update({custom_step: step})
            for k in info_log_env.keys():
                if k not in key_already_seen:
                    log_runner.define_metric(k, step_metric=custom_step)
                    key_already_seen.add(k)

        log_runner.log(data=info_log_env)
        if step % train_step == 0 and step > train_step_start:
            dict_loss = sac_algo.update(
                replay=buffer,
                agent=agent,
                optimizer_actor=optimizer_actor,
                optimizer_critic=optimizer_critic,
                optimizer_mpc=optimizer_mpc,
                optimizer_critic_no_reg=optimizer_critic_no_reg,
                scheduler=None,
            )
            dict_loss = {f"{k}_{iteration_number}": v for k, v in dict_loss.items()}
            dict_loss.update({f"ravi_step_{iteration_number}": step})
            for k in info_log_env.keys():
                if k not in key_already_seen:
                    log_runner.define_metric(k, step_metric=custom_step)
                    key_already_seen.add(k)
        else:
            dict_loss = {}
        if step % log_loss_every_k_step == 0 and step > train_step_start:
            log_runner.log(data=dict_loss)

        if step % save_every_k_step == 0 and step > train_step_start:
            agent.save(
                path=f"{artifact_folder}/sac_hydra_{run_name}_{iteration_number}_step_{step}.pth"
            )
    agent.save(
        path=f"{artifact_folder}/sac_hydra_{run_name}_{iteration_number}_final_step_{step}.pth"  # type: ignore
    )

    env.close()
