
import torch.nn as nn
import numpy as np
import torch
import torch.optim as optim
import random

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
seed = 42  # You can set any value you want here
set_seed(seed)


def compute_mse(predicted, true):
    return np.mean((predicted - true) ** 2)

# Compute Mean Absolute Percentage Error (MAPE)
def compute_mape(predicted, true):
    return np.mean(np.abs((predicted - true) / true)) * 100


class FilterLayer(nn.Module):
    def __init__(self, intermediate_channels):
        super(FilterLayer, self).__init__()
        self.weights = nn.Parameter(torch.randn(1, intermediate_channels))

    def forward(self, t):
        sine_filter = torch.sin(self.weights * t)
        cosine_filter = torch.cos(self.weights * t)
        return sine_filter, cosine_filter

class MFN(nn.Module):
    def __init__(self, in_channels, intermediate_channels, num_layers, out_channels_high, out_channels_low):
        super(MFN, self).__init__()
        self.num_layers = num_layers
        self.intermediate_channels = intermediate_channels
        self.initial_layer = nn.Linear(in_channels, intermediate_channels)
        self.filter_layers = nn.ModuleList([FilterLayer(intermediate_channels) for _ in range(num_layers)])
        self.linear_layers = nn.ModuleList(
            [nn.Linear(intermediate_channels, intermediate_channels) for _ in range(num_layers)])
        self.final_layer_high = nn.Linear(intermediate_channels, out_channels_high)
        self.final_layer_low = nn.Linear(intermediate_channels, out_channels_low)

    def forward(self, t, output_type):
        x = self.initial_layer(t)  # Shape: [time_steps, intermediate_channels]
        bias_index = 0
        for i, filter_layer in enumerate(self.filter_layers):
            sine_filter, cosine_filter = filter_layer(t)
            x = x * sine_filter + x * cosine_filter  # Element-wise multiplication and addition
            # additional_biases = additional_biases_series[:, bias_index:bias_index + self.intermediate_channels]
            bias_index += self.intermediate_channels
            x = self.linear_layers[i](x)  # + additional_biases

        if output_type == 'high':
            # additional_biases = additional_biases_series[:, bias_index:bias_index + self.final_layer_high.out_features]
            bias_index += self.final_layer_high.out_features
            x = self.final_layer_high(x)  #+ additional_biases
        elif output_type == 'low':
            # additional_biases = additional_biases_series[:, bias_index:bias_index + self.final_layer_low.out_features]
            bias_index += self.final_layer_low.out_features
            x = self.final_layer_low(x) #+ additional_biases

        return x



from MFN_NODE.Interpolation.get_date_send import *
# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_200bus_9_10()
# data_name = "200bus"

# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_PV_online() # this is PV data
# data_name = "PV"

t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_load_oncor() # this is load data
data_name = "load"

# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_ari_quality() # this is load data
# data_name = "air"
# # #
# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_spiral() # this is load data
# data_name = "spiral"


model_name = "MFN"
input_size = y_low.shape[1]
output_size = y_low.shape[1]
print(input_size)
print(f"Rate: {len(t_low)/ len(t_high) * 100} %")


t_high = torch.tensor(t_high, dtype=torch.float32).view(-1, 1)
y_high = torch.tensor(y_high, dtype=torch.float32)
t_low = torch.tensor(t_low, dtype=torch.float32).view(-1, 1)
y_low = torch.tensor(y_low, dtype=torch.float32)

learning_rate = 0.005 #0.005
num_epochs = 1000
penalty_weight = 0

in_channels = 1  # Single time index as input
out_channels_high = y_high.shape[1]  # Multi-dimensional output for high resolution
out_channels_low = y_low.shape[1]  # Multi-dimensional output for low resolution
num_layers = 12
intermediate_channels = 40
latent_dim = 10
CDE_hidden_size = 20

model = MFN(in_channels, intermediate_channels, num_layers, out_channels_high, out_channels_low)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


if_use_HR = False

for epoch in range(num_epochs):
    model.train()

    output_low = model(t_low, 'low')
    loss_low = criterion(output_low, y_low)

    if if_use_HR:
        output_high = model(t_high, 'high')
        loss_high = criterion(output_high, y_high)
    else:
        loss_high = 0
    loss = loss_high + loss_low   # todo: what input data? low only, high only, or low + high

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

if if_use_HR:
    MFN_predicted = model(t_high, 'high').cpu().detach().numpy()
else:
    MFN_predicted = model(t_high, 'low').cpu().detach().numpy()


### spline results: -----------------------------------------------------------------
MFN_mse = compute_mse(MFN_predicted, y_low_high_res)
MFN_mape = compute_mape(MFN_predicted, y_low_high_res)
print(f'MFN: Low-Res MSE: {MFN_mse:.4f}, Low-Res MAPE: {MFN_mape:.4f}%')




print("MFN_res", MFN_predicted[:, 0].tolist())

# plt.figure(figsize=(10, 5))
#
# plt.plot(t_high, y_low_high_res[:, 0], label='High-res signal 1 (true)',  linestyle='solid', color='blue')
# plt.plot(t_high, y_low_high_res[:, 1], label='High-res signal 2 (true)', linestyle='solid', color='orange')
# plt.plot(t_low, y_low[:, 0], 'o', label='Low-res signal 1 (observed)', color='blue', alpha=1, )
# plt.plot(t_low, y_low[:, 1], 'o', label='Low-res signal 2 (observed)', color='orange', alpha=1, )
# plt.plot(t_high, MFN_predicted[:, 0], label='Interpolated signal 1', linestyle='dashed', color='blue')
# plt.plot(t_high, MFN_predicted[:, 1], label='Interpolated signal 2', linestyle='dashed', color='orange')
#
#
# plt.title("MFN (if use HR? {}) interpolation".format(if_use_HR), fontsize=14)
# plt.xlabel('Time', fontsize=14)
# plt.ylabel('Amplitude', fontsize=14)
# plt.legend(fontsize=12)
# plt.xticks(fontsize=14)
# plt.yticks(fontsize=14)
# plt.savefig('results/MFN (if use HR? {}) interpolation.pdf'.format(if_use_HR), format='pdf', dpi=300)
#
# plt.show()

plt.figure(figsize=(10, 5))
for i in range(input_size):
    # Plot the results
    plt.plot(t_high, y_low_high_res[:, i], label=f'High-res signal {i} (true)', linestyle='dashed')
    plt.plot(t_low, y_low[:, i], 'o', label=f'Low-res signal {i} (observed)', color='blue', alpha=1)
    plt.plot(t_high, MFN_predicted[:, i], label=f'Interpolated signal {i}' )

    plt.title(f'{model_name} Interpolation', fontsize=14)
    plt.xlabel('Time', fontsize=14)
    plt.ylabel('Amplitude', fontsize=14)
    plt.legend(fontsize=12)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.savefig(f'results/{model_name}_{data_name}_interpolation.pdf', format='pdf', dpi=300)



if data_name == "spiral":
    plt.figure(figsize=(6, 5))
    plt.plot(y_high, y_low_high_res, label=f'True {data_name}', linestyle='dashed')
    plt.plot(y_high, MFN_predicted, label=f'Interpolated {data_name}', linestyle='dashed')

    plt.title(f'{model_name} Interpolation 2d', fontsize=14)
    plt.xlabel('X', fontsize=14)
    plt.ylabel('Y', fontsize=14)
    plt.legend(fontsize=12)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.savefig(f'results/{model_name}_{data_name}_interpolation_2d.pdf', format='pdf', dpi=300)

plt.show()