__all__=["ApproxContainer","DSACTHcrl"]
import math
import time
from copy import deepcopy
from typing import Any, Optional, Tuple

import torch
import torch.nn as nn
from torch.distributions import Normal
from torch.optim import Adam

from gops.algorithm.base import AlgorithmBase, ApprBase
from gops.create_pkg.create_apprfunc import create_apprfunc
from gops.utils.tensorboard_setup import tb_tags
from gops.utils.gops_typing import DataDict
from gops.utils.common_utils import get_apprfunc_dict
from gops.utils.lagrangian import HCRL


class ApproxContainer(ApprBase):
    """Approximate function container for DSAC.

    Contains one policy and one action value.
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # create q networks
        q_args = get_apprfunc_dict("value", **kwargs)
        self.q1: nn.Module = create_apprfunc(**q_args)
        self.q2: nn.Module = create_apprfunc(**q_args)
        self.q1_target = deepcopy(self.q1)
        self.q2_target = deepcopy(self.q2)

        # create cost q networks
        qc_args = get_apprfunc_dict("cost_value", **kwargs)
        self.qc: nn.Module = create_apprfunc(**qc_args)
        self.qc_target = deepcopy(self.qc)

        # create policy network
        policy_args = get_apprfunc_dict("policy", **kwargs)
        self.policy: nn.Module = create_apprfunc(**policy_args)
        self.policy_target = deepcopy(self.policy)

        # set target network gradients
        for p in self.policy_target.parameters():
            p.requires_grad = False
        for p in self.q1_target.parameters():
            p.requires_grad = False
        for p in self.q2_target.parameters():
            p.requires_grad = False
        for p in self.qc_target.parameters():
            p.requires_grad = False

        # create entropy coefficient
        self.log_alpha = nn.Parameter(torch.tensor(1, dtype=torch.float32))

        # create optimizers
        self.q1_optimizer = Adam(self.q1.parameters(), lr=kwargs["value_learning_rate"])
        self.q2_optimizer = Adam(self.q2.parameters(), lr=kwargs["value_learning_rate"])
        self.qc_optimizer = Adam(self.qc.parameters(), lr=kwargs["value_learning_rate"])
        self.policy_optimizer = Adam(
            self.policy.parameters(), lr=kwargs["policy_learning_rate"]
        )
        self.alpha_optimizer = Adam([self.log_alpha], lr=kwargs["alpha_learning_rate"])

    def create_action_distributions(self, logits):
        return self.policy.get_act_dist(logits)


class DSACTHcrl(AlgorithmBase):
    """DSACTHcrl algorithm with three refinements, higher performance and more stable.

    :param float gamma: discount factor.
    :param float tau: param for soft update of target network.
    :param bool auto_alpha: whether to adjust temperature automatically.
    :param float alpha: initial temperature.
    :param bool bound: whether to bound the q value.
    :param float delay_update: delay update steps for actor.
    :param Optional[float] target_entropy: target entropy for automatic
        temperature adjustment.
    :cost_limit: Cost limit for constraints.
    :lagrangian_multiplier_init: Initial value of the Lagrangian multiplier.
    :lagrangian_multiplier_lr: Learning rate for Lagrangian multiplier updates.
    :lagrangian_upper_bound: Upper bound for the Lagrangian multiplier.
    :eta: Learning rate for parameter updates.
    :beta: Learning rate for Lagrangian multiplier updates.
    :c_init: Initial value for the scaling parameter c.
    :gamma0: Decay rate for scaling parameter c.
    :alpha_lr: Learning rate for alpha optimization.
    :alpha_iters: Number of iterations for alpha optimization.
    """

    def __init__(
        self,
        index: int = 0,
        gamma: float = 0.99,
        tau: float = 0.005,
        alpha: float = math.e,
        auto_alpha: bool = True,
        target_entropy: Optional[float] = None,
        bound: bool = True,
        cost_limit: float = 25.0,
        lagrangian_multiplier_init: float = 0.001,
        lagrangian_multiplier_lr: float = 0.035,
        lagrangian_upper_bound: Optional[float] = None,
        eta: float = 0.01,
        beta: float = 0.01,
        c_init: float = 0.1,
        gamma0: float = 0.9,
        alpha_lr: float = 0.05,
        alpha_iters: int = 50,
        **kwargs: Any,
    ):
        super().__init__(index, **kwargs)
        self.networks = ApproxContainer(**kwargs)
        self.networks.log_alpha.data.fill_(math.log(alpha))
        self.gamma = gamma
        self.tau = tau
        self.auto_alpha = auto_alpha
        if target_entropy is None:
            target_entropy = -kwargs["action_dim"]
        self.target_entropy = target_entropy
        self.bound = bound
        self.delay_update = kwargs["delay_update"]
        self.mean_std1= None
        self.mean_std2= None
        self.tau_b = kwargs.get("tau_b", self.tau)
        
        # Initialize Lagrange multiplier
        self.hcrl = HCRL(
            cost_limit=cost_limit,
            lagrangian_multiplier_init=lagrangian_multiplier_init,
            lagrangian_multiplier_lr=lagrangian_multiplier_lr,
            lagrangian_upper_bound=lagrangian_upper_bound,
            eta=eta,
            beta=beta,
            c_init=c_init,
            gamma0=gamma0,
            alpha_lr=alpha_lr,
            alpha_iters=alpha_iters,
        )


    @property
    def adjustable_parameters(self):
        return (
            "gamma",
            "tau",
            "auto_alpha",
            "alpha",
            "delay_update",
        )

    def local_update(self, data: DataDict, iteration: int) -> dict:
        tb_info = self._compute_gradient(data, iteration)
        self._update(iteration)
        return tb_info

    def get_remote_update_info(
        self, data: DataDict, iteration: int
    ) -> Tuple[dict, dict]:
        tb_info = self._compute_gradient(data, iteration)

        update_info = {
            "q1_grad": [p._grad for p in self.networks.q1.parameters()],
            "q2_grad": [p._grad for p in self.networks.q2.parameters()],
            "qc_grad": [p.grad for p in self.networks.qc.parameters()],
            "policy_grad": [p._grad for p in self.networks.policy.parameters()],
            "lagrangian_multiplier_grad": self.lagrange._lagrangian_multiplier.grad,
            "iteration": iteration,
        }
        if self.auto_alpha:
            update_info.update({"log_alpha_grad":self.networks.log_alpha.grad})

        return tb_info, update_info

    def remote_update(self, update_info: dict):
        iteration = update_info["iteration"]
        q1_grad = update_info["q1_grad"]
        q2_grad = update_info["q2_grad"]
        qc_grad = update_info["qc_grad"]
        policy_grad = update_info["policy_grad"]

        for p, grad in zip(self.networks.q1.parameters(), q1_grad):
            p._grad = grad
        for p, grad in zip(self.networks.q2.parameters(), q2_grad):
            p._grad = grad
        for p, grad in zip(self.networks.qc.parameters(), qc_grad):
            p._grad = grad
        for p, grad in zip(self.networks.policy.parameters(), policy_grad):
            p._grad = grad
        if self.auto_alpha:
            self.networks.log_alpha._grad = update_info["log_alpha_grad"]
        self.lagrange._lagrangian_multiplier.grad = update_info["lagrangian_multiplier_grad"]

        self._update(iteration)

    def _get_alpha(self, requires_grad: bool = False):
        alpha = self.networks.log_alpha.exp()
        if requires_grad:
            return alpha
        else:
            return alpha.item()

    def _compute_gradient(self, data: DataDict, iteration: int):
        start_time = time.time()

        obs = data["obs"]
        logits = self.networks.policy(obs)
        logits_mean, logits_std = torch.chunk(logits, chunks=2, dim=-1)
        policy_mean = torch.tanh(logits_mean).mean().item()
        policy_std = logits_std.mean().item()

        act_dist = self.networks.create_action_distributions(logits)
        new_act, new_log_prob = act_dist.rsample()
        data.update({"new_act": new_act, "new_log_prob": new_log_prob})

        # Compute Q loss
        self.networks.q1_optimizer.zero_grad()
        self.networks.q2_optimizer.zero_grad()
        loss_q, q1, q2, std1, std2, min_std1, min_std2 = self._compute_loss_q(data)
        loss_q.backward()

        # Compute Qc loss
        self.networks.qc_optimizer.zero_grad()
        loss_qc, qc = self._compute_loss_qc(data)
        loss_qc.backward()

        # Freeze Q and Qc networks for policy loss
        for p in self.networks.q1.parameters():
            p.requires_grad = False

        for p in self.networks.q2.parameters():
            p.requires_grad = False
            
        for p in self.networks.qc.parameters():
            p.requires_grad = False
        
        entropy = self.update_policy_grad(data)

        for p in self.networks.q1.parameters():
            p.requires_grad = True
        for p in self.networks.q2.parameters():
            p.requires_grad = True
        for p in self.networks.qc.parameters():
            p.requires_grad = True
            
        # Compute Lagrangian multiplier loss
        c = data["cost"].mean().item()
        self.hcrl.lambda_optimizer.zero_grad()
        lambda_loss = self.hcrl.compute_lambda_loss(c)
        lambda_loss.backward()

        if self.auto_alpha:
            self.networks.alpha_optimizer.zero_grad()
            loss_alpha = self._compute_loss_alpha(data)
            loss_alpha.backward()

        tb_info = {
            "DSACTHcrl/critic_avg_q1-RL iter": q1.item(),
            "DSACTHcrl/critic_avg_q2-RL iter": q2.item(),
            "DSACTHcrl/critic_avg_std1-RL iter": std1.item(),
            "DSACTHcrl/critic_avg_std2-RL iter": std2.item(),
            "DSACTHcrl/critic_avg_qc-RL iter": qc.item(),
            "DSACTHcrl/critic_avg_min_std1-RL iter": min_std1.item(),
            "DSACTHcrl/critic_avg_min_std2-RL iter": min_std2.item(),
            tb_tags["loss_actor"]: -q1.item(),
            tb_tags["loss_critic"]: loss_q.item(),
            "DSACTHcrl/policy_mean-RL iter": policy_mean,
            "DSACTHcrl/policy_std-RL iter": policy_std,
            "DSACTHcrl/entropy-RL iter": entropy.item(),
            "DSACTHcrl/alpha-RL iter": self._get_alpha(),
            "DSACTHcrl/mean_std1": self.mean_std1,
            "DSACTHcrl/mean_std2": self.mean_std2,
            "DSACTHcrl/lagrangian_multiplier": self.hcrl.lagrangian_multiplier,
            "DSACTHcrl/ep_cost_mean": c,
            tb_tags["alg_time"]: (time.time() - start_time) * 1000,
        }

        return tb_info

    def _q_evaluate(self, obs, act, qnet):
        StochaQ = qnet(obs, act)
        mean, std = StochaQ[..., 0], StochaQ[..., -1]
        normal = Normal(torch.zeros_like(mean), torch.ones_like(std))
        z = normal.sample()
        z = torch.clamp(z, -3, 3)
        q_value = mean + torch.mul(z, std)
        return mean, std, q_value
    
    def _qc_evaluate(self, obs, act, qnet):
        StochaQ = qnet(obs, act)
        mean, std = StochaQ[..., 0], StochaQ[..., -1]
        return mean

    def _compute_loss_q(self, data: DataDict):
        obs, act, rew, obs2, done = (
            data["obs"],
            data["act"],
            data["rew"],
            data["obs2"],
            data["done"],
        )
        logits_2 = self.networks.policy_target(obs2)
        act2_dist = self.networks.create_action_distributions(logits_2)
        act2, log_prob_act2 = act2_dist.rsample()

        q1, q1_std, _ = self._q_evaluate(obs, act, self.networks.q1)
        q2, q2_std, _ = self._q_evaluate(obs, act, self.networks.q2)
        if self.mean_std1 is None:
            self.mean_std1 = torch.mean(q1_std.detach())
        else:
            self.mean_std1 = (1 - self.tau_b) * self.mean_std1 + self.tau_b * torch.mean(q1_std.detach())

        if self.mean_std2 is None:
            self.mean_std2 = torch.mean(q2_std.detach())
        else:
            self.mean_std2 = (1 - self.tau_b) * self.mean_std2 + self.tau_b * torch.mean(q2_std.detach())

        q1_next, _, q1_next_sample = self._q_evaluate(
            obs2, act2, self.networks.q1_target
        )
        
        q2_next, _, q2_next_sample = self._q_evaluate(
            obs2, act2, self.networks.q2_target
        )
        q_next = torch.min(q1_next, q2_next)
        q_next_sample = torch.where(q1_next < q2_next, q1_next_sample, q2_next_sample)

        target_q1, target_q1_bound = self._compute_target_q(
            rew,
            done,
            q1.detach(),
            self.mean_std1.detach(),
            q_next.detach(),
            q_next_sample.detach(),
            log_prob_act2.detach(),
        )
        
        target_q2, target_q2_bound = self._compute_target_q(
            rew,
            done,
            q2.detach(),
            self.mean_std2.detach(),
            q_next.detach(),
            q_next_sample.detach(),
            log_prob_act2.detach(),
        )

        q1_std_detach = torch.clamp(q1_std, min=0.).detach()
        q2_std_detach = torch.clamp(q2_std, min=0.).detach()
        bias = 0.1

        q1_loss = (torch.pow(self.mean_std1, 2) + bias) * torch.mean(
            -(target_q1 - q1).detach() / ( torch.pow(q1_std_detach, 2)+ bias)*q1
            -((torch.pow(q1.detach() - target_q1_bound, 2)- q1_std_detach.pow(2) )/ (torch.pow(q1_std_detach, 3) +bias)
            )*q1_std
        )

        q2_loss = (torch.pow(self.mean_std2, 2) + bias)*torch.mean(
            -(target_q2 - q2).detach() / ( torch.pow(q2_std_detach, 2)+ bias)*q2
            -((torch.pow(q2.detach() - target_q2_bound, 2)- q2_std_detach.pow(2) )/ (torch.pow(q2_std_detach, 3) +bias)
            )*q2_std
        )

        return q1_loss +q2_loss, q1.detach().mean(), q2.detach().mean(), q1_std.detach().mean(), q2_std.detach().mean(), q1_std.min().detach(), q2_std.min().detach()

    def _compute_loss_qc(self, data: DataDict):
        obs, act, cost, obs2, done = (
            data["obs"],
            data["act"],
            data["cost"],
            data["obs2"],
            data["done"],
        )

        logits_2 = self.networks.policy_target(obs2)
        act2_dist = self.networks.create_action_distributions(logits_2)
        act2 = act2_dist.rsample()[0]
 
        qc = self._qc_evaluate(obs, act, self.networks.qc)
        qc_next = self._qc_evaluate(obs2, act2, self.networks.qc_target)

        target_qc = self._compute_target_qc(
            cost,
            done,
            qc_next.detach()
        )

        qc_loss = torch.mean(
            torch.pow(qc - target_qc, 2)
        )
        return qc_loss, qc.detach().mean()


    def _compute_target_q(self, r, done, q,q_std, q_next, q_next_sample, log_prob_a_next):
        target_q = r + (1 - done) * self.gamma * (
            q_next - self._get_alpha() * log_prob_a_next
        )
        target_q_sample = r + (1 - done) * self.gamma * (
            q_next_sample - self._get_alpha() * log_prob_a_next
        )
        td_bound = 3 * q_std
        difference = torch.clamp(target_q_sample - q, -td_bound, td_bound)
        target_q_bound = q + difference
        return target_q.detach(), target_q_bound.detach()
    
    def _compute_target_qc(self, c, done, qc_next):
        target_qc = c + (1 - done) * self.gamma * qc_next
        return target_qc.detach()

    def update_policy_grad(self, data):
        
        obs, new_act, new_log_prob = data["obs"], data["new_act"], data["new_log_prob"]
        
        self.networks.policy_optimizer.zero_grad()
        q1, _, _ = self._q_evaluate(obs, new_act, self.networks.q1)
        q2, _, _ = self._q_evaluate(obs, new_act, self.networks.q2)
        minus_q = -torch.min(q1,q2)
        minus_q.mean().backward()
        policy_grad = [p._grad for p in self.networks.policy.parameters()]
        policy_performance_grad = deepcopy(policy_grad)
        policy_performance_grad = torch.nn.utils.parameters_to_vector(policy_performance_grad)

        self.networks.policy_optimizer.zero_grad()
        obs = data["obs"]
        logits = self.networks.policy(obs)
        logits_mean, logits_std = torch.chunk(logits, chunks=2, dim=-1)
        act_dist = self.networks.create_action_distributions(logits)
        new_act, new_log_prob = act_dist.rsample()
        qc, _, _ = self._q_evaluate(obs, new_act, self.networks.qc)
        qc.mean().backward()
        policy_grad = [p._grad for p in self.networks.policy.parameters()]
        policy_safety_grad = deepcopy(policy_grad)
        policy_safety_grad = torch.nn.utils.parameters_to_vector(policy_safety_grad)

        self.networks.policy_optimizer.zero_grad()
        obs = data["obs"]
        logits = self.networks.policy(obs)
        logits_mean, logits_std = torch.chunk(logits, chunks=2, dim=-1)
        act_dist = self.networks.create_action_distributions(logits)
        new_act, new_log_prob = act_dist.rsample()
        entropy = -new_log_prob.detach().mean()
        entropy_loss = (self._get_alpha() * new_log_prob).mean().backward()
        policy_grad = [p._grad for p in self.networks.policy.parameters()]
        policy_entropy_grad = deepcopy(policy_grad)
        policy_entropy_grad = torch.nn.utils.parameters_to_vector(policy_entropy_grad)
        
        combine_grad = self.hcrl.update_parameters(policy_performance_grad, policy_safety_grad)
        total_policy_grad = combine_grad + policy_entropy_grad
        
        from gops.utils.lagrangian import apply_grad_vector_to_params
        apply_grad_vector_to_params(self.networks.policy.parameters(), total_policy_grad)
        
        return entropy

    def _compute_loss_alpha(self, data: DataDict):
        new_log_prob = data["new_log_prob"]
        loss_alpha = (
            -self.networks.log_alpha
            * (new_log_prob.detach() + self.target_entropy).mean()
        )
        return loss_alpha

    def _update(self, iteration: int):
        self.networks.q1_optimizer.step()
        self.networks.q2_optimizer.step()
        self.networks.qc_optimizer.step()
        self.hcrl.lambda_optimizer.step()
        self.hcrl._lagrangian_multiplier.data.clamp_(
            0.0, self.hcrl.lagrangian_upper_bound
        )

        if iteration % self.delay_update == 0:
            self.networks.policy_optimizer.step()

            if self.auto_alpha:
                self.networks.alpha_optimizer.step()

            with torch.no_grad():
                polyak = 1 - self.tau
                for p, p_targ in zip(
                    self.networks.q1.parameters(), self.networks.q1_target.parameters()
                ):
                    p_targ.data.mul_(polyak)
                    p_targ.data.add_((1 - polyak) * p.data)
                for p, p_targ in zip(
                    self.networks.q2.parameters(), self.networks.q2_target.parameters()
                ):
                    p_targ.data.mul_(polyak)
                    p_targ.data.add_((1 - polyak) * p.data)
                for p, p_targ in zip(
                    self.networks.qc.parameters(), self.networks.qc_target.parameters()
                ):
                    p_targ.data.mul_(polyak)
                    p_targ.data.add_((1 - polyak) * p.data)

                for p, p_targ in zip(
                    self.networks.policy.parameters(),
                    self.networks.policy_target.parameters(),
                ):
                    p_targ.data.mul_(polyak)
                    p_targ.data.add_((1 - polyak) * p.data)
