import scipy.optimize as opt
import scipy.special
import numpy as np
import copy
import torch.nn.functional as F
import torch.utils
import safety_gymnasium as safety_gym
from safety_gymnasium.wrappers import SafeNormalizeObservation, SafeNormalizeReward
import torch
from torch import nn
import cv2
from tqdm import trange
from wrappers import (
    SafeClipAction,
    ObservationStack,
    SafeTanHAction,
    SafeRecordEpisodeStatistics,
)
from torch.utils.tensorboard import SummaryWriter
import argparse
import time
from utils import retrace
from off_policy_variational import off_policy_surrogate


def difference_of_probs(logp1, logp2):
    c = max(max(logp1), max(logp2))
    return (np.exp(logp1 - c) - np.exp(logp2 - c)) * np.exp(c)


# Function to update the target network using Polyak averaging
def polyak_average(model, target_model, beta=0.995):
    """
    Perform Polyak averaging: target_model = beta * target_model + (1 - beta) * model
    Args:
        model (nn.Module): The current model
        target_model (nn.Module): The target model to be updated
        beta (float): The averaging factor (typically close to 1.0)
    """
    change = 0
    for target_param, model_param in zip(target_model.parameters(), model.parameters()):
        target_param.data.copy_(
            beta * target_param.data + (1 - beta) * model_param.data
        )
        change += (
            ((1 - beta) * target_param.data - (1 - beta) * model_param.data)
            .abs()
            .mean()
        )


class fit_surrogate:
    def __init__(self, log_pi, qvals, cvals, affinity, eps, budget):
        self.budget = budget
        self.cvals = cvals
        self.log_pi = log_pi.detach().numpy().reshape(-1)
        self.pi = np.exp(self.log_pi)
        self.q_fn = qvals.detach().numpy().reshape(-1)
        self.c_fn = cvals.detach().numpy().reshape(-1)
        self.affinity = affinity
        # likelihood of constraint
        self.g = np.log(self.affinity) - np.maximum(self.c_fn - self.budget, 0)
        self.eps = eps

    def __call__(self, l):
        eta, lmbda = l
        logq = self.get_q(eta, lmbda)
        normalized_pi = self.log_pi - scipy.special.logsumexp(self.log_pi)
        # fst = np.mean(self.pi*np.exp((self.q_fn + (1-lmbda)*self.g)/eta))
        # snd = np.mean(lmbda*self.pi*self.g)
        # trd = eta*self.eps
        fst = np.sum(np.exp(logq) * (self.q_fn - lmbda * self.g))
        snd = lmbda * np.sum(np.exp(normalized_pi) * self.g)
        trd = eta * self.eps
        fth = -eta * np.sum(np.exp(logq) * (logq - normalized_pi))
        # print(fst,snd,trd,fth,eta,lmbda, fth/eta)
        return fst + snd + trd + fth

    def get_q(self, eta, lmbda, log_pi=None, q_fn=None, g=None):
        if log_pi is None:
            log_pi = self.log_pi
        if q_fn is None:
            q_fn = self.q_fn
        if g is None:
            g = self.g
        unnormalized_q = log_pi + (q_fn - lmbda * g) / eta
        normalizer = scipy.special.logsumexp(unnormalized_q)
        return unnormalized_q - normalizer

    def get_kl_div(
        self,
        eta,
        lmbda,
    ):
        normalized_q = self.get_q(eta, lmbda)
        # logr =normalized_q - self.pi
        # return np.mean(np.exp(logr)*logr)
        normalized_pi = self.log_pi - scipy.special.logsumexp(self.log_pi)
        kl_div = np.exp(normalized_q) * (normalized_q - normalized_pi)
        return np.sum(kl_div)

    def get_constraint_diff(self, eta, lmbda):
        log_q = self.get_q(eta, lmbda)
        normalized_pi = self.log_pi - scipy.special.logsumexp(self.log_pi)
        snd = np.sum(np.exp(normalized_pi) * self.g) - np.sum(np.exp(log_q) * self.g)
        print("yielddiff", np.exp(normalized_pi) * self.g - np.exp(log_q) * self.g)
        return snd

    def expected_returns(
        self,
        eta,
        lmbda,
    ):
        q = self.get_q(eta, lmbda)
        return np.sum(self.q_fn * np.exp(q))


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    #torch.nn.init.kaiming_normal_(layer.weight)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


