"""Base runner for off-policy algorithms."""
import os
import time
import torch
import numpy as np
import setproctitle
from mas3ac.common.valuenorm import ValueNorm
from torch.distributions import Categorical
from mas3ac.utils.trans_tools import _t2n
from mas3ac.utils.envs_tools import (
    make_eval_env,
    make_train_env,
    make_render_env,
    set_seed,
    get_num_agents,
)
from mas3ac.utils.models_tools import init_device
from mas3ac.utils.configs_tools import init_dir, save_config, get_task_name
from mas3ac.algorithms.actors import ALGO_REGISTRY
from mas3ac.algorithms.critics import CRITIC_REGISTRY
from mas3ac.common.buffers.off_policy_buffer_ep import OffPolicyBufferEP
from mas3ac.common.buffers.off_policy_buffer_fp import OffPolicyBufferFP
import wandb
import socket


def calculate_decayed_lr(epoch, total_num_epochs, initial_lr):
    """Decreases the learning rate linearly
    Args:
        epoch: (int) current epoch
        total_num_epochs: (int) total number of epochs
        initial_lr: (float) initial learning rate
    """
    learning_rate = initial_lr - (initial_lr * ((epoch - 1) / float(total_num_epochs)))
    return learning_rate

class OffPolicyBaseRunner:
    """Base runner for off-policy algorithms."""

    def __init__(self, args, algo_args, env_args):
        """Initialize the OffPolicyBaseRunner class.
        Args:
            args: command-line arguments parsed by argparse. Three keys: algo, env, exp_name.
            algo_args: arguments related to algo, loaded from config file and updated with unparsed command-line arguments.
            env_args: arguments related to env, loaded from config file and updated with unparsed command-line arguments.
        """
        self.args = args
        self.algo_args = algo_args
        self.env_args = env_args

        if args["algo"] == "mas3ac":
            project_name = "MAS3AC"  + "_" + env_args["scenario"] + "_" + env_args["agent_conf"]
            self.constrained_writer = wandb.init(
                project=project_name,  # Name of the group where data are saved in wandb
                config={**algo_args, **env_args},
                notes=socket.gethostname(),
                dir='wandb_logs',
                group="MASAFEMUJOCO",
            )
        else:
            self.constrained_writer = None

        if "policy_freq" in self.algo_args["algo"]:
            self.policy_freq = self.algo_args["algo"]["policy_freq"]
        else:
            self.policy_freq = 1

        self.state_type = env_args.get("state_type", "EP")
        self.share_param = algo_args["algo"]["share_param"]
        self.fixed_order = algo_args["algo"]["fixed_order"]

        set_seed(algo_args["seed"])
        self.device = init_device(algo_args["device"])

        self.task_name = get_task_name(args["env"], env_args)
        if not self.algo_args["render"]["use_render"]:
            self.run_dir, self.log_dir, self.save_dir, self.writter = init_dir(
                args["env"],
                env_args,
                args["algo"],
                args["exp_name"],
                algo_args["seed"]["seed"],
                logger_path=algo_args["logger"]["log_dir"],
            )
            save_config(args, algo_args, env_args, self.run_dir)
            self.log_file = open(
                os.path.join(self.run_dir, "progress.txt"), "w", encoding="utf-8"
            )
        setproctitle.setproctitle(
            str(args["algo"]) + "-" + str(args["env"]) + "-" + str(args["exp_name"])
        )

        # env
        if self.algo_args["render"]["use_render"]:  # make envs for rendering
            (
                self.envs,
                self.manual_render,
                self.manual_expand_dims,
                self.manual_delay,
                self.env_num,
            ) = make_render_env(args["env"], algo_args["seed"]["seed"], env_args)
        else:
            self.envs = make_train_env(
                args["env"],
                algo_args["seed"]["seed"],
                algo_args["train"]["n_rollout_threads"],
                env_args,
            )
            self.eval_envs = (
                make_eval_env(
                    args["env"],
                    algo_args["seed"]["seed"],
                    algo_args["eval"]["n_eval_rollout_threads"],
                    env_args,
                )
                if algo_args["eval"]["use_eval"]
                else None
            )
        self.num_agents = get_num_agents(args["env"], env_args, self.envs)
        self.agent_deaths = np.zeros(
            (self.algo_args["train"]["n_rollout_threads"], self.num_agents, 1)
        )

        self.action_spaces = self.envs.action_space
        for agent_id in range(self.num_agents):
            self.action_spaces[agent_id].seed(algo_args["seed"]["seed"] + agent_id + 1)

        self.policy_init_lr = self.algo_args["model"]["lr"]
        self.lambda_init_lr = self.algo_args["model"]["lagrangian_lr"]

        if self.share_param:
            self.actor = []
            agent = ALGO_REGISTRY[args["algo"]](
                {**algo_args["model"], **algo_args["algo"]},
                self.envs.observation_space[0],
                self.envs.action_space[0],
                device=self.device,
            )
            self.actor.append(agent)
            for agent_id in range(1, self.num_agents):
                assert (
                    self.envs.observation_space[agent_id]
                    == self.envs.observation_space[0]
                ), "Agents have heterogeneous observation spaces, parameter sharing is not valid."
                assert (
                    self.envs.action_space[agent_id] == self.envs.action_space[0]
                ), "Agents have heterogeneous action spaces, parameter sharing is not valid."
                self.actor.append(self.actor[0])
        else:
            self.actor = []
            self.barrier_lambda_values = []
            self.Lyapunov_lambda_values = []
            self.barrier_lambda_lr = []
            self.Lyapunov_lambda_lr = []
            for agent_id in range(self.num_agents):
                agent = ALGO_REGISTRY[args["algo"]](
                    {**algo_args["model"], **algo_args["algo"]},
                    self.envs.observation_space[agent_id],
                    self.envs.action_space[agent_id],
                    lr = self.policy_init_lr[agent_id],
                    device=self.device,
                )
                self.actor.append(agent)
                self.barrier_lambda_values.append(self.algo_args["model"]["lagrangian_value"])
                self.Lyapunov_lambda_values.append(self.algo_args["model"]["lagrangian_value"])
                self.barrier_lambda_lr.append(self.lambda_init_lr[agent_id])
                self.Lyapunov_lambda_lr.append(self.lambda_init_lr[agent_id])

        if not self.algo_args["render"]["use_render"]:
            self.critic = []
            for agent_id in range(self.num_agents):
                agent_critic = CRITIC_REGISTRY[args["algo"]](
                    {**algo_args["train"], **algo_args["model"], **algo_args["algo"]},
                    self.envs.share_observation_space[0],
                    self.envs.action_space,
                    self.num_agents,
                    self.state_type,
                    device=self.device,
                )
                self.critic.append(agent_critic)

            self.barrier = []
            for agent_id in range(self.num_agents):
                agent_barrier = CRITIC_REGISTRY[args["algo"]](
                    {**algo_args["train"], **algo_args["model"], **algo_args["algo"]},
                    self.envs.share_observation_space[0],
                    self.envs.action_space,
                    self.num_agents,
                    self.state_type,
                    device=self.device,
                )
                self.barrier.append(agent_barrier)

            self.Lyapunov = []
            for agent_id in range(self.num_agents):
                agent_lyapunov = CRITIC_REGISTRY[args["algo"]](
                    {**algo_args["train"], **algo_args["model"], **algo_args["algo"]},
                    self.envs.share_observation_space[0],
                    self.envs.action_space,
                    self.num_agents,
                    self.state_type,
                    device=self.device,
                )
                self.Lyapunov.append(agent_lyapunov)

            if self.state_type == "EP":
                self.buffer = OffPolicyBufferEP(
                    {**algo_args["train"], **algo_args["model"], **algo_args["algo"]},
                    self.envs.share_observation_space[0],
                    self.num_agents,
                    self.envs.observation_space,
                    self.envs.action_space,
                )
            elif self.state_type == "FP":
                self.buffer = OffPolicyBufferFP(
                    {**algo_args["train"], **algo_args["model"], **algo_args["algo"]},
                    self.envs.share_observation_space[0],
                    self.num_agents,
                    self.envs.observation_space,
                    self.envs.action_space,
                )
            else:
                raise NotImplementedError

        self.value_normalizer = []
        self.barrier_normalizer = []
        self.Lyapunov_normalizer = []
        if (
            "use_valuenorm" in self.algo_args["train"].keys()
            and self.algo_args["train"]["use_valuenorm"]
        ):
            for agent_id in range(self.num_agents):
                self.value_normalizer.append(
                    ValueNorm(
                        1,
                        device=self.device,
                    )
                )
                self.barrier_normalizer.append(ValueNorm(1, device=self.device))
                self.Lyapunov_normalizer.append(ValueNorm(1, device=self.device))
        else:
            for agent_id in range(self.num_agents):
                self.value_normalizer.append(None)
                self.barrier_normalizer.append(None)
                self.Lyapunov_normalizer.append(None)

        if self.algo_args["train"]["model_dir"] is not None:
            self.restore()

        self.total_it = 0  # total iteration

        if (
            "auto_alpha" in self.algo_args["algo"].keys()
            and self.algo_args["algo"]["auto_alpha"]
        ):
            self.target_entropy = []
            for agent_id in range(self.num_agents):
                if (
                    self.envs.action_space[agent_id].__class__.__name__ == "Box"
                ):
                    self.target_entropy.append(
                        -np.prod(self.envs.action_space[agent_id].shape)
                    )
                else:
                    self.target_entropy.append(
                        -0.98
                        * np.log(1.0 / np.prod(self.envs.action_space[agent_id].shape))
                    )
            self.log_alpha = []
            self.alpha_optimizer = []
            self.alpha = []
            for agent_id in range(self.num_agents):
                _log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
                self.log_alpha.append(_log_alpha)
                self.alpha_optimizer.append(
                    torch.optim.Adam(
                        [_log_alpha], lr=self.algo_args["algo"]["alpha_lr"]
                    )
                )
                self.alpha.append(torch.exp(_log_alpha.detach()))
        elif "alpha" in self.algo_args["algo"].keys():
            self.alpha = [self.algo_args["algo"]["alpha"]] * self.num_agents

    def run(self):
        """Run the training (or rendering) pipeline."""
        if self.algo_args["render"]["use_render"]:
            self.render()
            return
        self.train_episode_rewards = np.zeros(
            self.algo_args["train"]["n_rollout_threads"]
        )
        self.train_episode_costs = np.zeros(
            self.algo_args["train"]["n_rollout_threads"]
        )
        self.train_episode_stability_costs = np.zeros(
            self.algo_args["train"]["n_rollout_threads"]
        )
        self.done_episodes_rewards = []
        self.done_episode_costs = []
        self.done_episode_stability_costs = []
        # warmup
        print("start warmup")
        obs, share_obs, available_actions = self.warmup()
        print("finish warmup, start training")
        # train and eval
        steps = (
            self.algo_args["train"]["num_env_steps"]
            // self.algo_args["train"]["n_rollout_threads"]
        )
        update_num = int(  # update number per train
            self.algo_args["train"]["update_per_train"]
            * self.algo_args["train"]["train_interval"]
        )
        for step in range(1, steps + 1):
            actions = self.get_actions(
                obs, available_actions=available_actions, add_random=True
            )
            (
                new_obs,
                new_share_obs,
                rewards,
                costs,
                stability_costs,
                dones,
                infos,
                new_available_actions,
            ) = self.envs.step(
                actions
            )
            next_obs = new_obs.copy()
            next_share_obs = new_share_obs.copy()
            next_available_actions = new_available_actions.copy()
            data = (
                share_obs,
                obs.transpose(1, 0, 2),
                actions.transpose(1, 0, 2),
                available_actions.transpose(1, 0, 2)
                if len(np.array(available_actions).shape) == 3
                else None,
                rewards,
                costs,
                stability_costs,
                dones,
                infos,
                next_share_obs,
                next_obs,
                next_available_actions.transpose(1, 0, 2)
                if len(np.array(available_actions).shape) == 3
                else None,
            )
            self.insert(data)
            obs = new_obs
            share_obs = new_share_obs
            available_actions = new_available_actions
            if step % self.algo_args["train"]["train_interval"] == 0:
                if self.algo_args["train"]["use_linear_lr_decay"]:
                    if self.share_param:
                        self.actor[0].lr_decay(step, steps)
                    else:
                        for agent_id in range(self.num_agents):
                            self.actor[agent_id].lr_decay(step, steps)
                            self.critic[agent_id].lr_decay(step, steps)
                            self.barrier[agent_id].lr_decay(step, steps)
                            self.Lyapunov[agent_id].lr_decay(step, steps)
                            self.barrier_lambda_lr[agent_id] = calculate_decayed_lr(step, steps, self.lambda_init_lr[agent_id])
                            self.Lyapunov_lambda_lr[agent_id] = calculate_decayed_lr(step, steps, self.lambda_init_lr[agent_id])
                for _ in range(update_num):
                    self.train()
            if step % self.algo_args["train"]["eval_interval"] == 0:
                cur_step = (
                    self.algo_args["train"]["warmup_steps"]
                    + step * self.algo_args["train"]["n_rollout_threads"]
                )
                print(
                    f"Env {self.args['env']} Task {self.task_name} Algo {self.args['algo']} Exp {self.args['exp_name']} Step {cur_step} / {self.algo_args['train']['num_env_steps']}, average step reward in buffer: {self.buffer.get_mean_rewards()}.\n"
                )
                if len(self.done_episodes_rewards) > 0:
                    aver_episode_rewards = np.mean(self.done_episodes_rewards)
                    aver_episode_costs = np.mean(self.done_episode_costs)
                    aver_episode_stability_costs = np.mean(self.done_episode_stability_costs)

                    print(
                        "Some episodes done, average episode reward is {}, aver_episode_costs is {},  aver_episode_stability_costs is {} \n".format(
                            aver_episode_rewards, -aver_episode_costs, aver_episode_stability_costs
                        )
                    )
                    self.constrained_writer.log({
                        'train_episode_rewards': aver_episode_rewards,
                        'train_episode_costs': -aver_episode_costs,
                        'train_episode_stability_costs': aver_episode_stability_costs,
                        'total_num_steps': cur_step,
                    })
                    self.log_file.write(
                        ",".join(map(str, [cur_step, aver_episode_rewards])) + "\n"
                    )
                    self.log_file.flush()
                    self.done_episodes_rewards = []
                    self.done_episode_costs = []
                    self.done_episode_stability_costs = []
                self.save()

    def warmup(self):
        """Warmup the replay buffer with random actions"""
        warmup_steps = (
            self.algo_args["train"]["warmup_steps"]
            // self.algo_args["train"]["n_rollout_threads"]
        )
        obs, share_obs, available_actions = self.envs.reset()
        for _ in range(warmup_steps):
            actions = self.sample_actions(available_actions)
            (
                new_obs,
                new_share_obs,
                rewards,
                costs,
                stability_costs,
                dones,
                infos,
                new_available_actions,
            ) = self.envs.step(actions)
            next_obs = new_obs.copy()
            next_share_obs = new_share_obs.copy()
            next_available_actions = new_available_actions.copy()
            data = (
                share_obs,
                obs.transpose(1, 0, 2),
                actions.transpose(1, 0, 2),
                available_actions.transpose(1, 0, 2)
                if len(np.array(available_actions).shape) == 3
                else None,
                rewards,
                costs,
                stability_costs,
                dones,
                infos,
                next_share_obs,
                next_obs,
                next_available_actions.transpose(1, 0, 2)
                if len(np.array(available_actions).shape) == 3
                else None,
            )
            self.insert(data)
            obs = new_obs
            share_obs = new_share_obs
            available_actions = new_available_actions
        return obs, share_obs, available_actions

    def insert(self, data):
        (
            share_obs,
            obs,
            actions,
            available_actions,
            rewards,
            costs,
            stability_costs,
            dones,
            infos,
            next_share_obs,
            next_obs,
            next_available_actions,
        ) = data

        dones_env = np.all(dones, axis=1)
        reward_env = np.mean(rewards, axis=1).flatten()
        cost_env = np.mean(costs, axis=1).flatten()
        stability_cost_env = np.mean(stability_costs, axis=1).flatten()
        self.train_episode_rewards += reward_env
        self.train_episode_costs += cost_env
        self.train_episode_stability_costs += stability_cost_env


        valid_transitions = 1 - self.agent_deaths

        self.agent_deaths = np.expand_dims(dones, axis=-1)

        if self.state_type == "EP":
            terms = np.full((self.algo_args["train"]["n_rollout_threads"], 1), False)
            for i in range(self.algo_args["train"]["n_rollout_threads"]):
                if dones_env[i]:
                    if not (
                        "bad_transition" in infos[i][0].keys()
                        and infos[i][0]["bad_transition"] == True
                    ):
                        terms[i][0] = True
        elif self.state_type == "FP":
            terms = np.full(
                (self.algo_args["train"]["n_rollout_threads"], self.num_agents, 1),
                False,
            )
            for i in range(self.algo_args["train"]["n_rollout_threads"]):
                for agent_id in range(self.num_agents):
                    if dones[i][agent_id]:
                        if not (
                            "bad_transition" in infos[i][agent_id].keys()
                            and infos[i][agent_id]["bad_transition"] == True
                        ):
                            terms[i][agent_id][0] = True

        for i in range(self.algo_args["train"]["n_rollout_threads"]):
            if dones_env[i]:
                self.done_episodes_rewards.append(self.train_episode_rewards[i])
                self.train_episode_rewards[i] = 0
                self.done_episode_costs.append(self.train_episode_costs[i])
                self.train_episode_costs[i] = 0
                self.done_episode_stability_costs.append(
                    self.train_episode_stability_costs[i]
                )
                self.train_episode_stability_costs[i] = 0
                self.agent_deaths = np.zeros(
                    (self.algo_args["train"]["n_rollout_threads"], self.num_agents, 1)
                )
                if "original_obs" in infos[i][0]:
                    next_obs[i] = infos[i][0]["original_obs"].copy()
                if "original_state" in infos[i][0]:
                    next_share_obs[i] = infos[i][0]["original_state"].copy()

        if self.state_type == "EP":
            data = (
                share_obs[:, 0],
                obs,
                actions,
                available_actions,
                rewards[:, 0],
                np.expand_dims(dones_env, axis=-1),
                valid_transitions.transpose(1, 0, 2),
                terms,  # (n_threads, 1)
                next_share_obs[:, 0],
                next_obs.transpose(1, 0, 2),
                next_available_actions,
            )
        elif self.state_type == "FP":
            data = (
                share_obs,
                obs,
                actions,
                available_actions,
                rewards,
                costs,
                stability_costs,
                np.expand_dims(dones, axis=-1),
                valid_transitions.transpose(1, 0, 2),
                terms,
                next_share_obs,
                next_obs.transpose(1, 0, 2),
                next_available_actions,
            )

        self.buffer.insert(data)

    def sample_actions(self, available_actions=None):
        """Sample random actions for warmup.
        Args:
            available_actions: (np.ndarray) denotes which actions are available to agent (if None, all actions available),
                                 shape is (n_threads, n_agents, action_number) or (n_threads, ) of None
        Returns:
            actions: (np.ndarray) sampled actions, shape is (n_threads, n_agents, dim)
        """
        actions = []
        for agent_id in range(self.num_agents):
            action = []
            for thread in range(self.algo_args["train"]["n_rollout_threads"]):
                if available_actions[thread] is None:
                    action.append(self.action_spaces[agent_id].sample())
                else:
                    action.append(
                        Categorical(
                            torch.tensor(available_actions[thread, agent_id, :])
                        ).sample()
                    )
            actions.append(action)
        if self.envs.action_space[agent_id].__class__.__name__ == "Discrete":
            return np.expand_dims(np.array(actions).transpose(1, 0), axis=-1)

        return np.array(actions).transpose(1, 0, 2)

    @torch.no_grad()
    def get_actions(self, obs, available_actions=None, add_random=True):
        """Get actions for rollout.
        Args:
            obs: (np.ndarray) input observation, shape is (n_threads, n_agents, dim)
            available_actions: (np.ndarray) denotes which actions are available to agent (if None, all actions available),
                                 shape is (n_threads, n_agents, action_number) or (n_threads, ) of None
            add_random: (bool) whether to add randomness
        Returns:
            actions: (np.ndarray) agent actions, shape is (n_threads, n_agents, dim)
        """
        if self.args["algo"] == "mas3ac":
            actions = []
            for agent_id in range(self.num_agents):
                if (
                    len(np.array(available_actions).shape) == 3
                ):
                    actions.append(
                        _t2n(
                            self.actor[agent_id].get_actions(
                                obs[:, agent_id],
                                available_actions[:, agent_id],
                                add_random,
                            )
                        )
                    )
                else:
                    actions.append(
                        _t2n(
                            self.actor[agent_id].get_actions(
                                obs[:, agent_id], stochastic=add_random
                            )
                        )
                    )
        else:
            actions = []
            for agent_id in range(self.num_agents):
                actions.append(
                    _t2n(self.actor[agent_id].get_actions(obs[:, agent_id], add_random))
                )
        return np.array(actions).transpose(1, 0, 2)

    def train(self):
        """Train the model"""
        raise NotImplementedError


    @torch.no_grad()
    def render(self):
        """Render the model"""
        print("start rendering")
        if self.manual_expand_dims:
            for _ in range(self.algo_args["render"]["render_episodes"]):
                eval_obs, _, eval_available_actions = self.envs.reset()
                eval_obs = np.expand_dims(np.array(eval_obs), axis=0)
                eval_available_actions = np.array([eval_available_actions])
                rewards = 0
                while True:
                    eval_actions = self.get_actions(
                        eval_obs,
                        available_actions=eval_available_actions,
                        add_random=False,
                    )
                    (
                        eval_obs,
                        _,
                        eval_rewards,
                        eval_dones,
                        _,
                        eval_available_actions,
                    ) = self.envs.step(eval_actions[0])
                    rewards += eval_rewards[0][0]
                    eval_obs = np.expand_dims(np.array(eval_obs), axis=0)
                    eval_available_actions = np.array([eval_available_actions])
                    if self.manual_render:
                        self.envs.render()
                    if self.manual_delay:
                        time.sleep(0.1)
                    if eval_dones[0]:
                        print(f"total reward of this episode: {rewards}")
                        break
        else:
            for _ in range(self.algo_args["render"]["render_episodes"]):
                eval_obs, _, eval_available_actions = self.envs.reset()
                rewards = 0
                while True:
                    eval_actions = self.get_actions(
                        eval_obs,
                        available_actions=eval_available_actions,
                        add_random=False,
                    )
                    (
                        eval_obs,
                        _,
                        eval_rewards,
                        eval_dones,
                        _,
                        eval_available_actions,
                    ) = self.envs.step(eval_actions)
                    rewards += eval_rewards[0][0][0]
                    if self.manual_render:
                        self.envs.render()
                    if self.manual_delay:
                        time.sleep(0.1)
                    if eval_dones[0][0]:
                        print(f"total reward of this episode: {rewards}")
                        break
        if "smac" in self.args["env"]:  # replay for smac, no rendering
            if "v2" in self.args["env"]:
                self.envs.env.save_replay()
            else:
                self.envs.save_replay()

    def restore(self):
        """Restore the model"""
        for agent_id in range(self.num_agents):
            self.actor[agent_id].restore(self.algo_args["train"]["model_dir"], agent_id)
        if not self.algo_args["render"]["use_render"]:
            self.critic.restore(self.algo_args["train"]["model_dir"])
            if self.value_normalizer is not None:
                value_normalizer_state_dict = torch.load(
                    str(self.algo_args["train"]["model_dir"])
                    + "/value_normalizer"
                    + ".pt"
                )
                self.value_normalizer.load_state_dict(value_normalizer_state_dict)

    def save(self):
        """Save the model"""
        for agent_id in range(self.num_agents):
            self.actor[agent_id].save(self.save_dir, agent_id)
            self.critic[agent_id].save(self.save_dir, agent_id)
        if self.value_normalizer[0] is not None:
            for agent_id in range(self.num_agents):
                torch.save(
                    self.value_normalizer[agent_id].state_dict(),
                    str(self.save_dir) + "/value_normalizer" + str(agent_id) + ".pt",)


    def close(self):
        """Close environment, writter, and log file."""
        if self.algo_args["render"]["use_render"]:
            self.envs.close()
        else:
            self.envs.close()
            if self.algo_args["eval"]["use_eval"] and self.eval_envs is not self.envs:
                self.eval_envs.close()
            self.writter.export_scalars_to_json(str(self.log_dir + "/summary.json"))
            self.writter.close()
            self.log_file.close()
