import random
# from utils import *
from get_date_sendv2 import *
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, mean_absolute_percentage_error
from sklearn.preprocessing import MinMaxScaler
from scipy.interpolate import interp1d

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
seed = 42  # You can set any value you want here
set_seed(seed)


model_name = "MFN"
# Parameters
input_size = 2  # Number of input features
hidden_size = 60  # Number of features in the hidden state
output_size = 2  # Number of output features
num_epochs = 300
learning_rate = 0.001
SEQ_LENGTH = 10 # Number of time steps to use for each input sequence
after_interpolation = True

### DATA #################################################################################
# dataname = "200bus"
# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_200bus_9_10()
# factor = 0.66
# input_size, output_size = 1, 1


# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_PV_online()
# dataname = "PV"
# factor = 0.8

# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_load_oncor()
# dataname = "load"
# factor = 0.6
# input_size, output_size = 1, 1

# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_ari_quality()
# dataname = "air"
# factor = 0.66
# input_size, output_size = 1, 1

#
t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_spiral()
dataname = "spiral"
factor = 0.66
input_size, output_size = 1, 1



print(t_high.shape, y_high.shape, t_low.shape, y_low.shape, y_low_high_res.shape, t_low_index.shape, Interval)
print(f"Rate {len(t_low) / len(t_high)* 100}   %")


# After interpolation ################################################################################
if after_interpolation == True:
    data_dim = y_low.shape[1]
    cubic_spline_Y = np.zeros(shape=(len(t_high), y_low.shape[1]))  #
    for d in range(data_dim):
        cubic_spline = interp1d(t_low, y_low[:, d], kind='cubic', fill_value='extrapolate')
        cubic_spline_y2 = cubic_spline(t_high)
        cubic_spline_Y[:, d] = cubic_spline_y2

    t, value, value_high_real = t_high, cubic_spline_Y, y_high
    dataname += "_after_interpolation_"
else:
    t, value, value_high_real = t_low, y_low, None
print(t_low.shape, y_low.shape,t_high.shape, value_high_real.shape )
###############################################################################

# Convert data to PyTorch tensors
value_tensor = torch.tensor(value, dtype=torch.float32)
t_tensor = torch.tensor(t, dtype=torch.float32)
t_tensor = t_tensor.unsqueeze(1).repeat(1, input_size)  # (length, input_size)
# value_high_real = torch.tensor(value_high_real, dtype=torch.float32)

# sequences, targets = create_sequences(value_tensor, SEQ_LENGTH, value_high_real)
sequences, targets = create_sequences(value_tensor, SEQ_LENGTH, t_tensor)  # use t_tensor as input instead of value_high_real

train_size = int(len(sequences) * factor)  #  0.6 for load, 0.8 for PV
test_size = len(sequences) - train_size
# print(sequences.shape)
# print(targets.shape)
# exit()
train_sequences = sequences[:train_size]  # ([batch_size, sequence_length, num_features])  # only one batch
train_targets = targets[:train_size]

test_sequences = sequences[train_size:]
test_targets = targets[train_size:]

###############################################################################################


class FilterLayer(nn.Module):
    def __init__(self, intermediate_channels, in_channels):
        super(FilterLayer, self).__init__()
        self.intermediate_channels = intermediate_channels

        # Learnable parameters for sine and cosine filters
        self.weights = nn.Parameter(torch.randn(in_channels, intermediate_channels))

    def forward(self, t):


        weighted_t = torch.matmul(t, self.weights)
        # Generate sine and cosine filters
        sine_filter = torch.sin(weighted_t)
        cosine_filter = torch.cos(weighted_t)

        return sine_filter, cosine_filter


class MFN(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers, intermediate_channels):
        super(MFN, self).__init__()

        self.initial_layer = nn.Linear(in_channels, intermediate_channels)

        self.filter_layers = nn.ModuleList()
        self.linear_layers = nn.ModuleList()

        for _ in range(num_layers):
            self.filter_layers.append(FilterLayer(intermediate_channels, in_channels))
            self.linear_layers.append(nn.Linear(intermediate_channels, intermediate_channels))

        self.final_layer = nn.Linear(intermediate_channels, out_channels)
        # self.pool = nn.AdaptiveAvgPool1d(1)
        self.aggregate_layer = nn.Linear(SEQ_LENGTH * out_channels, out_channels)

    def forward(self, t):
        x = self.initial_layer(t)

        for filter_layer, linear_layer in zip(self.filter_layers, self.linear_layers):
            sine_filter, cosine_filter = filter_layer(t)
            x = x * sine_filter + x * cosine_filter  # Element-wise multiplication and stacking
            x = linear_layer(x)  # Linear activation

        x = self.final_layer(x)
        x = x.view(x.size(0), -1)  # [batch_size, sequence_length * out_channels]

        # Pass through the aggregation linear layer
        x = self.aggregate_layer(x)  # [batch_size, out_channels]
        return x

model = MFN(input_size, output_size, 8, hidden_size)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

best_loss = float('inf')  # Set initial best loss to infinity
best_model_path = 'best_model_{}.pth'.format(model_name)# Path to save the best model

