"""Runner for off-policy MAS3AC algorithms."""
import torch
import numpy as np
import torch.nn.functional as F
from mas3ac.runners.off_policy_base_runner import OffPolicyBaseRunner
from wandb import agent
from mas3ac.utils.envs_tools import check


class OffPolicyMARunner(OffPolicyBaseRunner):
    """Runner for off-policy HA algorithms."""

    def train(self):
        """Train the model"""
        self.total_it += 1
        if self.fixed_order:
            agent_order = list(range(self.num_agents))
        data = self.buffer.sample()
        (
            sp_share_obs,
            sp_obs,
            sp_actions,
            sp_available_actions,
            sp_reward,
            sp_cost,
            sp_stability_cost,
            sp_done,
            sp_valid_transition,
            sp_term,
            sp_next_share_obs,
            sp_next_obs,
            sp_next_available_actions,
            sp_gamma,
        ) = data
        sp_valid_transition_tensor = torch.tensor(sp_valid_transition, device=self.device)

        # to update the critic functions for agent agent_id, first we need to get the next actions and logp_actions with the old policies
        next_actions = []
        next_logp_actions = []
        with torch.no_grad():
            for inner_agent_id in agent_order:
                next_action, next_logp_action = self.actor[
                    inner_agent_id
                ].get_actions_with_logprobs(
                    sp_next_obs[inner_agent_id],
                    sp_next_available_actions[inner_agent_id]
                    if sp_next_available_actions is not None
                    else None,
                )
                next_actions.append(next_action)
                next_logp_actions.append(next_logp_action)

        # to update the policies, we first store the actions and logp_actions of all agents with the old policies
        actions = []
        logp_actions = []
        with torch.no_grad():
            for inner_agent_id_action in agent_order:
                action, logp_action = self.actor[
                    inner_agent_id_action
                ].get_actions_with_logprobs(
                    sp_obs[inner_agent_id_action],
                    sp_available_actions[inner_agent_id_action]
                    if sp_available_actions is not None
                    else None,
                )
                actions.append(action)
                logp_actions.append(logp_action)

        #start to update all agents in sequence
        for agent_id in agent_order:
            if self.args["algo"] == "mas3ac":  # in fact must be "mas3ac"
                # train critic, barrier and Lyapunov networks for the agent agent_id
                self.critic[agent_id].turn_on_grad()
                self.critic[agent_id].train(agent_id,
                    sp_share_obs,
                    sp_actions,
                    sp_reward,
                    sp_done,
                    sp_valid_transition,
                    sp_term,
                    sp_next_share_obs,
                    next_actions,
                    next_logp_actions,
                    sp_gamma,
                    self.value_normalizer[agent_id],
                )
                self.critic[agent_id].turn_off_grad()

                self.barrier[agent_id].turn_on_grad()
                self.barrier[agent_id].train(agent_id,
                    sp_share_obs,
                    sp_actions,
                    sp_cost,
                    sp_done,
                    sp_valid_transition,
                    sp_term,
                    sp_next_share_obs,
                    next_actions,
                    next_logp_actions,
                    sp_gamma,
                    self.barrier_normalizer[agent_id],)
                self.barrier[agent_id].turn_off_grad()

                self.Lyapunov[agent_id].turn_on_grad()
                self.Lyapunov[agent_id].train(agent_id,
                    sp_share_obs,
                    sp_actions,
                    sp_stability_cost,
                    sp_done,
                    sp_valid_transition,
                    sp_term,
                    sp_next_share_obs,
                    next_actions,
                    next_logp_actions,
                    sp_gamma,
                    self.Lyapunov_normalizer[agent_id],)
                self.Lyapunov[agent_id].turn_off_grad()                #Until here the critic and barrier and Lyapunov networks are updated with the new policies

                # train actor for the agent agent_id
                if self.total_it % self.policy_freq == 0:
                    self.actor[agent_id].turn_on_grad()
                    # For calculating actor_loss1, needs to insert new action and logp_action into the lists with gradient information
                    actions[agent_id], logp_actions[agent_id] = self.actor[
                        agent_id
                    ].get_actions_with_logprobs(
                        sp_obs[agent_id],
                        sp_available_actions[agent_id]
                        if sp_available_actions is not None
                        else None,
                    )

                    # For calculating actor_loss2, needs to insert new next action and next logp_action into the lists with gradient information
                    next_actions[agent_id], next_logp_actions[agent_id] = self.actor[
                        agent_id
                    ].get_actions_with_logprobs(
                        sp_next_obs[agent_id],
                        sp_next_available_actions[agent_id]
                        if sp_next_available_actions is not None
                        else None,
                    )

                    logp_action = logp_actions[agent_id]
                    actions_t = torch.cat(actions, dim=-1)

                    value_pred = self.critic[agent_id].get_values(sp_share_obs[agent_id], actions_t)
                    if self.algo_args["algo"]["use_policy_active_masks"]:
                        if self.state_type == "EP":
                            actor_loss1 = (
                                -torch.sum(
                                    (value_pred - self.alpha[agent_id] * logp_action)
                                    * sp_valid_transition_tensor[agent_id]
                                )
                                / sp_valid_transition_tensor[agent_id].sum()
                            )
                        elif self.state_type == "FP":
                            valid_transition = sp_valid_transition_tensor[agent_id]

                            actor_loss1 = (
                                -torch.sum(
                                    (value_pred - self.alpha[agent_id] * logp_action)
                                    * valid_transition
                                )
                                / valid_transition.sum()
                            )
                    else:
                        actor_loss1 = -torch.mean(
                            value_pred - self.alpha[agent_id] * logp_action
                        )


                    ###########The part for Barrier and Lyapunov action loss ###########

                    gamma_l = 0.01
                    gamma_b = 0.05
                    tpdv = self.critic[0].tpdv

                    #prepare the state and action for calculating the current barrier and Lyapunov value
                    state_batch_critic = sp_share_obs[agent_id]
                    batch_size_used = state_batch_critic.shape[0]
                    action_batch = sp_actions
                    action_batch = check(action_batch).to(**tpdv)
                    action_batch = torch.cat([action_batch[i] for i in range(action_batch.shape[0])], dim=-1)   # (batch_size, wholedim)

                    #calculate barrier value for current time
                    current_barrier_value = self.barrier[agent_id].get_values(state_batch_critic, action_batch) # (batch_size, 1)
                    #calculate Lyapunov value for current time
                    current_lyapunov_value = self.Lyapunov[agent_id].get_values(state_batch_critic, action_batch) # (batch_size, 1)

                    # prepare the state and action for calculating the next barrier and Lyapunov value
                    next_state_batch_critic = sp_next_share_obs[agent_id]
                    next_action_batch = next_actions
                    next_action_batch = torch.cat(next_action_batch, dim=-1).to(**tpdv)

                    #calculate barrier value for next time
                    next_barrier_value = self.barrier[agent_id].get_values(next_state_batch_critic, next_action_batch)   # (batch_size, 1)
                    #calculate Lyapunov value for next time
                    next_lyapunov_value = self.Lyapunov[agent_id].get_values(next_state_batch_critic, next_action_batch)    # (batch_size, 1)


                    #calculate the barrier and Lyapunov loss
                    # First detach the current values of barrier and Lyapunov since they should not contain gradient information
                    current_barrier_value_detach = current_barrier_value.detach() # (batch_size, 1)
                    current_lyapunov_value_detach = current_lyapunov_value.detach() # (batch_size, 1)
                    # # Then calculate the Lyapunov loss
                    lya_term = ((next_lyapunov_value - current_lyapunov_value_detach) / 1.0) + gamma_l * current_lyapunov_value_detach   # (batch_size, 1)

                    # Constrained part, use softplus to have a smooth version of ReLU, and add some coefficients for numerical stability
                    hinge = torch.nn.functional.softplus(lya_term)
                    quadratic_penalty = 0.5 * hinge ** 2
                    total_lya_loss = hinge + 0.01 * quadratic_penalty
                    total_lya_loss_mean_detach = total_lya_loss.mean().detach()


                    barrier_term = -(next_barrier_value - current_barrier_value_detach) - gamma_b * current_barrier_value_detach
                    hinge_barrier = torch.nn.functional.relu(barrier_term)
                    quadratic_penalty_barrier = 0.5 * hinge_barrier ** 2
                    total_barrier_loss = hinge_barrier + 0.001 * quadratic_penalty_barrier
                    total_barrier_loss_mean_detach = total_barrier_loss.mean().detach()


                    # Update the Lagrangian multipliers for barrier and Lyapunov constraints by gradient ascend
                    previous_lambda_barrier = self.barrier_lambda_values[agent_id]
                    new_lambda_barrier = previous_lambda_barrier + self.barrier_lambda_lr[agent_id] * total_barrier_loss_mean_detach
                    real_new_lambda_barrier = torch.clamp(new_lambda_barrier, 0.01, 10.0)
                    self.barrier_lambda_values[agent_id] = real_new_lambda_barrier

                    previous_lambda_lyapunov = self.Lyapunov_lambda_values[agent_id]
                    new_lambda_lyapunov = previous_lambda_lyapunov + self.Lyapunov_lambda_lr[agent_id] * total_lya_loss_mean_detach
                    real_new_lambda_lyapunov = torch.clamp(new_lambda_lyapunov, 0.01, 10.0)
                    self.Lyapunov_lambda_values[agent_id] = real_new_lambda_lyapunov

                    # To balance the Lyapunov and barrier loss
                    ratio = float(total_barrier_loss_mean_detach / (total_lya_loss_mean_detach + 1e-6))
                    if ratio < 0.05:        # The ratio cannot be too small.
                        ratio = 0.05

                    if self.algo_args["algo"]["use_policy_active_masks"]:
                        if self.state_type == "EP":
                            actor_loss2 = (
                                torch.sum(
                                    (float(self.Lyapunov_lambda_values[agent_id]) * total_lya_loss * ratio + float(self.barrier_lambda_values[agent_id]) * total_barrier_loss)
                                    * sp_valid_transition_tensor[agent_id]
                                )
                                / sp_valid_transition_tensor[agent_id].sum()
                            )
                        elif self.state_type == "FP":
                            valid_transition = sp_valid_transition_tensor[agent_id]
                            actor_loss2 = (
                                torch.sum(
                                    (float(self.Lyapunov_lambda_values[agent_id]) * total_lya_loss * ratio + float(self.barrier_lambda_values[agent_id]) * total_barrier_loss)
                                    * valid_transition
                                )
                                / valid_transition.sum()
                            )
                    else:
                        actor_loss2 = torch.mean(
                            float(self.Lyapunov_lambda_values[agent_id]) * total_lya_loss * ratio + float(self.barrier_lambda_values[agent_id]) * total_barrier_loss
                        )

                    total_actor_loss = actor_loss1 + actor_loss2

                    self.actor[agent_id].actor_optimizer.zero_grad()
                    total_actor_loss.backward()
                    self.actor[agent_id].actor_optimizer.step()
                    self.actor[agent_id].turn_off_grad()

                    # train this agent's alpha
                    if self.algo_args["algo"]["auto_alpha"]:
                        log_prob = (
                            logp_actions[agent_id].detach()
                            + self.target_entropy[agent_id]
                        )
                        alpha_loss = -(self.log_alpha[agent_id] * log_prob).mean()
                        self.alpha_optimizer[agent_id].zero_grad()
                        alpha_loss.backward()
                        self.alpha_optimizer[agent_id].step()
                        self.alpha[agent_id] = torch.exp(
                            self.log_alpha[agent_id].detach()
                        )

                    # since the actor of agent_id is updated, we need to update the next actions and logp_actions list of this agent for the next loop such that the update in the next loop can have the latest information to update the next agent, when a previous policy has been newly updated
                    next_actions[agent_id], next_logp_actions[agent_id] = self.actor[
                        agent_id
                    ].get_actions_with_logprobs(
                        sp_next_obs[agent_id],
                        sp_next_available_actions[agent_id]
                        if sp_next_available_actions is not None
                        else None,
                    )

                    # since the actor is updated, we need to update the actions and logp_actions list of this agent for the next loop such that the update in the next loop can have the latest information to update the next agent, when a previous policy has been newly updated
                    actions[agent_id], _ = self.actor[
                        agent_id
                    ].get_actions_with_logprobs(
                        sp_obs[agent_id],
                        sp_available_actions[agent_id]
                        if sp_available_actions is not None
                        else None,
                    )

                    # train critic's alpha
                    if self.algo_args["algo"]["auto_alpha"]:
                        self.critic[agent_id].update_alpha(logp_actions[agent_id], self.target_entropy[agent_id])
                        self.Lyapunov[agent_id].update_alpha(logp_actions[agent_id], self.target_entropy[agent_id])

                    # soft update for critic, barrier and Lyapunov networks
                    self.critic[agent_id].soft_update()
                    self.barrier[agent_id].soft_update()
                    self.Lyapunov[agent_id].soft_update()

            else:
                raise RuntimeError(
                    f"Unsupported algorithm: {self.args['algo']}. "
                    "This training loop only supports 'mas3ac'."
                )
