import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import pmdarima as pm
import threading
import pandas as pd
from sklearn.ensemble import GradientBoostingRegressor
from statsmodels.tsa.holtwinters import ExponentialSmoothing
from statsmodels.tsa.api import VAR
# from fbprophet import Prophet
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C

class Naive_repeat(nn.Module):
    def __init__(self, configs):
        super(Naive_repeat, self).__init__()
        self.test_pred_len = configs.test_pred_len
        
    def forward(self, x):
        B,L,D = x.shape
        x = x[:,-1,:].reshape(B,1,D).repeat(1,self.test_pred_len,1)
        return x # [B, L, D]

class Naive_thread(threading.Thread):
    def __init__(self,func,args=()):
        super(Naive_thread,self).__init__()
        self.func = func
        self.args = args

    def run(self):
        self.results = self.func(*self.args)
    
    def return_result(self):
        threading.Thread.join(self)
        return self.results

def _arima(seq,test_pred_len,bt,i):
    model = pm.auto_arima(seq)
    forecasts = model.predict(test_pred_len) 
    return forecasts,bt,i

class Arima(nn.Module):
    """
    Extremely slow, please sample < 0.1
    """
    def __init__(self, configs):
        super(Arima, self).__init__()
        self.test_pred_len = configs.test_pred_len
        
    def forward(self, x):
        result = np.zeros([x.shape[0],self.test_pred_len,x.shape[2]])
        threads = []
        for bt,seqs in tqdm(enumerate(x)):
            for i in range(seqs.shape[-1]):
                seq = seqs[:,i]
                one_seq = Naive_thread(func=_arima,args=(seq,self.test_pred_len,bt,i))
                threads.append(one_seq)
                threads[-1].start()
        for every_thread in tqdm(threads):
            forcast,bt,i = every_thread.return_result()
            result[bt,:,i] = forcast

        return result # [B, L, D]

def _sarima(season,seq,test_pred_len,bt,i):
    model = pm.auto_arima(seq, seasonal=True, m=season)
    forecasts = model.predict(test_pred_len) 
    return forecasts,bt,i

class SArima(nn.Module):
    """
    Extremely extremely slow, please sample < 0.01
    """
    def __init__(self, configs):
        super(SArima, self).__init__()
        self.test_pred_len = configs.test_pred_len
        self.seq_len = configs.seq_len
        self.season = 24
        if 'Ettm' in configs.data_path:
            self.season = 12
        elif 'ILI' in configs.data_path:
            self.season = 1
        if self.season >= self.seq_len:
            self.season = 1

    def forward(self, x):
        result = np.zeros([x.shape[0],self.test_pred_len,x.shape[2]])
        threads = []
        for bt,seqs in tqdm(enumerate(x)):
            for i in range(seqs.shape[-1]):
                seq = seqs[:,i]
                one_seq = Naive_thread(func=_sarima,args=(self.season,seq,self.test_pred_len,bt,i))
                threads.append(one_seq)
                threads[-1].start()
        for every_thread in tqdm(threads):
            forcast,bt,i = every_thread.return_result()
            result[bt,:,i] = forcast
        return result # [B, L, D]

def _gbrt(seq,seq_len,test_pred_len,bt,i):
    model = GradientBoostingRegressor()
    model.fit(np.arange(seq_len).reshape(-1,1),seq.reshape(-1,1))
    forecasts = model.predict(np.arange(seq_len,seq_len+test_pred_len).reshape(-1,1))  
    return forecasts,bt,i

class GBRT(nn.Module):
    def __init__(self, configs):
        super(GBRT, self).__init__()
        self.seq_len = configs.seq_len
        self.test_pred_len = configs.test_pred_len
    
    def forward(self, x):
        result = np.zeros([x.shape[0],self.test_pred_len,x.shape[2]])
        threads = []
        for bt,seqs in tqdm(enumerate(x)):
            for i in range(seqs.shape[-1]):
                seq = seqs[:,i]
                one_seq = Naive_thread(func=_gbrt,args=(seq,self.seq_len,self.test_pred_len,bt,i))
                threads.append(one_seq)
                threads[-1].start()
        for every_thread in tqdm(threads):
            forcast,bt,i = every_thread.return_result()
            result[bt,:,i] = forcast
        return result # [B, L, D]
    
    
# Exponential Smoothing Model
def _ets(seq, test_pred_len, bt, i):
    model = ExponentialSmoothing(seq.detach().cpu().numpy() , seasonal='add', trend='add', seasonal_periods=12)
    model_fit = model.fit()
    forecasts = model_fit.forecast(test_pred_len)
    return forecasts, bt, i

