import torch.nn as nn
from torch.nn.functional import softplus
import exp_utils as PQ


class FLAGS(PQ.BaseFLAGS):
    ell_coef = 1.
    barrier_coef = 1


class Lyapunov(nn.Module):
    FLAGS = FLAGS

    def __init__(self, net, barrier_fn, s0):
        super().__init__()
        self.net = net
        self.barrier_fn = barrier_fn
        self.s0 = s0
        self.ell = softplus

    def forward(self, states):
        # return self.net(state) + self.barrier_fn(state) - self.net(self.s0) + self.barrier_fn(self.s0) - self.offset
        return self.ell(self.net(states) - self.net(self.s0[None])) * FLAGS.ell_coef \
               + self.barrier_fn(states) * FLAGS.barrier_coef
