# Deep Gaussian Processes
import gpytorch
import tqdm
import numpy as np
import torch
from torch.nn import Linear
from gpytorch.means import ConstantMean, LinearMean, ZeroMean
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.variational import VariationalStrategy, CholeskyVariationalDistribution,\
    NaturalVariationalDistribution, TrilNaturalVariationalDistribution, MeanFieldVariationalDistribution
from models.base_deep_gp import BaseDeepGP


class ToyDeepGPHiddenLayer(gpytorch.models.deep_gps.DeepGPLayer):
    def __init__(self, input_dims, output_dims, num_inducing=32,
                 mean_type='constant',
                 scale_kernel=False,
                 ngd=False,
                 inducing_points=None):
        """
        
        :param input_dims:  how many inputs this hidden layer will expect
        :param output_dims: how many hidden GPs to create outputs for
        """

        if inducing_points is None:
            learn_inducing_locations = True
        else:
            learn_inducing_locations = False
            num_inducing = inducing_points.shape[0]

        if output_dims is None:
            if inducing_points is None:
                inducing_points = torch.randn(num_inducing, input_dims)
            else:
                inducing_points = inducing_points.repeat(1, input_dims)
            batch_shape = torch.Size([])
        else:
            if inducing_points is None:
                inducing_points = torch.randn(output_dims, num_inducing, input_dims)
            else:
                inducing_points = inducing_points.unsqueeze(0)
                inducing_points = inducing_points.repeat(output_dims, 1, input_dims)
            batch_shape = torch.Size([output_dims])

        if ngd:
            variational_distribution = NaturalVariationalDistribution(
                num_inducing_points=num_inducing,
                batch_shape=batch_shape
            )
        else:
            variational_distribution = CholeskyVariationalDistribution(
                num_inducing_points=num_inducing,
                batch_shape=batch_shape
            )

        variational_strategy = VariationalStrategy(
            self,
            inducing_points,
            variational_distribution,
            learn_inducing_locations=learn_inducing_locations
        )

        super(ToyDeepGPHiddenLayer, self).__init__(variational_strategy, input_dims, output_dims)

        if mean_type == 'constant':
            self.mean_module = ConstantMean(batch_shape=batch_shape)
        elif mean_type == 'linear':
            self.mean_module = LinearMean(input_dims)
        elif mean_type == 'zero':
            self.mean_module = ZeroMean(batch_shape=batch_shape)
        else:
            raise NotImplementedError

        if scale_kernel:
            self.covar_module = ScaleKernel(
                RBFKernel(batch_shape=batch_shape, ard_num_dims=input_dims),
                batch_shape=batch_shape, ard_num_dims=input_dims
            )
        else:
            self.covar_module = RBFKernel(batch_shape=batch_shape, ard_num_dims=input_dims)

        #self.linear_layer = Linear(input_dims, 1)

    def forward(self, x):
        mean_x = self.mean_module(x) # self.linear_layer(x).squeeze(-1)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
    
#    def __call__(self, x, *other_inputs, **kwargs):
#        """
#        Overriding __call__ isn't strictly necessary, but it lets us add concatenation based skip connections
#        easily. For example, hidden_layer2(hidden_layer1_outputs, inputs) will pass the concatenation of the first
#        hidden layer's outputs and the input data to hidden_layer2.
#        """
#        if len(other_inputs):
#            if isinstance(x, gpytorch.distributions.MultitaskMultivariateNormal):
#                x = x.rsample()
#
#            processed_inputs = [
#                inp.unsqueeze(0).expand(self.num_samples, *inp.shape)
#                for inp in other_inputs
#            ]
#
#            x = torch.cat([x] + processed_inputs, dim=-1)
#
#        return super().__call__(x, are_samples=bool(len(other_inputs)))    

    
class DeepGP(BaseDeepGP):
    def __init__(self, train_x_shape, num_output_dims, likelihood, num_inducing=32, ngd=False, inducing_points=None):
        hidden_layer = ToyDeepGPHiddenLayer(
            input_dims=train_x_shape[-1],
            output_dims=num_output_dims,
            mean_type='zero',
            scale_kernel=False,
            num_inducing=num_inducing,
            inducing_points=inducing_points,
            ngd=ngd,
        )
        
        hidden_layer2 = ToyDeepGPHiddenLayer(
            input_dims=hidden_layer.output_dims,
            output_dims=num_output_dims,
            mean_type='zero',
            scale_kernel=True,
            num_inducing=num_inducing,
            inducing_points=inducing_points,
            ngd=ngd,
        )

        last_layer = ToyDeepGPHiddenLayer(
            input_dims=hidden_layer.output_dims,
            output_dims=None,
            mean_type='zero',
            num_inducing=num_inducing,
            inducing_points=inducing_points,
            ngd=ngd,
        )

        super().__init__()

        self.hidden_layer = hidden_layer
        #self.hidden_layer2 = hidden_layer2
        self.last_layer = last_layer
        self.likelihood = likelihood #gpytorch.likelihoods.GaussianLikelihood()
        self.n_samples = 10
        self.train_x_shape = train_x_shape

    def forward(self, inputs, return_hidden_layers=False):
        # Make sure the first dimension is the batch size
        inputs = inputs.unsqueeze(-1) if len(inputs.shape) == 1 else inputs

        hidden_rep1 = self.hidden_layer(inputs)
        #hidden_rep1 = self.hidden_layer2(hidden_rep1)
        output = self.last_layer(hidden_rep1)

        if return_hidden_layers:
            return output, hidden_rep1

        return output


class MultitaskDeepGP(BaseDeepGP):
    def __init__(self, train_x_shape, num_output_dims, likelihood, num_tasks, num_inducing=32, ngd=False):
        hidden_layer = ToyDeepGPHiddenLayer(
            input_dims=train_x_shape[-1],
            output_dims=num_output_dims,  # hidden dimension
            mean_type='constant',
            num_inducing=num_inducing,
            ngd=ngd,
        )
        last_layer = ToyDeepGPHiddenLayer(
            input_dims=hidden_layer.output_dims,
            output_dims=num_tasks,
            mean_type='constant',
            num_inducing=num_inducing,
        )

        super().__init__()

        self.hidden_layer = hidden_layer
        self.last_layer = last_layer
        self.likelihood = likelihood
        self.n_samples = 10
        self.train_x_shape = train_x_shape

    def forward(self, inputs, return_hidden_layers=False):
        # Make sure the first dimension is the batch size
        inputs = inputs.unsqueeze(-1) if len(inputs.shape) == 1 else inputs

        hidden_rep1 = self.hidden_layer(inputs)
        output = self.last_layer(hidden_rep1)

        if return_hidden_layers:
            return output, hidden_rep1

        return output
