import torch
import torch.nn as nn
import torch.nn.functional as F
import math


def get_sinusoid(N, thetas, ns_s, s_s):
    if N.dim() == 2:
        # N.shape: H x length
        thetas = thetas.view(-1, 1).repeat(1, N.size(1))
        inner = torch.einsum('hl, hl -> hl', N, thetas)
    elif N.dim() == 3:
        # N.shape: H x len1 x len2
        thetas = thetas.view(-1, 1, 1).repeat(1, N.size(1), N.size(2))
        inner = torch.einsum('hij, hij -> hij', N, thetas)
        
    ns_cos = inner[:ns_s, ].cos()
    ns_sin = inner[ns_s:2*ns_s, ].sin()
    s_cos = inner[2*ns_s:2*ns_s+s_s, ].cos()
    s_sin = inner[2*ns_s+s_s:, ].sin()

    return ns_cos, ns_sin, s_cos, s_sin
    
def sinusoid_func(N, ns_thetas, s_thetas, ns_s, s_s, ns_r, s_r, device):
    # N.size(0) = self.ns_r + 2*self.ns_s + self.s_r + 2*self.s_s
        
    N_inp = torch.cat((N[ns_r : ns_r+2*ns_s, ],N[2*ns_s+ns_r+s_r :, ]), dim=0)
    thetas = torch.cat((ns_thetas.repeat(2), s_thetas.repeat(2)), dim=0)
    ns_cos, ns_sin, s_cos, s_sin = get_sinusoid(N_inp, thetas, ns_s, s_s)

    if N.dim() == 2:
        out = torch.cat((torch.ones((ns_r, N.size(1)), device=device), ns_cos, ns_sin, 
                                torch.ones((s_r, N.size(1)), device=device), s_cos, s_sin), dim=0)
    elif N.dim() == 3:
        out = torch.cat((torch.ones((ns_r, N.size(1), N.size(2)), device=device), ns_cos, ns_sin, 
                              torch.ones((s_r, N.size(1), N.size(2)), device=device), s_cos, s_sin), dim=0)
        
    return out

def toeplitz2D(x1, x2, device):
        # x1: H x L
        n_heads = x1.size(0)
        vals = torch.cat((x2, x1[:, 1:].flip(-1)), dim=-1)
        shape = x1.size(-1), x2.size(-1)
        i, j = torch.ones(*shape, device=device).nonzero().T
        return vals[:, j-i].reshape(n_heads, *shape)


