import torch
import torch.nn as nn
from .decomp import decomp_block, Trend, Seasonality
import numpy as np


def soft_clamp5(x: torch.Tensor):
    return x.div(5.).tanh_().mul(5.)  

def soft_clamp3(x: torch.Tensor):
    return x.div(1.).tanh_().mul(1.) 

class ma_decomp(nn.Module):
    """
    Series decomposition block
    """
    def __init__(self, kernel_size):
        super(ma_decomp, self).__init__()
        self.moving_avg = moving_avg(kernel_size, stride=1)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean

class decomp_one(nn.Module):
    def __init__(self, input_size, theta_size, middle_size, layers, 
                                degree_of_polynomial, harmonics, trend_layers, season_layers, norm=False):
        super(decomp_one, self).__init__()
        self.trend_block = nn.ModuleList([decomp_block(input_size=input_size, theta_size=2 * (degree_of_polynomial + 1), middle_size=middle_size, 
                                            layers=layers, degree_of_polynomial=degree_of_polynomial, harmonics=harmonics, function=Trend(harmonics, input_size)) for _ in range(trend_layers)])
        self.season_block = nn.ModuleList([decomp_block(input_size=input_size, theta_size = 4 * int(np.ceil(harmonics / 2 * input_size) - (harmonics - 1)), 
                                            middle_size = middle_size, layers=layers, degree_of_polynomial=degree_of_polynomial, harmonics=harmonics, function=Seasonality(harmonics, input_size))
                                            for _ in range(season_layers)])
        self.decomp_block = nn.ModuleList([self.trend_block for _ in range(trend_layers)] +[self.season_block for _ in range(season_layers)])
        self.softmax = nn.Softmax()
        self.norm = norm
        self.norm1 = nn.LayerNorm(input_size)
        self.norm2 = nn.LayerNorm(input_size)
    
    def cal_kl(self, mu1, sigma1, mu2, sigma2):
        mu1 = soft_clamp3(mu1)
        mu2 = soft_clamp3(mu2)
        sigma1 = soft_clamp3(sigma1)
        sigma2 = soft_clamp3(sigma2)

        term1 = (mu1-mu2)
        term2 = torch.exp(sigma1 - sigma2)
        kl = 0.5*(term1*term1/torch.exp(sigma2) + term2) - 0.5 - 0.5*sigma2
        return kl
    
    def cal_kl_p(self, samples, mu, sigma):
        mu = soft_clamp3(mu)
        sigma = soft_clamp3(sigma)
        normalized_samples = (samples - mu)
        log_p = - 0.5 * normalized_samples * normalized_samples/torch.exp(sigma) - 0.5 * np.log(2 * np.pi) - 0.5*sigma
        return log_p

    def forward(self, x):
        forecast_t = torch.zeros_like(x)
        forecast_s = torch.zeros_like(x)
        residual = x
        for layer in self.trend_block:
            fore_t, mu_trend, sigma_trend = layer(residual)
            forecast_t +=  fore_t
            residual = residual - fore_t
        
        residual = torch.relu(residual)
        for layer in self.season_block:
            fore_s, mu_season, sigma_season = layer(residual)
            forecast_s += fore_s
            residual = residual - fore_s
        residual = torch.relu(residual)
        
        feature = fore_s + fore_t
        if self.norm:
            feature = self.norm2(feature)
        
        kl_q = self.cal_kl(mu_trend,sigma_trend, mu_season, sigma_season)
        
        kl_q = self.norm1(kl_q)
        return feature, residual, kl_q, feature

class decoder(nn.Module):
    def __init__(self, input_size, theta_size, middle_size, layers, 
                                degree_of_polynomial, harmonics, trend_layers, season_layers,  decomp_layers, norm):
        super(decoder, self).__init__()
        #self.ma_decomp = ma_decomp(kernel_size)
        self.decomp_block = decomp_one(input_size, theta_size, middle_size, layers, 
                                degree_of_polynomial, harmonics, trend_layers, season_layers, norm)
        self.decomp_blocks = nn.ModuleList([self.decomp_block for _ in range(decomp_layers)])
        #self.decomp_block_trend = nn.ModuleList([self.trend_block for _ in range(trend_layers)])
        #self.decomp_block_season = nn.ModuleList([self.season_block for _ in range(season_layers)])
 
    def forward(self, x):

        kl_q_sum = 0
        for i, decomp_layer in enumerate(self.decomp_blocks):
            #print(x.device)
            fore, x, kl_q, feature = decomp_layer(x)
            #x = torch.relu(x)
            if i ==0:
                latent_variables = feature.unsqueeze(0)
                forecast = fore
                kl_q_sum = kl_q.mean()
            else:
                latent_variables = torch.cat([latent_variables, feature.unsqueeze(0)], dim=0)
                forecast = forecast + fore
                kl_q_sum += kl_q.mean()
        return forecast, x, kl_q_sum, latent_variables
