import warnings
warnings.filterwarnings("ignore")

import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset

import sys 
sys.path.append("../")
import os
from utils.FeatureEngineering import FeatureEngineering

device = 'cuda'

SEED = 2023
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)

class moving_avg(nn.Module):
    """
    Moving average block to highlight the trend of time series
    """
    def __init__(self, kernel_size, stride):
        super(moving_avg, self).__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        # padding on the both ends of time series
        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        return x


class series_decomp(nn.Module):
    """
    Series decomposition block
    """
    def __init__(self, kernel_size):
        super(series_decomp, self).__init__()
        self.moving_avg = moving_avg(kernel_size, stride=1)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean

    
class DLinear(nn.Module):
    """
    Decomposition-Linear
    """
    def __init__(self, in_features, out_features, args):
        super(DLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.seq_len = args.seq_len
        self.pred_len = args.pred_len

        # Decompsition Kernel Size
        kernel_size = 25
        self.decompsition = series_decomp(kernel_size)
        self.individual = False
        self.channels = None
        
        self.Linear_Seasonal = nn.Linear(self.seq_len,self.pred_len)
        self.Linear_Trend = nn.Linear(self.seq_len,self.pred_len)
        self.out = nn.Linear(self.in_features, self.out_features)

        # Use this two lines if you want to visualize the weights
        #self.Linear_Seasonal.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
        #self.Linear_Trend.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))

    def forward(self, x):
        # x: [Batch, Input length, Channel]
        seasonal_init, trend_init = self.decompsition(x)
        seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1)
        
        seasonal_output = self.Linear_Seasonal(seasonal_init)
        trend_output = self.Linear_Trend(trend_init)

        x = seasonal_output + trend_output
        x = x[:, -self.out_features: , :]
        #print(x.shape)
        #x = self.out(x.permute(0,2,1)).permute(0,2,1)
        return x.permute(0,2,1) # to [Batch, Output length, Channel]
    
class NLinear(nn.Module):
    """
    Normalization-Linear
    """
    def __init__(self, in_features, out_features, args):
        super(NLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.seq_len = args.seq_len
        self.pred_len = args.pred_len
        
        # Use this line if you want to visualize the weights        
        self.Linear = nn.Linear(self.seq_len, self.pred_len)
        #self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
        
        self.out = nn.Linear(self.in_features, self.out_features)
        
    def forward(self, x):
        # x: [Batch, Input length, Channel]
        seq_last = x[:,-1:,:].detach()
        x = x - seq_last
        x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
        #x = x + seq_last
        x = x[:, :, -self.out_features:] + seq_last[:, :, -self.out_features:] #.permute(0,2,1)
        #x = self.out(x)
        return x # [Batch, Output length, Channel]