import pyro
from models.base_pyro_gp import BaseFBGP
import torch

class FBGP(BaseFBGP):
    def __init__(self, args, train_x, train_y, kernel, length_prior, noise_prior):
        super(FBGP, self).__init__()

        self.args = args
        self.batch_model = None
        self.mcmc_samples = None
        self.gpytorch_kernel = None
        self.gpytorch_likelihood = None

        # MCMC wants to pickle model, but cannot pickle class object. Thus, we must use this hacky method
        # to define the model as an object within the model
        self.gpr = pyro.contrib.gp.models.GPRegression(train_x, train_y, kernel)

        # Add training data to model
        self.train_x = train_x
        self.train_y = train_y

        #length_prior = pyro.distributions.LogNormal(0, 1.73)
        #noise_prior = pyro.distributions.LogNormal(0, 1.73)

        # Place priors on GP covariance function parameters.
        self.length_prior = length_prior
        self.noise_prior = noise_prior
        """
        if train_x.shape[1] == 2:
            print("Nothing happend..")
            #self.gpr.kernel.kern0.lengthscale = pyro.nn.PyroSample(length_prior)
            #self.gpr.kernel.kern1.lengthscale = pyro.nn.PyroSample(length_prior)
        elif train_x.shape[1] == 3:
            self.gpr.kernel.kern1.lengthscale = pyro.nn.PyroSample(length_prior)
            self.gpr.kernel.kern0.kern0.lengthscale = pyro.nn.PyroSample(length_prior)
            self.gpr.kernel.kern0.kern1.lengthscale = pyro.nn.PyroSample(length_prior)
        elif train_x.shape[1] == 6:
            self.gpr.kernel.kern1.lengthscale = pyro.nn.PyroSample(length_prior)
            self.gpr.kernel.kern0.kern1.lengthscale = pyro.nn.PyroSample(length_prior)
            self.gpr.kernel.kern0.kern0.kern1.lengthscale = pyro.nn.PyroSample(length_prior)
            self.gpr.kernel.kern0.kern0.kern0.kern1.lengthscale = pyro.nn.PyroSample(length_prior)
            self.gpr.kernel.kern0.kern0.kern0.kern0.kern1.lengthscale = pyro.nn.PyroSample(length_prior)
            self.gpr.kernel.kern0.kern0.kern0.kern0.kern0.lengthscale = pyro.nn.PyroSample(length_prior)
        elif train_x.shape[1] == 1:
            self.gpr.kernel.lengthscale = pyro.nn.PyroSample(length_prior)
        else:
            raise NotImplementedError
        """

        self.gpr.noise = pyro.nn.PyroSample(noise_prior)
