import torch
import torch.nn as nn

import exp_utils as PQ


class SGradOptimizer(nn.Module):
    def __init__(self, dim_state, obj_eval, normalizer):
        super().__init__()
        self.obj_eval = obj_eval
        self.s = nn.Parameter(torch.randn(10000, dim_state), requires_grad=True)
        self.opt = torch.optim.Adam([self.s], lr=1e-3)
        self.normalizer = normalizer

    def step(self):
        result = self.obj_eval(self.s)
        obj = result['hard_obj']
        loss = (-obj).mean()

        self.opt.zero_grad()
        loss.mean().backward()
        self.opt.step()
        return loss

    @torch.no_grad()
    def reinit(self, s=None):
        if s is not None:
            self.s.copy_(s)
        else:
            self.s.set_(torch.randn_like(self.s) * self.normalizer.std + self.normalizer.mean)

    def evaluate(self, *, step):
        result = self.obj_eval(self.s)
        hardD = result['hard_obj']
        L = result['constraint']
        U = result['obj']
        idx = hardD.argmax()
        nmPQ = self.obj_eval.safe_invariant.barrier.net[0]
        # print(nmPQ(self.s[idx]).cpu().detach().numpy())
        PQ.log.info(f"[S grad opt] hardD = {hardD.max().item():.6f}, L = {L[idx].item():.6f}, U = {U[idx].item():.6f}, "
                    f"inside = {(L <= 0).sum().item()}")

        return {
            'optimal': hardD.max().item(),
        }
