import asgl
import time
import argparse
import warnings
import numpy as np
import pandas as pd
import scipy.linalg as linalg
import matplotlib.pyplot as plt
from sklearn.linear_model import Lasso
from sklearn.metrics import mean_absolute_error, mean_squared_error
from statsmodels.tsa.seasonal import STL
warnings.filterwarnings('ignore')

def quantile_regression_error(true, pred):
    return np.sum(pred >= true) / pred.size

def positive_mean_absolute_error(true, pred):
    mask = pred >= true
    return np.mean(pred[mask] - true[mask])

def positive_mean_squared_error(true, pred):
    mask = pred >= true
    return np.mean((pred[mask] - true[mask]) ** 2)

def FFT(t, x, max_k=100, norm=False):
    if not norm:
        x = (x - np.mean(x)) / np.std(x)
    start_t = min(t)
    end_t = max(t)
    n = len(t)
    sample_freq = (end_t - start_t) / n
    fft_x = np.fft.rfft(x)
    fft_x_abs = np.abs(fft_x)
    freq = np.fft.fftfreq(len(x), sample_freq)
    T_max_k = 1 / (freq[fft_x_abs.argsort()[::-1][:max_k]])
    return T_max_k


def basis(t, T_max_k, max_k=100):
    """
    .. math ::
    [\sin(2\pi t/ T_max_k[i]) , \cos(2\pi t/ T_max_k[i]),...]
    """
    my_basis = []
    for i in range(max_k):
        my_basis.append(np.sin(2 * np.pi * t / T_max_k[i]))
        my_basis.append(np.cos(2 * np.pi * t / T_max_k[i]))
    return np.array(my_basis)


def lasso_identification(x, y, train_rate=0.7, def_alpha=1e-2, it=10000, norm=False):
    train_len = int(train_rate * len(y))
    y_mean = np.mean(y[:train_len])
    y_std = np.std(y[:train_len])
    if not norm:
        y = (y - y_mean) / y_std
    x_train = x[:, :train_len]
    y_train = y[:train_len]
    lasso = Lasso(alpha=def_alpha, max_iter=it)
    lasso.fit(np.transpose(x_train), y_train.reshape(-1, 1))
    coef_lasso = np.hstack((lasso.coef_, lasso.intercept_))
    x_ones = np.ones((1, len(y)))
    x = np.vstack((x, x_ones))
    y_pre = coef_lasso @ x
    y_pre = y_pre * y_std + y_mean
    return y_pre


def knowledge_fusion(x, x_pre_SI, predict_len, alpha_ft, lambda_ft):
    x_new = np.zeros_like(x)
    alpha = alpha_ft
    lammbda = lambda_ft
    beta = 0
    gamma = 1 / 2
    sep_beta = int(beta * len(x))
    x_new[:sep_beta] = alpha * x[:sep_beta] + (1 - alpha) * x_pre_SI[:sep_beta]
    weights = (np.linspace(alpha ** (1 / gamma), lammbda ** (1 / gamma), len(x) - sep_beta - predict_len)) ** (gamma)
    x_new[sep_beta:-predict_len] = np.multiply(weights, x[sep_beta:-predict_len]) + np.multiply(
        1 - weights, x_pre_SI[sep_beta:-predict_len])
    x_new[-predict_len:] = x_pre_SI[-predict_len:]
    return x_new

def qr(x, y, train_rate=0.7, tau=0.99, Lambda=None, norm=False):
    train_len = int(train_rate * len(y))
    y_mean = np.mean(y[:train_len])
    y_std = np.std(y[:train_len])
    if not norm:
        y = (y - y_mean) / y_std
    x_train = x[:train_len]
    y_train = y[:train_len]
    f = asgl.ASGL(model="qr", penalization=None, tau=tau, lambda1=Lambda)
    f.fit(x_train, y_train)
    y_pre = f.predict(x)[0]
    y_pre = y_pre * y_std + y_mean
    return y_pre

def stl(x_new, train_rate_new, b, plot=False):
    stl_result = STL(pd.Series(x_new), period=12).fit()
    trend, seasonal, residual = stl_result.trend, stl_result.seasonal, stl_result.resid
    x_pre_SI_new_trend = lasso_identification(
        b,
        trend.values, train_rate=train_rate_new, def_alpha=2e-5, it=1000)
    x_pre_SI_new_seasonal = lasso_identification(
        b,
        seasonal.values, train_rate=train_rate_new, def_alpha=2e-5, it=1000)
    x_pre_SI_new = x_pre_SI_new_trend + x_pre_SI_new_seasonal + residual
    if plot:
        plt.figure(figsize=(12, 8))
        plt.subplot(4, 1, 1)
        plt.plot(x_new, label='Original')
        plt.legend()
        plt.subplot(4, 1, 2)
        plt.plot(trend, label='Trend')
        plt.legend()
        plt.subplot(4, 1, 3)
        plt.plot(seasonal, label='Seasonal')
        plt.legend()
        plt.subplot(4, 1, 4)
        plt.plot(residual, label='Residual')
        plt.legend()
        plt.show()
    return x_pre_SI_new