class ETS(nn.Module):
    def __init__(self, configs):
        super(ETS, self).__init__()
        self.test_pred_len = configs.test_pred_len
    
    def forward(self, x):
        result = np.zeros([x.shape[0], self.test_pred_len, x.shape[2]])
        threads = []
        for bt, seqs in tqdm(enumerate(x)):
            for i in range(seqs.shape[-1]):
                seq = seqs[:, i]
                one_seq = Naive_thread(func=_ets, args=(seq, self.test_pred_len, bt, i))
                threads.append(one_seq)
                threads[-1].start()
        for every_thread in tqdm(threads):
            forcast, bt, i = every_thread.return_result()
            result[bt, :, i] = forcast
        return result # [B, L, D]

# VAR Model
def _var(seqs, test_pred_len, bt, i):
    # 转换为 numpy 并检查常数列
    seqs_np = seqs.detach().cpu().numpy()
    constant_columns = [col for col in range(seqs_np.shape[1]) if np.all(seqs_np[:, col] == seqs_np[0, col])]
    
    # 根据是否存在常数列选择趋势
    if constant_columns:
        # print(f"Constant columns detected: {constant_columns}, using trend='n'")
        model = VAR(seqs_np)
        model_fit = model.fit(trend='n')  # 无趋势
    else:
        model = VAR(seqs_np)
        model_fit = model.fit()  # 默认趋势
    
    forecast = model_fit.forecast(seqs_np[-5:], steps=test_pred_len)
    return forecast, bt, i


class VARModel(nn.Module):
    def __init__(self, configs):
        super(VARModel, self).__init__()
        self.seq_len = configs.seq_len
        self.test_pred_len = configs.test_pred_len
    
    def forward(self, x):
        result = np.zeros([x.shape[0], self.test_pred_len, x.shape[2]])
        threads = []
        for bt, seqs in tqdm(enumerate(x)):
            for i in range(seqs.shape[-1]):
                one_seq = Naive_thread(func=_var, args=(seqs, self.test_pred_len, bt, i))
                threads.append(one_seq)
                threads[-1].start()
        for every_thread in tqdm(threads):
            forcast, bt, i = every_thread.return_result()
            result[bt, :, i] = forcast[:, i]
        return result # [B, L, D]

# # Prophet Model
# def _prophet(seq, test_pred_len, bt, i):
#     df = pd.DataFrame({'ds': pd.date_range(start="2021-01-01", periods=len(seq), freq='D'), 'y': seq})
#     model = Prophet(yearly_seasonality=True, weekly_seasonality=True)
#     model.fit(df)
#     future = model.make_future_dataframe(df, periods=test_pred_len)
#     forecast = model.predict(future)
#     return forecast['yhat'][-test_pred_len:].values, bt, i

# class ProphetModel(nn.Module):
#     def __init__(self, configs):
#         super(ProphetModel, self).__init__()
#         self.test_pred_len = configs.test_pred_len
    
#     def forward(self, x):
#         result = np.zeros([x.shape[0], self.test_pred_len, x.shape[2]])
#         threads = []
#         for bt, seqs in tqdm(enumerate(x)):
#             for i in range(seqs.shape[-1]):
#                 seq = seqs[:, i]
#                 one_seq = Naive_thread(func=_prophet, args=(seq, self.test_pred_len, bt, i))
#                 threads.append(one_seq)
#                 threads[-1].start()
#         for every_thread in tqdm(threads):
#             forcast, bt, i = every_thread.return_result()
#             result[bt, :, i] = forcast
#         return result # [B, L, D]

# Gaussian Process Model
def _gpr(seq, seq_len, test_pred_len, bt, i):
    kernel = C(1.0, (1e-4, 1e1)) * RBF(1.0, (1e-4, 1e1))
    model = GaussianProcessRegressor(kernel=kernel)
    model.fit(np.arange(seq_len).reshape(-1,1), seq.detach().cpu().numpy().reshape(-1,1))
    forecast = model.predict(np.arange(seq_len, seq_len + test_pred_len).reshape(-1,1))
    return forecast, bt, i

class GPRModel(nn.Module):
    def __init__(self, configs):
        super(GPRModel, self).__init__()
        self.seq_len = configs.seq_len
        self.test_pred_len = configs.test_pred_len
    
    def forward(self, x):
        result = np.zeros([x.shape[0], self.test_pred_len, x.shape[2]])
        threads = []
        for bt, seqs in tqdm(enumerate(x)):
            for i in range(seqs.shape[-1]):
                seq = seqs[:, i]
                one_seq = Naive_thread(func=_gpr, args=(seq, self.seq_len, self.test_pred_len, bt, i))
                threads.append(one_seq)
                threads[-1].start()
        for every_thread in tqdm(threads):
            forcast, bt, i = every_thread.return_result()
            result[bt, :, i] = forcast
        return result # [B, L, D]