import numpy as np
import torch
import torch.nn as nn
import os
import gpytorch


from gpytorch.models import ExactGP
from gpytorch.likelihoods import DirichletClassificationLikelihood
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, RBFKernel


def sample_measurement_set(self, X, num_data):
    # sample from old using geometric distribution
    # p = 2.0 / num_data
    # g = torch.distributions.Geometric(p)
    # n = g.sample()
    # count = 0
    # while n > num_data:
    #     n = g.sample()
    #     count += 1
    #     if count > 10:
    #         n = num_data
    #         break

    n = torch.Tensor([200])
    # sample measurement set with size n
    perm = torch.randperm(int(num_data))
    idx = perm[:n.to(torch.long)]
    measurement_set = X[idx, :]

    return measurement_set

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, input_dim):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.LinearMean(input_dim)
        # self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel(nu=0.5))
        # self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel() * gpytorch.kernels.PeriodicKernel())
        # self.covar_module = gpytorch.kernels.LinearKernel()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        # * gpytorch.kernels.PeriodicKernel()

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

def prior_sample_functions(X, prior, num_sample):
    # sample measurement set
    # measurement_set = self.sample_measurement_set(Xall, num_data)
    prior_marginal = prior(X)
    mean_prior = prior_marginal.mean
    mean_prior = mean_prior[:, None]
    mean_prior = mean_prior.repeat(num_sample, 1, 1)

    K_prior = torch.diag(prior_marginal.covariance_matrix)
    K_prior = K_prior[:, None]
    K_prior = K_prior.repeat(num_sample, 1, 1)

    gp_sample = mean_prior + K_prior * torch.rand_like(mean_prior)

    return gp_sample

def prior_sample_functions2(X, prior, num_sample):
    # sample measurement set
    # measurement_set = self.sample_measurement_set(Xall, num_data)
    # prior_marginal = prior(X)
    # mean_prior = prior_marginal.mean
    mean_prior = torch.randn(40)
    mean_prior = mean_prior[:, None]
    mean_prior = mean_prior.repeat(num_sample, 1, 1)

    # K_prior = torch.diag(prior_marginal.covariance_matrix)
    K_prior = torch.randn(40)
    K_prior = K_prior[:, None]
    K_prior = K_prior.repeat(num_sample, 1, 1)

    gp_sample = mean_prior + K_prior * torch.rand_like(mean_prior)

    return gp_sample
    
class DirichletGPModel(ExactGP):
    def __init__(self, train_x, train_y, likelihood, num_classes):
        super(DirichletGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = ConstantMean(batch_shape=torch.Size((num_classes,)))
        # self.mean_module = gpytorch.means.LinearMean(input_dim)
        self.covar_module = ScaleKernel(
            RBFKernel(batch_shape=torch.Size((num_classes,))),
            batch_shape=torch.Size((num_classes,)),
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
        


