import torch
import exp_utils as PQ
import torch.nn as nn

INF = 1e10


class StateBox:
    def __init__(self, shape, device, expansion=1.5):
        self._max = torch.full(shape, -INF, device=device)
        self._min = torch.full(shape, +INF, device=device)
        self.center = None
        self.length = None
        self.expansion = expansion
        self.device = device
        self.shape = shape

    @torch.no_grad()
    def find_box(self, crabs):
        s = torch.empty(10_000, *self.shape, device=self.device)
        count = 0
        for i in range(1000):
            self.fill_(s)
            inside = torch.where(crabs.barrier(s) < 0.0)[0]
            if len(inside) and (torch.any(s[inside] < self._min) or torch.any(s[inside] > self._max)):
                self.update(s[inside])
                count += 1
            else:
                break

    def update(self, data, logging=True):
        self._max = self._max.maximum(data.max(dim=0).values)
        self._min = self._min.minimum(data.min(dim=0).values)
        self.center = (self._max + self._min) / 2
        self.length = (self._max - self._min) / 2 * self.expansion  # expand the box
        if logging:
            PQ.log.info(f"[StateBox] updated: max = {self._max.cpu()}, min = {self._min.cpu()}")

    @torch.no_grad()
    def reset(self, s0):
        nn.init.constant_(self._max, -INF)
        nn.init.constant_(self._min, +INF)
        self.update(s0 + 1e-3, logging=False)
        self.update(s0 - 1e-3, logging=False)

    @torch.no_grad()
    def fill_(self, s):
        s.data.copy_((torch.rand_like(s) * 2 - 1) * self.length + self.center)

    def decode(self, s):
        return s * self.length + self.center

