import torch
import torch.nn as nn
import exp_utils as PQ


class SSampleOptimizer(nn.Module):
    def __init__(self, dim_s, obj_eval):
        super().__init__()
        self.obj_eval = obj_eval
        self.s = nn.Parameter(torch.randn(100_000, dim_s), requires_grad=False)
        self.dim_s = dim_s

    def hardD(self, s):
        result = self.obj_eval(s)
        return result['hard_obj']

    @torch.no_grad()
    def evaluate(self, *, step):
        self.s.set_(torch.randn(100_000, self.dim_s).to(self.s.device))
        s = self.s

        hardD_sample_s = self.hardD(s).max().item()
        inside = (self.obj_eval.safe_invariant.L(s) <= 1).sum().item()
        PQ.writer.add_scalar('L/sample_s/hardD', hardD_sample_s, global_step=step)
        PQ.log.debug(f"[S sampler]: hardD = {hardD_sample_s:.6f}, inside = {inside}")