class RNN_SPosEmbed_Func(torch.autograd.Function):
    @staticmethod
    def forward(ctx, ns_lamvdas, ns_gammas, ns_thetas, s_lamvdas, s_gammas, s_thetas,
                length, ns_H, ns_s, s_H, s_s, period,
                mask_flag=True, trunc=200, device=torch.device('cuda:0')):
        
        # prepare power: 1 x L
        power = torch.arange(length, device=device).unsqueeze(0)
        # 1 x L x L
        power = toeplitz2D(power, power, device=device)
        # construct N
        N = power.repeat(ns_H + s_H, 1, 1)

        # period
        period = period.view(-1, 1, 1).repeat(1, length, length)
        period_mask = (N.remainder(period) != 0)
        N = N.div(period).masked_fill(period_mask, 0)
        
        # trunc_mask
        trunc_mask = (N != N.clamp(max=trunc))
        N = N.masked_fill(trunc_mask, 0)    
        
        # get sinusoid
        sinusoid = sinusoid_func(N, ns_thetas, s_thetas, ns_s, s_s, ns_H-2*ns_s, s_H-2*s_s, device)

        base = torch.cat((ns_lamvdas, ns_gammas.repeat(2), s_lamvdas, s_gammas.repeat(2))).view(-1, 1, 1).repeat(1, length, length)
        const = base.pow(N)

        # deal with truncation and seasonal period
        const = const.masked_fill(trunc_mask, 0.0)
        const = const.masked_fill(period_mask, 0.0)
        
        P = torch.einsum('hij, hij -> hij', const, sinusoid).unsqueeze(0)

        if mask_flag:
            P = P.tril(-1)
        else:
            P = P.tril(-1) + P.triu(1)

        ctx.save_for_backward(N, base, sinusoid, P.clone().detach())
        ctx.mask_flag = mask_flag
        ctx.device = device
        ctx.ns_H = ns_H
        ctx.ns_s = ns_s
        ctx.s_H = s_H
        ctx.s_s = s_s
      
        # B x H x L x L
        return P
    
    @staticmethod
    def backward(ctx, grad_output):
        N, base, sinusoid, P = ctx.saved_tensors

        ns_r = ctx.ns_H - 2*ctx.ns_s
        s_r = ctx.s_H - 2*ctx.s_s

        mN = N.add(-1).clamp(min=0)
        const = base.pow(mN)
        # H x LQ x LK
        grad_base = torch.einsum('hij, hij -> hij', const*N, sinusoid).unsqueeze(0)

        if ctx.mask_flag:
            grad_base = grad_base.tril(-1)
        else:
            grad_base = grad_base.tril(-1) + grad_base.triu(1) 

        # grad_thetas
        newP = P.clone().detach()
        newP[:, ns_r:ns_r+ctx.ns_s,:,:] = P[:, ns_r+ctx.ns_s:ctx.ns_H,:,:]*(-1)
        newP[:, ns_r+ctx.ns_s:ctx.ns_H,:,:] = P[:, ns_r:ns_r+ctx.ns_s,:,:]
        newP[:, ctx.ns_H+s_r:ctx.ns_H+s_r+ctx.s_s,:,:] = P[:, ctx.ns_H+s_r+ctx.s_s:ctx.ns_H+ctx.s_H, :, :]*(-1)
        newP[:, ctx.ns_H+s_r+ctx.s_s:ctx.ns_H+ctx.s_H, :, :] = P[:, ctx.ns_H+s_r:ctx.ns_H+s_r+ctx.s_s,:,:]
        grad_thetas = torch.einsum('bhls, bhls -> bhls', N.unsqueeze(0), newP)

        grad_base = (grad_base * grad_output).sum((0,2,3)) 
        grad_ns_lamvdas = grad_base[:ns_r,]
        grad_ns_gammas = grad_base[ns_r:ns_r+ctx.ns_s] + grad_base[ns_r+ctx.ns_s:ctx.ns_H]
        grad_s_lamvdas = grad_base[ctx.ns_H: ctx.ns_H+s_r]
        grad_s_gammas = grad_base[ctx.ns_H+ctx.s_H-2*ctx.s_s:ctx.ns_H+ctx.s_H-ctx.s_s] + grad_base[ctx.ns_H+ctx.s_H-ctx.s_s:ctx.ns_H+ctx.s_H]

        grad_thetas = (grad_thetas * grad_output).sum((0,2,3))
        grad_ns_thetas = grad_thetas[ns_r:ns_r+ctx.ns_s] + grad_thetas[ns_r+ctx.ns_s:ctx.ns_H]
        grad_s_thetas = grad_thetas[ctx.ns_H+ctx.s_H-2*ctx.s_s:ctx.ns_H+ctx.s_H-ctx.s_s] + grad_thetas[ctx.ns_H+ctx.s_H-ctx.s_s:ctx.ns_H+ctx.s_H]

        return grad_ns_lamvdas, grad_ns_gammas, grad_ns_thetas, grad_s_lamvdas, grad_s_gammas, grad_s_thetas, None, None, None, None, None, None, None, None, None


