import math
import torch
import os

class Periodic_kernel(torch.nn.Module):
    # Rational Quadratic Kernel
    def __init__(self, exp_restrict = 'exp', length_scale=1., scale=1., period=1.) -> None:
        super().__init__()
        self.exp_restrict = exp_restrict

        if exp_restrict is True:
            self.length_scale = torch.nn.Parameter(torch.log(torch.tensor(length_scale)))
            self.scale = torch.nn.Parameter(torch.log(torch.tensor(scale)))
        else:
            self.length_scale = torch.nn.Parameter(torch.tensor(length_scale))
            self.scale = torch.nn.Parameter(torch.tensor(scale))

        self.period = torch.nn.Parameter(torch.tensor(period))
        # TODO should we add noise here?

    def forward(self, X, X2):
        if self.exp_restrict is True:
            length_scale = torch.exp(self.length_scale).view(1, -1)
            scale = torch.exp(self.scale).view(1, -1)
        else:
            length_scale = self.length_scale.view(1, -1)
            scale = self.scale.view(1, -1)

        # optimize for multi dim.
        if len(X.shape)>2 :
            assert len(X2.shape)>2, "X and X2 should be same dim"
            X = X.reshape(X.size(0), -1)
            X2 = X2.reshape(X2.size(0), -1)

        X = X / length_scale.expand(X.size(0), length_scale.size(1))
        X2 = X2 / length_scale.expand(X2.size(0), length_scale.size(1))
        K = -2 * torch.sin(math.pi*torch.abs(X.expand(X.size(0), X2.size(0)) -X2.expand(X.size(0), X2.size(0)))/self.period)
        K = scale**2 * torch.exp(K)

        return K

    # def get_param(self):
    #     return self.length_scale, self.log_scale

    def get_param(self, optimize_param_list):
        optimize_param_list.append(self.length_scale)
        optimize_param_list.append(self.scale)
        optimize_param_list.append(self.period)
        return optimize_param_list

    def set_param(self, param_list):
        with torch.no_grad():
            self.length_scale.copy_(param_list[0])
            self.scale.copy_(param_list[1])
            self.period.copy_(param_list[2])

    def get_params_need_check(self):
        _temp_list = []
        self.get_param(_temp_list)
        return _temp_list

