import exp_utils as PQ
import rl_utils
import torch
import torch.nn as nn


class FLAGS(PQ.BaseFLAGS):
    batch_size = 256
    n_opt_iters = 1000


class GreedyBackPolicy(rl_utils.DetNetPolicy):
    def __init__(self, model, dim_action, s0):
        super().__init__()
        self.actions = nn.Parameter(torch.zeros(FLAGS.batch_size, dim_action), requires_grad=True)
        self.optim = torch.optim.Adam([self.actions])
        self.model = model
        self.s0 = s0

    @torch.enable_grad()
    def forward(self, state):
        assert len(state) == 1 or state.ndim == 1
        actions = self.actions

        nn.init.uniform_(actions, -1, 1)
        for a in range(FLAGS.n_opt_iters):
            predictions = self.model(state.repeat(FLAGS.batch_size, 1).detach(), actions)
            loss = (predictions - self.s0).pow(2).mean(dim=-1)
            self.optim.zero_grad()
            loss.mean().backward()
            self.optim.step()

        return self.actions[loss.argmin()]