def kl_split(d,dold):
    mu=d.mean
    std=d.covariance_matrix
    mu_old=dold.mean
    std_old=dold.covariance_matrix
    diff_mu = mu - mu_old
    mu_kl = 0.5 * torch.sum(diff_mu*torch.linalg.lstsq(std_old,diff_mu).solution ,-1)
    sigma_div = torch.vmap(torch.trace)(torch.linalg.lstsq(std,std_old).solution)
    new_det = torch.vmap(torch.linalg.det)(std).abs()
    old_det = torch.vmap(torch.linalg.det)(std_old).abs()
    std_kl = 0.5*(sigma_div - std.shape[-1] + torch.log(new_det/(old_det+1e-9)) )
    return mu_kl.mean(),std_kl.mean()

class tmpfun(nn.Module):
    def __init__(self, size, actions):
        super().__init__()
        self.size = size
        self.model = nn.Sequential(
            layer_init(nn.Linear(size, 256)),
            nn.LeakyReLU(),
            layer_init(nn.Linear(256, 256)),
            nn.LeakyReLU(),
            layer_init(nn.Linear(256, 256)),
            nn.LeakyReLU(),
        )
        self.mean=nn.Sequential(
            layer_init(nn.Linear(256, actions)),
        )
        self.std = nn.Sequential(
            layer_init(nn.Linear(256, (actions * (actions + 1)) // 2),0.1,),
        )
        
        #nn.Parameter(torch.zeros(actions))
        self.dual_param = nn.Parameter(torch.tensor([0.0,10.0]))
        self.action_low = -1
        self.action_high = 1
        self.temperature=1
        self.actions=actions
        #nn.init.xavier_normal_(self.std[0].weight)
        #nn.init.xavier_normal_(self.mean[0].weight)

    def dual_mult(self):
        return F.softplus(self.dual_param)

    def forward(self, x, act):
        ds = self.get_distribution(x)
        return ds.log_prob(act[None]).squeeze(), ds.entropy()

    def get_distribution(self, x):
        p = self.model(x)
        mu = self.mean(p)/self.temperature
        mu = F.sigmoid(mu.clamp(-5,5))
        mu = self.action_low + (self.action_high-self.action_low)*mu
        std = self.std(p)

        tril_indices = torch.tril_indices(row=self.actions, col=self.actions, offset=0)
        cholesky=torch.zeros(size=(len(x), self.actions, self.actions))
        cholesky[:,tril_indices[0],tril_indices[1]]=std
        diag=torch.arange(self.actions)
        #cholesky[:,tril_indices[0],tril_indices[1]]=std
        cholesky[:,diag,diag]=F.softplus(cholesky[:,diag,diag])+ 1e-4
        cholesky = cholesky.clamp(self.action_low,self.action_high)
        #std_diag=torch.diag_embed(F.softplus(std)+ 1e-2)
        #ds = torch.distributions.MultivariateNormal(mu, std_diag)
        ds = torch.distributions.MultivariateNormal(mu, scale_tril=cholesky)
        return ds

    def sample(self, x):
        return self.get_distribution(x).sample()

    def get_action(self, x):
        ds = self.get_distribution(x)
        act = ds.sample()
        return act, ds.log_prob(act[None]).squeeze(), ds.entropy()


class qfun(nn.Module):
    def __init__(self, size, n_acts, final_bias=0):
        super().__init__()
        self.size = size
        self.model = nn.Sequential(
            nn.Linear(size + n_acts, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1),
        )

    def forward(self, x, act):
        p = self.model(torch.cat([x, act], -1))
        return p


def mk_environment(envs, env_name, n_stack):
    def builder():
        env = safety_gym.make(env_name, render_mode="rgb_array")
        env = SafeRecordEpisodeStatistics(env)
        env = safety_gym.wrappers.SafeAutoResetWrapper(env)
        # env = safety_gym.wrappers.
        """env = SafeClipAction(env,
            -1,
            1,
        )"""
        #env = SafeTanHAction(env)
        # env =SafeNormalizeObservation(env)
        # env = SafeNormalizeReward(env)
        # env = ObservationStack(env,n_stack)
        return env

    env = safety_gym.vector.SafetySyncVectorEnv([builder for _ in range(envs)])
    return env


def eval_env(logger, policy, global_step, env_name):
    returns = []
    constraints = []
    for _ in range(2):
        env = mk_environment(50, env_name, 1)
        next_obs, _ = env.reset()
        next_obs = torch.Tensor(next_obs)
        for step in range(0, 1000):
            # ALGO LOGIC: action logic
            with torch.no_grad():
                action, _, _ = policy.get_action(next_obs)
            # TRY NOT TO MODIFY: execute the game and log data.
            next_obs, reward, cost, terminations, truncations, infos = env.step(
                action.cpu().numpy()
            )
            next_obs = torch.Tensor(next_obs)
            if "final_info" in infos:
                print("adding info elements")
                for its, info in enumerate(infos["final_info"]):
                    returns.append(info["final_info"]["episode"]["r"])
                    constraints.append(info["final_info"]["episode"]["c"])
    logger.add_scalar("eval/episodic_return", np.mean(returns), global_step)
    logger.add_scalar("eval/episodic_cost", np.mean(constraints), global_step)
    logger.add_scalar("eval/episodic_cost_max", np.max(constraints), global_step)
    logger.add_scalar("eval/episodic_return_max", np.max(returns), global_step)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--render", action="store_true")
    parser.add_argument("--timesteps", type=int, default=128)
    parser.add_argument("--envs", type=int, default=4)
    parser.add_argument("--env_name",type=str,default="SafetyCarGoal1-v0")
    parser.add_argument("--affinity", type=int, default=1)
    parser.add_argument("--eps", type=float, default=0.2)
    parser.add_argument("--batchsize", type=int, default=512)
    parser.add_argument("--training_iterations", type=int, default=32)
    parser.add_argument("--buffer_size", type=int, default=1)
    parser.add_argument("--cost_budget", type=int, default=0)
    parser.add_argument("--n_epoch_steps", type=int, default=100_000)
    parser.add_argument("--n_supersample", type=int, default=32)
    parser.add_argument("--n_training_steps", type=int, default=32)
    parser.add_argument("--kappa", type=float, default=1.0)
    parser.add_argument("--experiment_name", type=str, default="experiment")

    args = parser.parse_args()

    e=safety_gym.make(args.env_name, render_mode="rgb_array")
    
    render = args.render
    timesteps = args.timesteps
    envs = args.envs
    size = e.observation_space.shape[0] #72
    buffer_size = args.buffer_size
    actions_space = e.action_space.shape[0]#2
    n_supersample = args.n_supersample
    affinity = args.affinity
    eps = args.eps
    eps_mu = 0.1
    eps_std = 1e-4
    decay = 0.99
    n_training_steps=args.n_training_steps
    n_epoch_steps = args.n_epoch_steps
    batchsize = args.batchsize
    budget = args.cost_budget
    kappa_max = args.kappa
    kappa_min = kappa_max
    n_stack = 1
    device = "cpu"
    keep_rollouts = buffer_size / (envs * timesteps)
    policy = tmpfun(size * n_stack, actions_space)
    policy_old = copy.deepcopy(policy)
    qfn = qfun(size * n_stack, actions_space)
    q_old = qfun(size * n_stack, actions_space)
    q_old.load_state_dict(qfn.state_dict())

    cfn = qfun(size * n_stack, actions_space)
    c_old = qfun(size * n_stack, actions_space)
    c_old.load_state_dict(cfn.state_dict())

    ###
    policy_opt = torch.optim.AdamW(
        [
            {"params": policy.model.parameters()},
            {"params": policy.std.parameters()},
            {"params": policy.mean.parameters()},
            {"params": policy.dual_param, "lr": 0.1},
        ],
        3e-4,
    )
    q_opt = torch.optim.AdamW(qfn.parameters(), 1e-3,)
    c_opt = torch.optim.AdamW(cfn.parameters(), 1e-3,)
    ####

    x0 = torch.randn(envs, size * n_stack)
    actions = torch.randint(0, actions_space, (envs,))
    point = np.random.rand(2)
    NAME=f"runs/{args.experiment_name}-{args.env_name}-{int(time.time())}"
    logger = SummaryWriter(
        log_dir=NAME,
    )
    logger.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s"
        % (
            "\n".join(
                [
                    f"|{key}|{value}|"
                    for key, value in {
                        "render": render,
                        "timesteps": timesteps,
                        "num_envs": envs,
                        "buffer size": buffer_size,
                        "action_space": actions_space,
                        "affinity": affinity,
                        "eps": eps,
                        "decay": decay,
                        "budget": budget,
                        "batchsize": batchsize,
                    }.items()
                ]
            )
        ),
    )
    dataset_o = torch.zeros(timesteps, envs, buffer_size, size * n_stack)
    dataset_next_obs = torch.zeros(envs, buffer_size, size * n_stack)
    dataset_next_dones = torch.zeros(envs,buffer_size)
    dataset_a = torch.zeros(timesteps, envs, buffer_size, actions_space)
    dataset_r = torch.zeros(timesteps, envs, buffer_size)
    dataset_logp = torch.zeros(timesteps, envs, buffer_size)
    dataset_c = torch.zeros(timesteps, envs, buffer_size)
    # dataset_adv= torch.zeros(timesteps,envs, buffer_size)
    dataset_rets = torch.zeros(timesteps, envs, buffer_size)
    dataset_constr = torch.zeros(timesteps, envs, buffer_size)
    dataset_done = torch.zeros(timesteps, envs, buffer_size)
    env = mk_environment(envs, args.env_name,n_stack,)
    cmu, cstd = 0, 1
    next_obs, info = env.reset()
    next_obs = torch.Tensor(next_obs).to(device)
    next_done = torch.zeros(envs).to(device)
    framelist = np.zeros((1000, envs, 256, 256, 3)).astype(np.uint8)
    frame_n = 0
    done = np.zeros(envs)
    global_step = 0
    print("Starting optimization")
    for epoch in range(25_000):
        kappa = min(epoch, 1000) / 1000 * (kappa_min - kappa_max) + kappa_max
        print("kappa current", kappa)
        if epoch % 50 == 5:
            eval_env(logger, policy, global_step,args.env_name)
            # saving all models:
            print("Saving models")
            torch.save(
                {"policy":policy,
                 "cfn":cfn,
                 "qfn":qfn,
                 },
                NAME+"/"+"model.pth"
            )
        rewards = torch.zeros(timesteps, envs)
        costs = torch.zeros(timesteps, envs)
        actions = torch.zeros(timesteps, envs, actions_space)
        obs = torch.zeros(timesteps, envs, size)
        returns = torch.zeros(timesteps, envs)
        logprobs = torch.zeros(timesteps, envs)
        constr = torch.zeros(timesteps, envs)
        dones = torch.zeros(timesteps, envs)

        finish_data = []

        for step in range(0, timesteps):
            frame_n += 1
            global_step += envs
            obs[step] = next_obs
            dones[step] = next_done

            # ALGO LOGIC: action logic
            with torch.no_grad():
                action, logprob, _ = policy.get_action(next_obs)
            actions[step] = action
            logprobs[step] = logprob

            # TRY NOT TO MODIFY: execute the game and log data.
            next_obs, reward, cost, terminations, truncations, infos = env.step(
                action.cpu().numpy()
            )
            next_done = np.logical_or(terminations, truncations)
            rewards[step] = torch.tensor(reward).to(device).view(-1)
            costs[step] = torch.tensor(cost)
            next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(
                next_done
            ).to(device)
            if render:
                framelist[frame_n % 1000] = env.call("render")
            if "final_info" in infos:
                for its, info in enumerate(infos["final_info"]):
                    finish_data.append(info["final_info"]["episode"])
                    print(
                        f"global_step={global_step}, episodic_return={info['final_info']['episode']['r']}, episodic_costs={info['final_info']['episode']['c']}"
                    )
                    if render:
                        w = cv2.VideoWriter(
                            f"outputvideo{its}.avi",
                            cv2.VideoWriter_fourcc(*"MJPG"),
                            30,
                            framelist[0][its].shape[:2],
                        )
                        for frame in range(1000):
                            # print(frame)
                            w.write(framelist[frame, its])
                        w.release()
        if len(finish_data) > 0:
            finish_data = {k: [dic[k] for dic in finish_data] for k in finish_data[0]}
            logger.add_scalar(
                "charts/episodic_return", np.mean(finish_data["r"]), global_step
            )
            logger.add_scalar(
                "charts/episodic_cost", np.mean(finish_data["c"]), global_step
            )
            logger.add_scalar(
                "charts/episodic_length", np.mean(finish_data["l"]), global_step
            )

        print("action mean+std", policy.get_distribution(obs[-1]).mean,policy.get_distribution(obs[-1]).covariance_matrix)
        next_obs = torch.tensor(next_obs).float()
        dataset_next_obs[:, epoch % buffer_size] = next_obs
        dataset_next_dones[:, epoch%buffer_size] = next_done

        print(
            epoch,
            "rewards",
            rewards.sum(0).mean(),
        )
        print("costs", costs.sum(0))
        dataset_o[:, :, epoch % buffer_size] = obs
        dataset_a[:, :, epoch % buffer_size] = actions
        # dataset_adv[:, :, epoch % buffer_size] = advantages
        dataset_r[:, :, epoch % buffer_size] = rewards
        dataset_c[:, :, epoch % buffer_size] = costs
        dataset_rets[:, :, epoch % buffer_size] = returns
        dataset_constr[:, :, epoch % buffer_size] = constr
        dataset_logp[:, :, epoch % buffer_size] = logprobs
        dataset_done[:, :, epoch % buffer_size] = dones
        # add truncation
        #####
        
        obs_flat = dataset_o[:, :, : epoch + 1].reshape(-1, size)
        acts_flat = dataset_a[:, :, : epoch + 1].reshape(-1, actions_space)
        rets_flat = dataset_rets[:, :, : epoch + 1].reshape(-1)
        rewards_flat = dataset_r[:, :, : epoch + 1].reshape(-1)
        cost_flat = dataset_c[:, :, : epoch + 1].reshape(-1)
        dones_flat = dataset_done[:, :, : epoch + 1].reshape(-1)
        logp_flat = dataset_logp[:, :, : epoch + 1].reshape(-1)
        #####
        # store current policy as baseline
        #policy_old = copy.deepcopy(policy)
        returns = returns.reshape(-1)
        constr = constr.reshape(-1)
        point = np.array([1.0, 1.0])
        for it in trange(args.training_iterations):
            if it % 100 == 0:
                with torch.no_grad():
                    q_ref = q_old(obs_flat, acts_flat).reshape(timesteps, -1)
                    c_ref = c_old(obs_flat, acts_flat).reshape(timesteps, -1)
                    policy_ref = policy(obs_flat, acts_flat)[0].reshape(timesteps, -1)
                    policy_ref_flat = policy_ref.reshape(timesteps, -1)
                    ratio = (
                        policy_ref_flat - logp_flat.reshape(timesteps, -1)
                    ).exp().clamp(None,1)
                    act_final,_,_ = policy_old.get_action(
                        dataset_next_obs[:,  : epoch + 1].reshape(-1, size * n_stack)
                    )
                    q_bootstrap = q_old(
                        dataset_next_obs[:,  : epoch + 1].reshape(-1, size * n_stack), act_final
                    ).reshape(-1)
                    c_bootstrap = c_old(
                        dataset_next_obs[:,  : epoch + 1].reshape(-1, size * n_stack), act_final
                    ).reshape(-1)
                    nds = dataset_next_dones[:,  : epoch + 1].reshape(-1)
                q_target = retrace(
                    q_ref,
                    dataset_r[:, :, : epoch + 1].reshape(timesteps,-1),
                    ratio,
                    dataset_done[:, :, : epoch + 1].reshape(timesteps,-1),
                    decay,
                    q_bootstrap,
                    nds
                ).reshape(-1)
                c_target = retrace(
                    c_ref,
                    dataset_c[:, :, : epoch + 1].reshape(timesteps,-1),
                    ratio,
                    dataset_done[:, :, : epoch + 1].reshape(timesteps,-1),
                    1.0,
                    c_bootstrap,
                    nds
                ).reshape(-1)
            if it%args.training_iterations==(args.training_iterations-1):
                print(
                    "q_target",
                    q_target.mean(),
                    q_target.max(),
                    q_target.min(),
                    q_target.std(),
                )
                print(
                    "c_target",
                    c_target.mean(),
                    c_target.max(),
                    c_target.min(),
                    c_target.std(),
                )
                print("c_ref", c_ref.mean(), c_ref.max(), c_ref.min(), c_ref.std())
                print("q_ref", q_ref.mean(), q_ref.max(), q_ref.min(), q_ref.std())
            shuffle = torch.arange(len(obs_flat)).numpy()
            np.random.shuffle(shuffle)
            k = min(n_epoch_steps, len(shuffle) // min(batchsize, len(obs_flat)))
            for s in range(k):
                idxs = shuffle[s * batchsize : (s + 1) * batchsize]
                q = qfn(obs_flat[idxs], acts_flat[idxs]).squeeze()
                c = cfn(obs_flat[idxs], acts_flat[idxs]).squeeze()
                v_loss = torch.mean((q - q_target[idxs]) ** 2)
                c_loss = torch.mean((c - c_target[idxs]) ** 2)
                v_loss.backward()
                c_loss.backward()
                torch.nn.utils.clip_grad_norm_(qfn.parameters(), 10)
                torch.nn.utils.clip_grad_norm_(cfn.parameters(), 10)
                q_opt.step()
                c_opt.step()
                q_opt.zero_grad()
                c_opt.zero_grad()
                polyak_average(qfn, q_old)
                polyak_average(cfn, c_old)
                
            # Training policy
            s=0
            idxs = shuffle[s * batchsize : (s + 1) * batchsize]
            with torch.no_grad():
                obs_supersample = torch.cat([obs_flat[idxs]] * n_supersample, 0)
                act, _, _ = policy_old.get_action(obs_supersample)
                k = off_policy_surrogate(
                    policy_old,
                    q_old,
                    c_old,
                    obs_flat[idxs],
                    affinity,
                    eps=eps,
                    budget=budget,
                    kappa=kappa,
                    act=act
                )
            res = opt.minimize(
                k,
                point,
                bounds=[(1e-6, 1e7), (1e-6, 1e7)],
                method="SLSQP",
                options={"eps":1e-5, "ftol":1e-4,}
            )
            if np.any(np.isnan(res.x)):
                print("NAN in nonparametric estimation, trying with Powell")
                point = np.array([10.0, 10.0])
                res = opt.minimize(
                    k,
                    point,
                    bounds=[(1e-3, 1e6), (1e-3, 1e6)],
                    method="Powell",
                )
            point=res.x

            for _ in range(n_training_steps):
                logp, ent = policy(k.obs_supersample, k.act)
                prob = torch.exp(torch.tensor(k.get_q(*res.x))).squeeze()
                pg = -torch.mean(torch.sum(logp.reshape(k.N_samples, -1) * prob, 0))
                dold = policy_old.get_distribution(obs_flat[idxs])
                d = policy.get_distribution(obs_flat[idxs])


                kl_mean,kl_std = kl_split(d,dold)
                kl_stacked = torch.tensor([eps_mu,eps_std])-torch.stack([kl_mean, kl_std], 0).squeeze()
                trust_region = torch.sum(
                    policy.dual_mult().squeeze() * kl_stacked.detach()
                    - policy.dual_mult().squeeze().detach() * kl_stacked
                )
                loss = pg + trust_region
                loss.backward()
                torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.1)
                policy_opt.step()
                policy_opt.zero_grad()
                # logp_new
            polyak_average(policy, policy_old)
            

            logger.add_scalar("losses/loss", loss, global_step)
            logger.add_scalar("losses/v_loss", v_loss, global_step)
            logger.add_scalar("losses/c_loss", c_loss, global_step)
            logger.add_scalar("losses/pg", pg, global_step)
            logger.add_scalar(
                "losses/trust_region", trust_region.mean(), global_step
            )
            logger.add_scalar("losses/kl_div_err_mu", kl_mean.mean(), global_step)
            logger.add_scalar("losses/kl_div_err_std", kl_std.mean(), global_step)
            logger.add_scalar(
                "losses/dual_mult_mu", policy.dual_mult()[0], global_step
            )
            logger.add_scalar(
                "losses/dual_mult_std", policy.dual_mult()[1], global_step
            )
            logger.add_scalar(
                "losses/constraint_advantage", k.get_x(*res.x), global_step
            )
            if it%args.training_iterations==(args.training_iterations-1):
                print(it, res)
                print("kl_div surrogate", k.get_kl_div(*res.x))
                print("constraint advantage", k.get_constraint_diff(*res.x))
                print("expected returns", k.expected_returns(*res.x))
                print("probabilities", np.exp(k.get_q(*res.x)))
                print(k.g.max(), k.g.mean(), k.g.min())
                print(
                    v_loss,
                    c_loss,
                    pg,
                    "kl_div",
                    torch.stack([kl_mean, kl_std], 0),
                    "dual multiplier",
                    policy.dual_mult(),
                    "cost",
                    c.mean(),
                    "±",
                    c.std(),
                    "value",
                    q.mean(),
                    "±",
                    q.std(),
                    "absolute cost error",
                    torch.mean((c - c_target[idxs]).abs()),
                    "cmax",
                    c.max(),
                )
