from scipy.fftpack import dct, idct
import numpy as np
from sklearn.linear_model import Lasso

np.random.seed(42)


def compute_mse(predicted, true):
    return np.mean((predicted - true) ** 2)

# Compute Mean Absolute Percentage Error (MAPE)
def compute_mape(predicted, true):
    return np.mean(np.abs((predicted - true) / true)) * 100


def create_measurement_matrix(t_l_norm, t_ori_norm):
    A = np.zeros((len(t_l_norm), len(t_ori_norm)))
    for i, t in enumerate(t_l_norm):
        closest_index = np.argmin(np.abs(t_ori_norm - t))
        A[i, closest_index] = 1
    return A


# Generate high and low-resolution data
from MFN_NODE.Interpolation.get_date_send import *
# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_200bus_9_10() # this is 200 bus data
# data_name = "200bus"

# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_PV_online() # this is PV data
# data_name = "PV"

# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_load_oncor() # this is load data
# data_name = "load"

# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_ari_quality() # this is load data
# data_name = "air"

t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_spiral() # this is load data
data_name = "spiral"



model_name = "CS"

input_size = y_low.shape[1]
print(input_size)
print(f"Rate: {len(t_low)/ len(t_high) * 100} %")



A = create_measurement_matrix(t_low, t_high)    # 创建测量矩阵 A
Psi = dct(np.eye(len(t_high)), norm='ortho')  # 离散余弦变换矩阵
lasso = Lasso(alpha=10e-8, max_iter=10000)  # alpha值需要根据具体情况调整       # 使用Lasso进行重构
lasso.fit(A @ Psi.T, y_low)
x_sparse = lasso.coef_
cs_res = idct(x_sparse, norm='ortho')  # 重构信号
cs_res = cs_res.T
cs_res = cs_res + (y_low[0] - cs_res[0])  # 有个偏移 (?)
if input_size == 1:
    cs_res = cs_res.reshape(-1, 1)
print(cs_res.shape)


### cs spline results: -----------------------------------------------------------------
print(cs_res.shape)
print(y_low_high_res.shape)
cs_mse = compute_mse(cs_res, y_low_high_res)
cs_mape = compute_mape(cs_res, y_low_high_res)
print(f'CS: Low-Res MSE: {cs_mse:.4f}, Low-Res MAPE: {cs_mape:.4f}%')


# print("cs_res", cs_res[:, 0].tolist())
# print("truth", y_low_high_res[:, 0].tolist())
# print("t_high", t_high.tolist())
# print("t_low", t_low.tolist())
# print("truth_low", y_low[:, 0].tolist())


for i in range(input_size):
    # Plot the results
    plt.figure(figsize=(10, 5))
    plt.plot(t_high, y_low_high_res[:, i], label=f'High-res signal {i} (true)', linestyle='dashed')
    plt.plot(t_low, y_low[:, i], 'o', label=f'Low-res signal {i} (observed)', color='blue', alpha=1)
    plt.plot(t_high, cs_res[:, i], label=f'Interpolated signal {i}' )

    plt.title(f'{model_name} Interpolation', fontsize=14)
    plt.xlabel('Time', fontsize=14)
    plt.ylabel('Amplitude', fontsize=14)
    plt.legend(fontsize=12)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.savefig(f'results/{model_name}_{data_name}_interpolation.pdf', format='pdf', dpi=300)


if data_name == "spiral":
    plt.figure(figsize=(6, 5))
    plt.plot(y_high[:, 0], y_low_high_res[:, 0], label=f'True {data_name}', linestyle='dashed')
    plt.plot(y_high[:, 0], cs_res[:, 0], label=f'Interpolated {data_name}', linestyle='dashed')

    plt.title(f'{model_name} Interpolation 2d', fontsize=14)
    plt.xlabel('X', fontsize=14)
    plt.ylabel('Y', fontsize=14)
    plt.legend(fontsize=12)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.savefig(f'results/{model_name}_{data_name}_interpolation_2d.pdf', format='pdf', dpi=300)

plt.show()

