import torch
import torch.nn as nn
import numpy as np
import math 

def soft_clamp1(x: torch.Tensor):
    return x.div(1.).tanh_().mul(1.)    #  5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5]


def sample_normal_jit(mu, sigma):
    eps = mu.mul(0).normal_()
    #print(eps)
    z = eps.mul_(sigma).add_(mu)
    return z, eps

class decomp_block(nn.Module):
    def __init__(self, input_size, theta_size, middle_size, layers, degree_of_polynomial, harmonics, function):
        super(decomp_block, self).__init__()
        self.mapping_block = mapping(input_size, theta_size, middle_size, layers)
        self.function = function
        self.mu = nn.Linear(input_size, input_size)
        self.sigma = nn.Linear(input_size, input_size)

    def forward(self, x):
        x = self.mapping_block(x)
        res = self.function(x)
        mu =  soft_clamp1(self.mu(res))
        sigma =  soft_clamp1(self.sigma(res))
        eps = torch.normal(mean = 0, std = 1, size=res.shape).to(x.device)
        sample = mu + torch.exp(.5*sigma) * eps 
        return sample, mu, sigma

class mapping(nn.Module):
    def __init__(self, input_size, theta_size, middle_size, layers):
        super(mapping, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(in_features=input_size, out_features=middle_size)] +
                                      [nn.Linear(in_features=middle_size, out_features=middle_size)
                                       for _ in range(layers - 1)])
        self.basis_parameters = nn.Linear(in_features=middle_size, out_features=theta_size)

    def forward(self, x): 
        block_input = x
        for layer in self.layers:
            block_input = torch.relu(layer(block_input))
        basis_parameters = self.basis_parameters(block_input)
        return basis_parameters


class Trend(nn.Module):
    """
    Polynomial function to model trend.
    """
    def __init__(self, degree_of_polynomial, backcast_size):
        super(Trend, self).__init__()
        self.polynomial_size = degree_of_polynomial + 1  # degree of polynomial with constant term
        self.backcast_time = nn.Parameter(
            torch.tensor(np.concatenate([np.power(np.arange(backcast_size, dtype=np.float) / backcast_size, i)[None, :]
                                     for i in range(self.polynomial_size)]), dtype=torch.float32), requires_grad=False)

        self.theta = nn.Linear(backcast_size, backcast_size)
        self.sigma = nn.Linear(backcast_size, backcast_size)
    def forward(self, theta):

        backcast = torch.einsum('bcp,pt->bct', theta[:, :, -self.polynomial_size:], self.backcast_time) # the extracted trendency features

        return backcast


class Seasonality(nn.Module):
    """
    Harmonic functions to model seasonality.
    """
    def __init__(self, harmonics, backcast_size):
        super(Seasonality, self).__init__()
        self.frequency = np.append(np.zeros(1, dtype=np.float32),
                                   np.arange(harmonics, harmonics /2 * backcast_size,
                                             dtype=np.float32) / harmonics)[None, :]
        backcast_grid = -2 * np.pi * (
                np.arange(backcast_size, dtype=np.float32)[:, None] /  backcast_size) * self.frequency
    
        self.backcast_cos_template = nn.Parameter(torch.tensor(np.transpose(np.cos(backcast_grid)), dtype=torch.float32),
                                                    requires_grad=False)
        self.backcast_sin_template = nn.Parameter(torch.tensor(np.transpose(np.sin(backcast_grid)), dtype=torch.float32),
                                                    requires_grad=False)

    def forward(self, theta):
        params_per_harmonic = theta.shape[2] // 4
        #print(theta.shape, self.backcast_cos_template.shape)
        backcast_harmonics_cos = torch.einsum('bcp,pt->bct', theta[:, :, :params_per_harmonic],
                                          self.backcast_cos_template)
        backcast_harmonics_sin = torch.einsum('bcp,pt->bct', theta[:, :,  1 * params_per_harmonic:2*params_per_harmonic], self.backcast_sin_template)
        backcast = backcast_harmonics_sin + backcast_harmonics_cos
        
        return backcast