import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import torch
import os
import random

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
seed = 42  # You can set any value you want here
set_seed(seed)


def compute_mse(predicted, true):
    return np.mean((predicted - true) ** 2)

# Compute Mean Absolute Percentage Error (MAPE)
def compute_mape(predicted, true):
    return np.mean(np.abs((predicted - true) / true)) * 100



def normalize_2d(matrix):
    return (matrix - np.min(matrix, axis=0)) / (np.max(matrix, axis=0) - np.min(matrix, axis=0))



#### this is 200 bus data
from MFN_NODE.Interpolation.get_date_send import *
# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_200bus_9_10()
# data_name = "200bus"

# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_PV_online() # this is PV data
# data_name = "PV"
# # #
# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_load_oncor() # this is load data
# data_name = "load"
#
# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_ari_quality() # this is load data
# data_name = "air"


t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_spiral() # this is load data
data_name = "spiral"


model_name = "CubicSpline"
input_size = y_low.shape[1]
print(input_size)
print(f"Rate: {len(t_low)/ len(t_high) * 100} %")

data_dim = input_size
if data_dim == 1:
    cubic_spline = interp1d(t_low, y_low[:, 0], kind='cubic', fill_value='extrapolate')
    cubic_spline_y2 = cubic_spline(t_high)

    print(cubic_spline_y2.shape)
    print(y_low_high_res.ravel().shape)
    ### cubic spline results: -----------------------------------------------------------------
    cubic_spline_mape = compute_mape(cubic_spline_y2, y_low_high_res.ravel())
    cubic_spline_mse = compute_mse(cubic_spline_y2, y_low_high_res.ravel())

    print(f'Cubic Spline: Low-Res MSE: {cubic_spline_mse}, Low-Res MAPE: {cubic_spline_mape:.8f}%')

    plt.figure(figsize=(10, 5))
    plt.title("original data vs Cubic spline")
    plt.plot(t_high, cubic_spline_y2, label="Cubic spline of y2", c="green", )
    plt.plot(t_high, y_low_high_res, label="y2 in high representation", c="green", linestyle="--")
    plt.scatter(t_low, y_low, c="green", s=50)
    plt.legend()
    plt.savefig(f'results/{model_name}_{data_name}_interpolation.pdf', format='pdf', dpi=300)


elif data_dim == 2:
    cubic_spline_Y = np.zeros(shape = (len(t_high), data_dim))

    for d in range(data_dim):
        cubic_spline = interp1d(t_low, y_low[:, d], kind='cubic', fill_value='extrapolate')
        cubic_spline_y2 = cubic_spline(t_high)
        cubic_spline_Y[:, d] = cubic_spline_y2

    ### cubic spline results: -----------------------------------------------------------------
    print(cubic_spline_Y.shape)
    print(y_low_high_res.shape)
    cubic_spline_mse = compute_mse(cubic_spline_Y, y_low_high_res)
    cubic_spline_mape = compute_mape(cubic_spline_Y, y_low_high_res)
    print(f'Cubic Spline: Low-Res MSE: {cubic_spline_mse}, Low-Res MAPE: {cubic_spline_mape:.8f}%')

    print("cubic_spline_res", cubic_spline_Y[:, 0].tolist())

    plt.figure(figsize=(10, 5))
    plt.title("original data vs linear spline")
    plt.plot(t_high, y_low_high_res[:, 0], label="High-res signal 1 (true)", c="blue", linestyle="--")
    plt.plot(t_high, y_low_high_res[:, 1], label="High-res signal 2 (true)", c="orange", linestyle="--")

    plt.plot(t_low, y_low[:, 0], 'o', label='Low-res signal 1 (observed)', color='blue', alpha=1, )
    plt.plot(t_low, y_low[:, 1], 'o', label='Low-res signal 2 (observed)', color='orange', alpha=1, )

    plt.plot(t_high, cubic_spline_Y[:, 0], label="Interpolated signal 1", c="blue", )
    plt.plot(t_high, cubic_spline_Y[:, 1], label="Interpolated signal 2", c="orange", )

    plt.title("Spine Linear interpolation", fontsize=14)
    plt.xlabel('Time', fontsize=14)
    plt.ylabel('Amplitude', fontsize=14)
    plt.legend(fontsize=12)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.savefig(f'results/{model_name}_{data_name} interpolation.pdf', format='pdf', dpi=300)



if data_name == "spiral":
    plt.figure(figsize=(6, 5))
    plt.plot(y_high, y_low_high_res, label=f'True {data_name}', linestyle='dashed')
    plt.plot(y_high, cubic_spline_y2, label=f'Interpolated {data_name}', linestyle='dashed')

    plt.title(f'{model_name} Interpolation 2d', fontsize=14)
    plt.xlabel('X', fontsize=14)
    plt.ylabel('Y', fontsize=14)
    plt.legend(fontsize=12)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.savefig(f'results/{model_name}_{data_name}_interpolation_2d.pdf', format='pdf', dpi=300)

plt.show()