# Training the model
for epoch in range(num_epochs):
    model.train()

    # print(train_sequences.shape)
    # print(train_targets.shape)
    outputs = model(train_sequences)

    # print(outputs.shape)

    loss = criterion(outputs, train_targets)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 50 == 0 or epoch == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.6f}')

    # Check if current loss is lower than the best loss
    if loss.item() < best_loss:
        best_loss = loss.item()
        torch.save(model.state_dict(), best_model_path)  # Save the model's state dict
        print(f"New best model saved with loss: {best_loss:.6f}")

# Function to make predictions
def predict(model, sequences):
    model.eval()
    with torch.no_grad():
        # hidden = model.init_hidden(batch_size=sequences.size(0))
        # predictions, hidden = model(sequences, hidden)
        predictions = model(sequences)
    return predictions


################################
# load model
model = MFN(input_size, output_size, 8, hidden_size)
model.load_state_dict(torch.load('best_model_{}.pth'.format(model_name)))

# # Make predictions on both training and test sets
# train_predictions = predict(model, train_sequences)
# test_predictions = predict(model, test_sequences)
# # Combine training and test predictions for visualization
# predictions = torch.cat((train_predictions, test_predictions), dim=0)

predictions_actual = predict(model, sequences).numpy() # train + test together

actual_np = torch.cat((train_targets, test_targets), dim=0).numpy()
actual_actual = actual_np

# Inverse transform to get actual values
# predictions_actual = scaler.inverse_transform(predictions_np)
# actual_actual = scaler.inverse_transform(actual_np)


# Adjust the time axis to match the reduced size (excluding first SEQ_LENGTH steps)
t_adjusted = t[SEQ_LENGTH:]

# Visualization of the entire dataset (training and test) with t_adjusted as x-axis
plt.figure(figsize=(12, 5))
# Plot for each feature
for i in range(output_size):
    plt.subplot(output_size, 1, i + 1)
    plt.plot(t_high, y_low_high_res[:, i], label="True high")
    plt.scatter(t_low, y_low[:, i], label='Actual', marker='o', c="darkblue")
    plt.scatter(t_adjusted, predictions_actual[:, i], label='Predicted', marker='.',  c="orange", s=3)
    plt.axvline(x=t[train_size + SEQ_LENGTH - 1], color='red', linestyle='--', label='Train/Test Split')  # Mark the split point
    plt.title(f'Feature {i + 1}: {model_name}')
    plt.xlabel('Time (t)')
    plt.ylabel('Value')
    plt.legend()
    plt.grid(True)
plt.tight_layout()
plt.savefig('./results/{}_{}_forecasting.pdf'.format(dataname, model_name), format='pdf', dpi=300)



# Calculate MSE and MAPE for training and test sets separately

if after_interpolation == True:
    actual_actual_low = y_low_high_res[t_low_index]
    actual_high_low = y_high[t_low_index]
    actual_actual = y_low_high_res[SEQ_LENGTH:]
    actual_high = y_high[SEQ_LENGTH:]
else:
    actual_actual = y_low[SEQ_LENGTH:]

train_predictions_actual = predictions_actual[:train_size]
test_predictions_actual = predictions_actual[train_size:]

train_actual_actual = actual_actual[:train_size]
test_actual_actual = actual_actual[train_size:]

train_actual_high = actual_high[:train_size]
test_actual_high = actual_high[train_size:]

train_mse = mean_squared_error(train_actual_actual, train_predictions_actual)
train_mape = mean_absolute_percentage_error(train_actual_actual, train_predictions_actual) * 100

test_mse = mean_squared_error(test_actual_actual, test_predictions_actual)
test_mape = mean_absolute_percentage_error(test_actual_actual, test_predictions_actual) * 100

print(f'{model_name} Training MSE: {train_mse:.4f}')
print(f'{model_name} Training MAPE: {train_mape:.2f}%')

print(f'{model_name} Test MSE: {test_mse:.6f}')
print(f'{model_name} Test MAPE: {test_mape:.2f}%')




if dataname == "spiral" + "_after_interpolation_":
    actual_low = y_low[:]
    actual_low_2 = y_high[t_low_index]

    # print(y_low_2.shape, actual_low2.shape)
    # exit()
    plt.figure(figsize=(6, 5))

    plt.plot(train_actual_high, train_actual_actual, label="True", c="black")
    plt.plot(test_actual_high, test_actual_actual, c="black")
    plt.plot(train_actual_high, train_predictions_actual, label="Pred Train", c="blue")
    plt.plot(test_actual_high, test_predictions_actual, label="Pred Test", c="orange")

    plt.scatter(actual_high_low, actual_actual_low, label ="Low True")
    # plt.scatter(y_low_2, actual_low_2, label="Low True")


    # plt.title(f'{dataname}_{model_name}_2d')
    plt.xlabel('X')
    plt.ylabel('Y')
    # plt.legend()
    plt.savefig('./results/{}_{}_forecasting_2d.pdf'.format(dataname, model_name), format='pdf', dpi=300)

plt.show()