import torch
import torch.nn as nn
from torch.nn.functional import softplus
import exp_utils as PQ


class FLAGS(PQ.BaseFLAGS):
    ell_coef = 1.
    barrier_coef = 1


class Barrier(nn.Module):
    FLAGS = FLAGS

    def __init__(self, net, env_barrier_fn, s0):
        super().__init__()
        self.net = net
        self.env_barrier_fn = env_barrier_fn
        self.s0 = s0
        self.ell = softplus

    def forward(self, states: torch.Tensor) -> torch.Tensor:
        # return self.net(state) + self.barrier_fn(state) - self.net(self.s0) + self.barrier_fn(self.s0) - self.offset
        return self.ell(self.net(states) - self.net(self.s0[None])) * FLAGS.ell_coef \
            + self.env_barrier_fn(states) * FLAGS.barrier_coef - 1


class HandCraftFLAGS(PQ.BaseFLAGS):
    max_pos = 0.5
    max_ang = 0.7
    max_vel = 2
    max_ang_vel = 2


class HandCraftBarrierSwing(nn.Module):
    FLAGS = HandCraftFLAGS

    def __init__(self):
        super().__init__()

    def forward(self, states: torch.Tensor) -> torch.Tensor:
        from safe.envs.safe_env_spec import interval_barrier

        pos, ang, vel, ang_vel = states[..., 0], states[..., 1], states[..., 2], states[..., 3]

        h1 = interval_barrier(pos, -self.FLAGS.max_pos, self.FLAGS.max_pos)
        h2 = interval_barrier(ang, -self.FLAGS.max_ang, self.FLAGS.max_ang)
        h3 = interval_barrier(vel, -self.FLAGS.max_vel, self.FLAGS.max_vel)
        h4 = interval_barrier(ang_vel, -self.FLAGS.max_ang_vel, self.FLAGS.max_ang_vel)

        return torch.stack([h1, h2, h3, h4], dim=0).max(dim=0).values - 1
