import torch
import numpy as np


from .datasets import pos_code
from .network import VanilaMLP


class Predictor:
    def __init__(self, state_dict_name):
        self.save = torch.load(state_dict_name)
        self.device = 'cuda:0'
        self.net = VanilaMLP().to(self.device)

        self.net.load_state_dict(self.save['state_dict'])

        self.mu = self.save['mu']
        self.std = self.save['std']

    def eval(self, data, ran, pos=True, normalize=True):
        if pos:
            _x = np.linspace(ran[0], ran[1], len(data))
            pos = pos_code(_x)
            data = data + pos
        if normalize:
            data = (data - self.mu) / (self.std + 1e-5)

        with torch.no_grad():
            self.net.eval()
            data = torch.Tensor(data).to(self.device)
            pred = self.net(data)
        return pred.item()

