from .base import Model
from .base import ModelFactory
from .np_priors import GPPrior

import torch

class GPFactory(ModelFactory):
    def __init__(self, config, benchmark, simulation_budget):
        super().__init__(config, benchmark, simulation_budget, GPModel)


class GPModel(Model):
    def __init__(self, benchmark, model_path, config, normalization_constants):
        super().__init__(normalization_constants)

        self.benchmark = benchmark
        self.prior = benchmark.get_prior()
        self.config = config
        self.device = benchmark.get_device()
        self.nb_estimators = config["nb_estimators"]
        self.gp_prior = GPPrior(benchmark, config, self.device)

    def log_prob(self, theta, x):
        output = (
            self.gp_prior.sample_functions(theta, x, self.nb_estimators)
            .squeeze(dim=2)
            .mean(dim=1)
        )
    
        if not self.gp_prior.log_space:
            output = torch.maximum(output, torch.Tensor([1e-10]).to(output.device)).log()

        return output

    def train_models(self):
        pass

    @classmethod
    def is_trained(cls, model_path):
        return True

    def sampling_enabled(self):
        return False

    def __call__(self, theta, x):
        return self.log_prob(theta, x)

    def get_loss_fct(self):
        pass

    def save(self):
        pass

    def load(self):
        pass

    def train(self):
        pass

    def eval(self):
        pass

    def to(self, device):
        pass