def calc_g(Q, Y):
    n_params = Q.shape[1]
    g = np.full((n_params, 1), np.nan)
    for i_param in range(n_params):
        g[i_param] = np.dot(Y.T, Q[:, i_param]) / np.dot(Q[:, i_param].T, Q[:, i_param])
    return g

def calc_err(g, Q, Y):
    sigma = np.dot(Y.T, Y)
    n_params = Q.shape[1]
    err = np.full_like(g, np.nan)
    for i_param in range(n_params):
        err[i_param] = g[i_param] ** 2 * np.dot(Q[:, i_param].T, Q[:, i_param]) / sigma
    return err
def orthogonalize(P):
    (Q, R) = np.linalg.qr(P, mode='reduced')
    return Q
def remove_B_columns_from_A(A, B):
    n_params = A.shape[1]
    n_cols = B.shape[1]
    A_new = A.copy()
    for i_param in range(n_params):
        for i_col in range(n_cols):
            A_new[:, i_param] -= np.dot(B[:, i_col].T, A[:, i_param]) / np.dot(B[:, i_col].T, B[:, i_col]) * B[:, i_col]
    return A_new
def solve_triangular(A, B):
    x = linalg.solve_triangular(A, B, unit_diagonal=True)
    return x
def SOR(Y, P, pcols=None, term_thresh=1.0e-3, max_steps=1000):
    l_list = []
    g_list = []
    Q_list = []
    err_list = []
    if pcols is None:
        pcols = {i: 'x%d' % (i + 1) for i in range(P.shape[1])}

    # Mask
    selected = np.zeros((P.shape[1],), dtype=bool)
    esr = 1.
    i_step = 0
    while esr > term_thresh and i_step < P.shape[1]:
        selected[l_list] = True
        unselected = np.logical_not(selected)
        unselected_inds = np.flatnonzero(unselected)
        if i_step == 0:
            Q = P.copy()
            g = calc_g(Q, Y)
            err = calc_err(g, Q, Y)
            l = np.argmax(err)
            l_list.append(l)
            g_list.append(g[l])
            Q_list.append(Q[:, l])
            err_list.append(err[l])
        else:
            Q = orthogonalize(P[:, selected])
            thisQ = remove_B_columns_from_A(P[:, unselected], Q)
            g = calc_g(thisQ, Y)
            err = calc_err(g, thisQ, Y)
            thisl = np.argmax(err)
            l = unselected_inds[thisl]
            l_list.append(l)
            g_list.append(g[thisl])
            Q_list.append(thisQ[:, thisl])
            err_list.append(err[thisl])
        esr = 1 - np.sum(np.array(err_list).flatten())
        i_step += 1
    Q_f = np.column_stack(Q_list)
    g_f = np.array(g_list).reshape(-1, 1)
    err_f = np.array(err_list).reshape(-1, 1)
    n_identified = Q_f.shape[1]
    A = np.eye(n_identified)
    Pl = P[:, l_list]
    for s in range(n_identified):
        for r in range(s):
            if r == s:
                A[r, s] = 1.
            else:
                A[r, s] = np.dot(Q_f[:, r].T, Pl[:, s]) / np.dot(Q_f[:, r].T, Q_f[:, r])
    coef = solve_triangular(A, g_f)
    pred_Y = np.dot(Pl, coef)
    full_coef = np.zeros((P.shape[1], 1))
    full_err = np.zeros((P.shape[1], 1))
    full_coef[l_list] = coef
    full_err[l_list] = err_f
    return full_coef, full_err, pred_Y

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
    parser.add_argument('--features', type=str, default='M',
                        help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
    parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
    parser.add_argument('--train_len', type=int, default=96, help='input sequence length')
    parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length')
    parser.add_argument('--fourier_len', type=int, default=168, help='fourier sequence length')
    parser.add_argument('--tau', type=float, default=0.9, help='the quantile')
    parser.add_argument('--Lambda', type=float, default=0.1,
                        help='the value of the penalty parameter(s) that determine how much shrinkage is done. ')
    parser.add_argument('--max_K', type=int, default=1, help='')
    parser.add_argument('--need_step', type=int, default=1000, help='')
    parser.add_argument('--alpha_ft', type=float, default=0.8, help='')
    parser.add_argument('--lambda_ft', type=float, default=1, help='')
    parser.add_argument('--plot', action='store_true', default=False, help='')
    args = parser.parse_args()
    train_rate = 0.8
    test_rate = 1 - train_rate

    if 'ETTh' in args.data_path:
        df = pd.read_csv('./dataset/' + args.data_path)[:20 * 30 * 24]
    elif 'ETTm' in args.data_path:
        df = pd.read_csv('./dataset/' + args.data_path)[:20 * 30 * 24 * 4]
    else:
        df = pd.read_csv('./dataset/' + args.data_path)
    df['date'] = pd.to_datetime(df['date'])
    df.set_index('date', inplace=True)
    n_t, n_features = df.shape
    is_weekend = df.index.weekday.isin([5, 6]).astype(float)

    t = np.arange(n_t)
    train_len = int(train_rate * n_t)
    predict_len = int(test_rate * n_t)

    mean = df.values[:train_len].mean(axis=0)
    std = df.values[:train_len].std(axis=0)
    df = (df - mean) / (std)

    trues, preds = [], []
    for i in range(n_features):
        x = df.iloc[:, i].values
        mmax = max(x[:train_len])
        # train
        basis_need = basis(t, FFT(t[:args.fourier_len], x[:args.fourier_len], max_k=args.max_K),
                               max_k=args.max_K)
        basis_need = SOR(x, basis_need.T, pcols=None, term_thresh=1.0e-1, max_steps=100)[2].T
        # basis_need = np.clip(basis_need, a_min=0, a_max=None)
        x_pre_SI = qr(np.vstack((basis_need, is_weekend)).T, x, train_rate=train_rate, tau=args.tau, Lambda=args.Lambda)
        if args.plot:
            plt.figure(figsize=(15, 6))
            plt.plot(t[-predict_len:], x[-predict_len:], label='Ground Truth')
            plt.plot(t[-predict_len:], x_pre_SI[-predict_len:], label='SI')
            plt.legend()
            plt.show()

        # test
        data_len = args.pred_len + args.train_len
        train_rate_new = args.train_len / data_len
        iter_times = predict_len - data_len + 1
        max_k_new = (data_len - args.pred_len) // 2
        need_step = args.need_step
        x_array = np.zeros((iter_times // need_step + 1, data_len))
        x_si_array = np.zeros_like(x_array)
        x_pre_SI_new_array = np.zeros_like(x_array)
        t_new_array = np.zeros_like(x_array)
        temp_i = 0
        for i in range(0, iter_times, need_step):
            x_array[temp_i, :] = x[-predict_len + i:-predict_len + i + data_len + n_t]
            x_si_array[temp_i, :] = x_pre_SI[-predict_len + i:-predict_len + i + data_len + n_t]

            x_new = knowledge_fusion(x_array[temp_i, :], x_si_array[temp_i, :], args.pred_len, alpha_ft=args.alpha_ft,
                                       lambda_ft=args.lambda_ft)
            t_new = t[-predict_len + i:-predict_len + i + data_len + n_t]
            basis_need_new = basis(t_new,
                                       FFT(t_new[:-args.pred_len], x_new[:-args.pred_len], max_k=max_k_new),
                                       max_k=max_k_new)
            basis_need_new = SOR(x_new, basis_need_new.T, pcols=None, term_thresh=1e-1, max_steps=100)[2].T
            # basis_need_new = np.clip(basis_need_new, a_min=0, a_max=None)
            # x_new = stl(x_new, train_rate_new, b=np.vstack(
            #     (basis_need[:, -predict_len + i:-predict_len + i + data_len + n_t], basis_need_new))).values  # 使用STL
            x_pre_SI_new = qr(
                np.vstack((basis_need[:, -predict_len + i:-predict_len + i + data_len + n_t], basis_need_new)).T,
                x_new, train_rate=train_rate_new, tau=args.tau, Lambda=args.Lambda)

            x_pre_SI_new_array[temp_i, :] = x_pre_SI_new
            t_new_array[temp_i, :] = t_new
            temp_i = temp_i + 1
            trues.append(x[-predict_len + i + data_len - args.pred_len:-predict_len + i + data_len + n_t])
            preds.append(x_pre_SI_new[-args.pred_len:])

        if args.plot:
            place = 0
            for i in range(place, place + 2):
                plt.figure(figsize=(15, 6))
                plt.plot(t_new_array[i, :], x_si_array[i, :], label='SI_ori')
                plt.plot(t_new_array[i, :], x_array[i, :], label='Ground Truth')
                plt.plot(t_new_array[i, :], x_pre_SI_new_array[i, :], label='SI')
                plt.plot(t_new_array[i, -args.pred_len:], x_pre_SI_new_array[i, -args.pred_len:], label='SI_predict')
                plt.legend()
                plt.show()
    trues = np.array(trues)
    preds = np.array(preds)
    
    pmse = positive_mean_squared_error(trues, preds).round(3)
    pmae = positive_mean_absolute_error(trues, preds).round(3)
    qre = quantile_regression_error(trues, preds).round(3)
    print(pmse, pmae, qre)
    # end_time = time.time()

