import torch
from gcip.utils.io import dict_to_cn

import torch.nn as nn
from gcip.utils.kernels.kernel_base import KernelBase


class ExponentialKernel(KernelBase):
    def __init__(self, w,  requires_grad, bias=0.0):
        super().__init__(name='exp',requires_grad=requires_grad)
        assert w > 0.0
        self.w = nn.Parameter(torch.tensor([w]), requires_grad=self.requires_grad)
        self.bias_value = bias
        assert bias >= 0.0
        if bias > 0.0:
            self.bias =  nn.Parameter(torch.tensor([bias]), requires_grad=self.requires_grad)


    def reset_parameters(self):
        self.w = nn.Parameter(torch.rand(1) + 0.1, requires_grad=self.requires_grad)
        self.bias =  nn.Parameter(torch.rand(1), requires_grad=self.requires_grad)



    def set_params(self, w, bias=0.0):
        self.w = nn.Parameter(torch.tensor([w]), requires_grad=self.requires_grad)
        self.bias =  nn.Parameter(torch.tensor([bias]), requires_grad=self.requires_grad)

    @classmethod
    def params(cls, kernel):
        if isinstance(kernel, dict):
            kernel = dict_to_cn(kernel)
        return {
            'w': kernel.w,
        }

    def forward(self, t):

        output = torch.exp(-self.w * t[t>=0.0])
        return output


    def integrate(self, t_init, t_end, t_i):
        assert t_end > t_init

        if self.bias_value > 0:
            t_init = t_init + self.bias
        t_init = t_init * torch.ones_like(t_i)
        t_end = t_end * torch.ones_like(t_i)

        t_init =  torch.where(t_init < t_i, t_i, t_init)
        t_end =  torch.where(t_end < t_i, t_i, t_end)

        upper =  torch.exp(-self.w*(t_init - t_i))
        bottom =  torch.exp(-self.w*(t_end - t_i))

        return 1/self.w *(upper  - bottom)



    def integrate_delta_t(self, delta_t):
        if self.bias_value > 0:
            delta_t = torch.clamp(delta_t - self.bias, min=0.0)
        integral = 1/self.w * (1.0 -  torch.exp(-self.w*delta_t))

        return integral