class RNN_SMixPosEmbed_Func(torch.autograd.Function):
    @staticmethod
    def forward(ctx, ns_lamvdas, ns_gammas, ns_thetas, s_lamvdas, s_gammas, s_thetas,
                seq_len, label_len, pred_len, ns_H, ns_s, s_H, s_s, period,
                mask_flag=True, trunc=200, device=torch.device('cuda:0')):

        # prepare power
        N1 = torch.arange(seq_len - label_len, seq_len + pred_len, device=device)
        N2 = torch.arange(seq_len, device=device)
        # power.shape: 1 x label_len + pred_len x seq_len
        power = (N1[:, None] + (-1)*N2[None, :]).abs().unsqueeze(0)
        # N.shape = H x label_len + pred_len x seq_len
        N = power.repeat(ns_H + s_H, 1, 1)

        # period
        period = period.view(-1, 1, 1).repeat(1, N1.size(0), N2.size(0))
        period_mask = (N.remainder(period) != 0)
        N = N.div(period).masked_fill(period_mask, 0)

        # trunc mask
        trunc_mask = (N != N.clamp(max=trunc))
        N = N.masked_fill(trunc_mask, 0)
        
        # sinusoid
        sinusoid = sinusoid_func(N, ns_thetas, s_thetas, ns_s, s_s, ns_H-2*ns_s, s_H-2*s_s, device)

        base = torch.cat((ns_lamvdas, ns_gammas.repeat(2), s_lamvdas, s_gammas.repeat(2))).view(-1, 1, 1).repeat(1, N1.size(0), N2.size(0))
        const = base.pow(N)
        # deal with truncation and seasonal period
        const = const.masked_fill(trunc_mask, 0.0)
        const = const.masked_fill(period_mask, 0.0)
        
        P = torch.einsum('hij, hij -> hij', const, sinusoid).unsqueeze(0)

        diag = seq_len-label_len-1
        if mask_flag:
            P = P.tril(diag)
        else:
            P = P.tril(diag) + P.triu(diag+2)    

        # save for backward
        # shape: H x LQ x LK
        ctx.save_for_backward(N, base, sinusoid, P.clone().detach())
        ctx.mask_flag = mask_flag
        ctx.diag = diag
        ctx.device = device
        ctx.ns_H = ns_H
        ctx.ns_s = ns_s
        ctx.s_H = s_H
        ctx.s_s = s_s
        
        # B x H x LQ X LK
        return P
    
    @staticmethod
    def backward(ctx, grad_output):
        N, base, sinusoid, P = ctx.saved_tensors

        ns_r = ctx.ns_H - 2*ctx.ns_s
        s_r = ctx.s_H - 2*ctx.s_s

        mN = N.add(-1).clamp(min=0)
        const = base.pow(mN)
        # H x LQ x LK
        grad_base = torch.einsum('hij, hij -> hij', const*N, sinusoid).unsqueeze(0)

        if ctx.mask_flag:
            grad_base = grad_base.tril(ctx.diag)
        else:
            grad_base = grad_base.tril(ctx.diag) + grad_base.triu(ctx.diag+2) 

        # grad_thetas
        newP = P.clone().detach()
        newP[:, ns_r:ns_r+ctx.ns_s,:,:] = P[:, ns_r+ctx.ns_s:ctx.ns_H,:,:]*(-1)
        newP[:, ns_r+ctx.ns_s:ctx.ns_H,:,:] = P[:, ns_r:ns_r+ctx.ns_s,:,:]
        newP[:, ctx.ns_H+s_r:ctx.ns_H+s_r+ctx.s_s,:,:] = P[:, ctx.ns_H+s_r+ctx.s_s:ctx.ns_H+ctx.s_H, :, :]*(-1)
        newP[:, ctx.ns_H+s_r+ctx.s_s:ctx.ns_H+ctx.s_H, :, :] = P[:, ctx.ns_H+s_r:ctx.ns_H+s_r+ctx.s_s,:,:]
        grad_thetas = torch.einsum('bhls, bhls -> bhls', N.unsqueeze(0), newP)

        grad_base = (grad_base * grad_output).sum((0,2,3)) 
        grad_ns_lamvdas = grad_base[:ns_r,]
        grad_ns_gammas = grad_base[ns_r:ns_r+ctx.ns_s] + grad_base[ns_r+ctx.ns_s:ctx.ns_H]
        grad_s_lamvdas = grad_base[ctx.ns_H: ctx.ns_H+s_r]
        grad_s_gammas = grad_base[ctx.ns_H+ctx.s_H-2*ctx.s_s:ctx.ns_H+ctx.s_H-ctx.s_s] + grad_base[ctx.ns_H+ctx.s_H-ctx.s_s:ctx.ns_H+ctx.s_H]

        grad_thetas = (grad_thetas * grad_output).sum((0,2,3))
        grad_ns_thetas = grad_thetas[ns_r:ns_r+ctx.ns_s] + grad_thetas[ns_r+ctx.ns_s:ctx.ns_H]
        grad_s_thetas = grad_thetas[ctx.ns_H+ctx.s_H-2*ctx.s_s:ctx.ns_H+ctx.s_H-ctx.s_s] + grad_thetas[ctx.ns_H+ctx.s_H-ctx.s_s:ctx.ns_H+ctx.s_H]


        return grad_ns_lamvdas, grad_ns_gammas, grad_ns_thetas, grad_s_lamvdas, grad_s_gammas, grad_s_thetas, None, None, None, None, None, None, None, None, None, None, None


