import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import torch
import os
import random

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
seed = 42  # You can set any value you want here
set_seed(seed)



# def calculate_mape(y_true, y_pred, axis=None):
#     """
#     Calculate Mean Absolute Percentage Error (MAPE) for multidimensional arrays.
#     Parameters:
#     - y_true: np.ndarray
#         Array representing the true values.
#     - y_pred: np.ndarray
#         Array representing the predicted values.
#     - axis: int or tuple of ints, optional
#         Axis or axes along which the MAPE is computed. Default is None.
#     Returns:
#     - np.ndarray
#         Mean Absolute Percentage Error (MAPE) along the specified axis or axes.
#     """
#     # Avoid division by zero
#     mask = y_true != 0
#     y_true_masked = np.ma.array(y_true, mask=~mask)
#     y_pred_masked = np.ma.array(y_pred, mask=~mask)
#     mape = np.mean(np.abs((y_true_masked - y_pred_masked) / y_true_masked)) * 100
#     if axis is not None:
#         mape = np.mean(mape, axis=axis)
#     return mape
#
# def calculate_mse(y_true, y_pred):
#     y_true = np.array(y_true)
#     y_pred = np.array(y_pred)
#     return np.mean(np.square(y_true - y_pred))



def compute_mse(predicted, true):
    return np.mean((predicted - true) ** 2)

# Compute Mean Absolute Percentage Error (MAPE)
def compute_mape(predicted, true):
    return np.mean(np.abs((predicted - true) / true)) * 100



def normalize_2d(matrix):
    return (matrix - np.min(matrix, axis=0)) / (np.max(matrix, axis=0) - np.min(matrix, axis=0))



#### this is 200 bus data
from MFN_NODE.Interpolation.get_date_send import *
# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_200bus_9_10()
# data_name = "200bus"

# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_PV_online() # this is PV data
# data_name = "PV"

# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_load_oncor() # this is load data
# data_name = "load"
#
# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_ari_quality() # this is load data
# data_name = "air"


t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_spiral() # this is load data
data_name = "spiral"

model_name = "LinearSpline"


input_size = y_low.shape[1]
print(input_size)
print(f"Rate: {len(t_low)/ len(t_high) * 100} %")

print(t_low.shape)
print(y_low.shape)


data_dim = input_size
if data_dim == 1:
    linear_spline = interp1d(t_low, y_low[:, 0], kind='linear', fill_value='extrapolate')
    linear_spline_y2 = linear_spline(t_high)

    ### linear spline MSE/MAPE results: -----------------------------------------------------------------

    print(linear_spline_y2.shape)
    print(y_low_high_res.ravel().shape)

    linear_spline_mape = compute_mape(linear_spline_y2, y_low_high_res.ravel())
    linear_spline_mse = compute_mse(linear_spline_y2, y_low_high_res.ravel())
    print(f'Linear Spline Low-Res MSE: {linear_spline_mse:.6f}, Low-Res MAPE: {linear_spline_mape:.8f}%')


    plt.figure(figsize=(10, 5))
    plt.title("original data vs linear spline")
    plt.plot(t_high, linear_spline_y2, label="linear spline of y2", c="green", )
    plt.plot(t_high, y_low_high_res, label="y2 in high representation", c="green", linestyle="--")
    plt.scatter(t_low, y_low, c="green", s=50)
    plt.legend()
    plt.savefig(f'results/{model_name}_{data_name} interpolation.pdf', format='pdf', dpi=300)
    plt.show()


elif data_dim == 2:

    linear_spline_Y = np.zeros(shape = (len(t_high), data_dim))
    for d in range(data_dim):
        linear_spline = interp1d(t_low, y_low[:, d], kind='linear', fill_value='extrapolate')
        linear_spline_y2 = linear_spline(t_high)
        linear_spline_Y[:, d] = linear_spline_y2


    ### linear spline MSE/MAPE results: -----------------------------------------------------------------
    print(linear_spline_Y.shape)
    print(y_low_high_res.shape)
    linear_spline_mse = compute_mse(linear_spline_Y, y_low_high_res)
    linear_spline_mape = compute_mape(linear_spline_Y, y_low_high_res)
    print(f'Linear Spline Low-Res MSE: {linear_spline_mse:.6f}, Low-Res MAPE: {linear_spline_mape:.8f}%')
    print("linear_spline_res", linear_spline_Y[:, 0].tolist())

    plt.figure(figsize=(10, 5))
    plt.title("original data vs linear spline")
    plt.plot(t_high, y_low_high_res[:, 0], label="High-res signal 1 (true)", c="blue", linestyle="--")
    plt.plot(t_high, y_low_high_res[:, 1], label="High-res signal 2 (true)", c="orange", linestyle="--")

    plt.plot(t_low, y_low[:, 0], 'o', label='Low-res signal 1 (observed)', color='blue', alpha=1, )
    plt.plot(t_low, y_low[:, 1], 'o', label='Low-res signal 2 (observed)', color='orange', alpha=1, )

    plt.plot(t_high, linear_spline_Y[:, 0], label="Interpolated signal 1", c="blue", )
    plt.plot(t_high, linear_spline_Y[:, 1], label="Interpolated signal 2", c="orange", )

    plt.title("Spine Linear interpolation", fontsize=14)
    plt.xlabel('Time', fontsize=14)
    plt.ylabel('Amplitude', fontsize=14)
    plt.legend(fontsize=12)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.savefig(f'results/{model_name}_{data_name} interpolation.pdf', format='pdf', dpi=300)




if data_name == "spiral":
    plt.figure(figsize=(6, 5))
    plt.plot(y_high[:, 0], y_low_high_res[:, 0], label=f'True {data_name}', linestyle='dashed')
    plt.plot(y_high[:, 0], linear_spline_y2, label=f'Interpolated {data_name}', linestyle='dashed')

    plt.title(f'{model_name} Interpolation 2d', fontsize=14)
    plt.xlabel('X', fontsize=14)
    plt.ylabel('Y', fontsize=14)
    plt.legend(fontsize=12)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.savefig(f'results/{model_name}_{data_name}_interpolation_2d.pdf', format='pdf', dpi=300)

plt.show()

