import torch
import numpy as np
import torch.nn as nn


class PolicyPushSafeOptimizer(nn.Module):
    def __init__(self, U, policy):
        super().__init__()
        self.U = U
        self.policy = policy
        self.opt = torch.optim.Adam(policy.parameters(), lr=3e-4)

    @torch.enable_grad()
    def step(self, s):
        s = s.detach()
        if len(s) == 0:
            return
        loss = self.U(s, self.policy(s)).mean()
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()

