__all__ = ['DLinearR']


import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ..models.DLinear import DLinear
from ..models.RevIN import RevIN

class DLinearR(nn.Module):
    
    def __init__(self, configs):
        super().__init__()
        self.revin = configs.revin
        self.affine = configs.affine
        self.c_in = configs.enc_in
        self.model = DLinear(configs)
        if self.revin: self.revin_layer = RevIN(self.c_in, affine=self.affine)

    def forward(self, z):       # x: [Batch, Channel, Input length]
        if self.revin: 
            z = z.permute(0,2,1)
            z = self.revin_layer(z, 'norm')
            z = z.permute(0,2,1)
        z = self.model(z)
        if self.revin: 
            z = z.permute(0,2,1)
            z = self.revin_layer(z, 'denorm')
            z = z.permute(0,2,1)
        return z