class RNN_SEmbed(nn.Module):
    def __init__(self, ns_H, ns_s, s_H, s_s, period, mask_flag=True, trunc=200, device=torch.device('cuda:0')):
        super(RNN_SEmbed, self).__init__()
        '''
        ns_H: number of non-seasonal heads
        ns_s: number of non-seasonal s
        s_H: number of seasonal heads
        s_s: number of seasonal s
        '''

        self.ns_r = ns_H - 2*ns_s
        self.s_r = s_H - 2*s_s
        self.ns_s = ns_s
        self.s_s = s_s 

        self.ns_H = ns_H  
        self.s_H = s_H     
        self.H = ns_H + s_H 

        self.mask_flag = mask_flag
        self.device = device
    
        # period: list; len = s_H
        # first lambda, then cos period, then sin period
        # convert to: one-dimensional tensor and add non-seaonal period (1)
        self.period = torch.as_tensor(period, device=self.device)
        self.period = torch.cat((torch.ones(self.ns_H, device=self.device), self.period))
        assert self.period.size(0) == self.H, "period size is incompatible with seasonal heads"
        # truncation
        self.trunc = trunc

        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

class RNN_SPosEmbed(RNN_SEmbed):
    # the final shape should be [B, H, L, L]
    def __init__(self, *args, **kwargs):
        super(RNN_SPosEmbed, self).__init__(*args, **kwargs)

        # seasonal parameters
        if self.s_r > 1:
            if self.s_r % 2 == 0:
                half = torch.linspace(-2, -1, self.s_r//2, device=self.device)
                self.s_etas = nn.Parameter(torch.cat((half, -1*half)))
            else:
                half= torch.linspace(-2, -1, self.s_r//2+1, device=self.device)
                self.s_etas = nn.Parameter(torch.cat((half, -1*half[:self.s_r//2])))
        else:
            self.s_etas = nn.Parameter(torch.ones(self.s_r, device=self.device))
        self.s_nus = nn.Parameter(torch.linspace(1, 2, self.s_s, device=self.device))
        self.s_thetas = nn.Parameter(torch.ones(self.s_s, device=self.device)*(math.pi/4))
        # self.s_thetas = nn.Parameter(torch.ones(self.s_s, device=self.device)*2)

        # non-seasonal parmeters
        if self.ns_r > 1:
            if self.ns_r % 2 == 0:
                half = torch.linspace(-2, -1, self.ns_r//2, device=self.device)
                self.ns_etas = nn.Parameter(torch.cat((half, -1*half)))
            else:
                half = torch.linspace(-2, -1, self.ns_r//2+1, device=self.device)
                self.ns_etas = nn.Parameter(torch.cat((half, -1*half[:self.ns_r//2])))
        else:
            self.ns_etas = nn.Parameter(torch.ones(self.ns_r, device=self.device))
        self.ns_nus = nn.Parameter(torch.linspace(1, 2, self.ns_s, device=self.device))
        self.ns_thetas = nn.Parameter(torch.ones(self.ns_s, device=self.device)*(math.pi/4))
        # self.ns_thetas = nn.Parameter(torch.ones(self.ns_s, device=self.device)*2)


        # non-seasonal parameters
        self.ns_lamvdas = self.tanh(self.ns_etas)
        self.ns_gammas = self.sigmoid(self.ns_nus)

        # seasonal parameters
        self.s_lamvdas = self.tanh(self.s_etas)
        self.s_gammas = self.sigmoid(self.s_nus)

        # # FOR TESTING
        # self.s_lamvdas = torch.ones(self.s_r, device=self.device)*2
        # self.s_gammas = torch.ones(self.s_s, device=self.device)*2
        # self.ns_lamvdas = torch.ones(self.ns_r, device=self.device)*2
        # self.ns_gammas = torch.ones(self.ns_s, device=self.device)*2

        self.fn = RNN_SPosEmbed_Func.apply

    def forward(self, length):
        return self.fn(self.ns_lamvdas, self.ns_gammas, self.ns_thetas, self.s_lamvdas, self.s_gammas, self.s_thetas,
                length, self.ns_H, self.ns_s, self.s_H, self.s_s, self.period,
                self.mask_flag, self.trunc, self.device)
    

class RNN_SMixPosEmbed(RNN_SEmbed):
    def __init__(self, *args, **kwargs):
        super(RNN_SMixPosEmbed, self).__init__(*args, **kwargs)
        
        # seasonal parameters
        if self.s_r > 1:
            if self.s_r % 2 == 0:
                half = torch.linspace(-2, -1, self.s_r//2, device=self.device)
                self.s_etas = nn.Parameter(torch.cat((half, -1*half)))
            else:
                half= torch.linspace(-2, -1, self.s_r//2+1, device=self.device)
                self.s_etas = nn.Parameter(torch.cat((half, -1*half[:self.s_r//2])))
        else:
            self.s_etas = nn.Parameter(torch.ones(self.s_r, device=self.device))
        self.s_nus = nn.Parameter(torch.linspace(1, 2, self.s_s, device=self.device))
        self.s_thetas = nn.Parameter(torch.ones(self.s_s, device=self.device)*(math.pi/4))
        # self.s_thetas = nn.Parameter(torch.ones(self.s_s, device=self.device)*2)

        # non-seasonal parmeters
        if self.ns_r > 1:
            if self.ns_r % 2 == 0:
                half = torch.linspace(-2, -1, self.ns_r//2, device=self.device)
                self.ns_etas = nn.Parameter(torch.cat((half, -1*half)))
            else:
                half = torch.linspace(-2, -1, self.ns_r//2+1, device=self.device)
                self.ns_etas = nn.Parameter(torch.cat((half, -1*half[:self.ns_r//2])))
        else:
            self.ns_etas = nn.Parameter(torch.ones(self.ns_r, device=self.device))
        self.ns_nus = nn.Parameter(torch.linspace(1, 2, self.ns_s, device=self.device))
        self.ns_thetas = nn.Parameter(torch.ones(self.ns_s, device=self.device)*(math.pi/4))
        # self.ns_thetas = nn.Parameter(torch.ones(self.ns_s, device=self.device)*2) 

         # non-seasonal parameters
        self.ns_lamvdas = self.tanh(self.ns_etas)
        self.ns_gammas = self.sigmoid(self.ns_nus)

        # seasonal parameters
        self.s_lamvdas = self.tanh(self.s_etas)
        self.s_gammas = self.sigmoid(self.s_nus)

        # # FOR TESTING
        # self.s_lamvdas = torch.ones(self.s_r, device=self.device)*2
        # self.s_gammas = torch.ones(self.s_s, device=self.device)*2
        # self.ns_lamvdas = torch.ones(self.ns_r, device=self.device)*2
        # self.ns_gammas = torch.ones(self.ns_s, device=self.device)*2

        self.fn = RNN_SMixPosEmbed_Func.apply
        # self.fn = RNN_SMixPosEmbed_Func()

    def forward(self, seq_len, label_len, pred_len):
        return self.fn(self.ns_lamvdas, self.ns_gammas, self.ns_thetas, self.s_lamvdas, self.s_gammas, self.s_thetas,
                seq_len, label_len, pred_len, self.ns_H, self.ns_s, self.s_H, self.s_s, self.period,
                self.mask_flag, self.trunc, self.device)