import random
# from utils import *
from get_date_sendv2 import *
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, mean_absolute_percentage_error
from sklearn.preprocessing import MinMaxScaler
from scipy.interpolate import interp1d
import torchdiffeq


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
seed = 42  # You can set any value you want here
set_seed(seed)



model_name = "NCDE_RNN"
# Parameters
input_size = 2  # Number of input features
hidden_size = 60  # Number of features in the hidden state
output_size = 2  # Number of output features
num_epochs = 300
learning_rate = 0.001
SEQ_LENGTH = 10 # Number of time steps to use for each input sequence
after_interpolation = True

### DATA #################################################################################
# dataname = "200bus"
# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_200bus_9_10()
# factor = 0.66
# input_size, output_size = 1, 1


# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_PV_online()
# dataname = "PV"
# factor = 0.8


# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_load_oncor()
# dataname = "load"
# factor = 0.6
# input_size, output_size = 1, 1

# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_ari_quality()
# dataname = "air"
# factor = 0.66
# input_size, output_size = 1, 1


t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_spiral()
dataname = "spiral"
factor = 0.66
input_size, output_size = 1, 1


print(t_high.shape, y_high.shape, t_low.shape, y_low.shape, y_low_high_res.shape, t_low_index.shape, Interval)
print(f"Rate {len(t_low) / len(t_high)* 100}   %")
# After interpolation ################################################################################
if after_interpolation == True:
    data_dim = y_low.shape[1]
    cubic_spline_Y = np.zeros(shape=(len(t_high), y_low.shape[1]))  #
    for d in range(data_dim):
        cubic_spline = interp1d(t_low, y_low[:, d], kind='cubic', fill_value='extrapolate')
        cubic_spline_y2 = cubic_spline(t_high)
        cubic_spline_Y[:, d] = cubic_spline_y2

    t, value, value_high_real = t_high, cubic_spline_Y, y_high
    dataname += "_after_interpolation_"
else:
    t, value, value_high_real = t_low, y_low, None
print(t_low.shape, y_low.shape,t_high.shape, value_high_real.shape )
###############################################################################

# Convert data to PyTorch tensors
value_tensor = torch.tensor(value, dtype=torch.float32)
value_high_real = torch.tensor(value_high_real, dtype=torch.float32)
t_tensor = torch.tensor(t, dtype=torch.float32)

sequences, targets = create_sequences(value_tensor, SEQ_LENGTH, value_high_real)
sequences_t, _ = create_sequences(value_tensor, SEQ_LENGTH, t_tensor)


train_size = int(len(sequences) * factor)
test_size = len(sequences) - train_size
print(train_size)

train_sequences = sequences[:train_size]  # ([batch_size, sequence_length, num_features])  # only one batch
# train_sequences_t = sequences_t[:train_size]
train_targets = targets[:train_size]


test_sequences = sequences[train_size:]
# test_sequences_t = sequences_t[train_size:]
test_targets = targets[train_size:]

###############################################################################################
# Define the Neural ODE for the hypernetwork that directly outputs matrix components
class ODEFunc(nn.Module):
    def __init__(self, latent_dim, hidden_size, num_matrices):
        super(ODEFunc, self).__init__()
        self.hidden_size = hidden_size
        self.num_matrices = num_matrices

        # Define the neural networks that will generate the matrices Ai based on time t
        self.fc1_list = nn.ModuleList([nn.Linear(1, latent_dim) for _ in range(num_matrices)])
        self.fc2_list = nn.ModuleList([nn.Linear(latent_dim, hidden_size * hidden_size) for _ in range(num_matrices)])

        # Define the coefficients for the linear combination
        self.coefficients = nn.Parameter(torch.randn(num_matrices))

    def forward(self, t, y):
        skew_symmetric_matrices = []

        for i in range(self.num_matrices):
            # Generate matrix Ai using the neural network
            A = self.fc1_list[i](t.view(-1, 1))  # Pass t through the first layer
            A = torch.tanh(A)  # Apply non-linearity
            A = self.fc2_list[i](A)  # Get the output matrix A

            # Reshape A to be a square matrix of size [hidden_size, hidden_size]
            A = A.view(-1, self.hidden_size, self.hidden_size)

            # Compute the skew-symmetric matrix Si = Ai - Ai.T
            S = A - A.transpose(1, 2)
            skew_symmetric_matrices.append(S)

        # Compute the final skew-symmetric matrix S = a1 * S1 + a2 * S2 + ... + aM * SM
        combined_S = sum(self.coefficients[i] * skew_symmetric_matrices[i] for i in range(self.num_matrices))

        # Compute the time derivative of y (dy/dt = combined_S * y)
        out = torch.matmul(combined_S, y.view(-1, self.hidden_size, self.hidden_size))

        # Reshape the output to match the expected size
        return out.view(-1, self.hidden_size * self.hidden_size)

    def check_skew_symmetric(self, S):
        """Check if a matrix S is skew-symmetric."""
        S_plus_S_T = S + S.transpose(1, 2)  # Compute S + S^T
        difference = torch.norm(S_plus_S_T)  # Norm of the difference should be close to 0 for skew-symmetric matrices
        return difference.item()


# Define the HyperNetwork that outputs time-dependent weight matrices W_hh
class HyperNetwork(nn.Module):
    def __init__(self, latent_dim, hidden_size, num_matrices):
        super(HyperNetwork, self).__init__()
        self.latent_dim = latent_dim
        self.hidden_size = hidden_size
        self.ode_func = ODEFunc(latent_dim, hidden_size, num_matrices)

    def forward(self, t):
        # Initialize y0 as an orthogonal identity matrix
        y0 = torch.eye(self.hidden_size).reshape(1, self.hidden_size * self.hidden_size).to(t.device)

        # Ensure t is one-dimensional
        t = t.squeeze()  # This ensures that t is 1D before passing to odeint

        # Solve the ODE, ode_out will have shape [time_steps, 1, hidden_size * hidden_size]
        ode_out = torchdiffeq.odeint(self.ode_func, y0, t)

        # Reshape the output to [batch_size, hidden_size, hidden_size]
        ode_out = ode_out.squeeze(1)  # Remove the extra dimension for batch processing

        return ode_out.reshape(t.size(0), self.hidden_size, self.hidden_size)


# Define the RNN with dynamic W_hh from Hypernetwork
class DynamicRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers, hyper_net):
        super(DynamicRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn_cell = nn.RNNCell(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)
        self.hyper_net = hyper_net

    def check_orthogonality(self, W):
        """Check if a matrix W is orthogonal."""
        W_T_W = torch.bmm(W.transpose(1, 2), W)  # Compute W^T * W
        identity = torch.eye(W.size(1), W.size(2)).to(W.device)  # Identity matrix for comparison
        difference = torch.norm(W_T_W - identity)  # Norm of the difference between W^T * W and the identity matrix
        return difference.item()

    def forward(self, x, t):
        batch_size, seq_length, _ = x.size()
        h_t = torch.zeros(batch_size, self.hidden_size).to(x.device)
        W_hh_seq = self.hyper_net(t)

        orthogonality_check_results = []

        for i in range(seq_length):
            h_t_expanded = h_t.unsqueeze(1)
            h_t_time_weight = torch.bmm(h_t_expanded, W_hh_seq)
            h_t_time_weight = h_t_time_weight.squeeze(1)
            h_t = torch.tanh(
                x[:, i] @ self.rnn_cell.weight_ih.t() + h_t_time_weight + self.rnn_cell.bias_ih + self.rnn_cell.bias_hh)

            # # Check orthogonality of W_hh_seq at each time step
            # orthogonality_diff = self.check_orthogonality(W_hh_seq[i].unsqueeze(0))
            # orthogonality_check_results.append(orthogonality_diff)

        output = self.fc(h_t)

        # Print or return the orthogonality check results
        # print(f'Orthogonality check (average difference from identity): {np.mean(orthogonality_check_results):.6f}')

        return output


#############################################################################################
hyper_net = HyperNetwork(latent_dim=20, hidden_size=hidden_size, num_matrices=1)
model = DynamicRNN(input_size=input_size, hidden_size=hidden_size, output_size=output_size, num_layers=1,
                   hyper_net=hyper_net)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

best_loss = float('inf')  # Set initial best loss to infinity
best_model_path = 'best_model_{}.pth'.format(model_name)# Path to save the best model


# test_sequences_t = torch.tensor([i for i in range(len())], dtype=torch.float32)

def assign_time_indices(total_samples, days=8):
    samples_per_day = total_samples // days
    t_high_normalized = np.tile(np.linspace(0, 1, samples_per_day), days)

    # Check if there's a remainder and adjust t_high_normalized accordingly
    remainder = total_samples % (samples_per_day * days)
    if remainder > 0:
        interval = t_high_normalized[1] - t_high_normalized[0]
        extra_time_indices = interval * np.arange(0, remainder)
        t_high_normalized = np.concatenate((t_high_normalized, extra_time_indices))
    # print(torch.tensor(t_high_normalized, dtype=torch.float32))
    # exit()
    return torch.tensor(t_high_normalized, dtype=torch.float32)


total_samples = len(t_high) - SEQ_LENGTH
# days = 8
# t_high_normalized = assign_time_indices(total_samples, days)
# print(t_high_normalized.shape)
l = int(total_samples * factor)

# print(t_high_normalized)
# t_train_normalized = t_high_normalized[SEQ_LENGTH: l+ SEQ_LENGTH]
# t_test_normalized = t_high_normalized[l + SEQ_LENGTH:]
# t_train_normalized = t_train_normalized[SEQ_LENGTH:]  # Align time indices with sequences
# t_test_normalized = t_test_normalized[SEQ_LENGTH:]
# print(t_train_normalized)
# exit()
# points =
# torch.linspace(0, 1, n_points) train_sequences
t_normalized = torch.tensor(torch.linspace(0, 1, total_samples), dtype=torch.float32)

t_train_normalized = t_normalized[0 : int(total_samples * factor)]
t_test_normalized = t_normalized[int(total_samples * factor):]


# Training the model
for epoch in range(num_epochs):
    model.train()
    # print(train_sequences.shape, train_targets.shape, t_train_normalized.shape)
    # exit()
    outputs = model(train_sequences, t_train_normalized)

    loss = criterion(outputs, train_targets)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 50 == 0 or epoch == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.6f}')

    # Check if current loss is lower than the best loss
    if loss.item() < best_loss:
        best_loss = loss.item()
        torch.save(model.state_dict(), best_model_path)  # Save the model's state dict
        print(f"New best model saved with loss: {best_loss:.6f}")

# Function to make predictions
def predict(model, sequences, t):
    model.eval()
    with torch.no_grad():
        # hidden = model.init_hidden(batch_size=sequences.size(0))
        # predictions, hidden = model(sequences, hidden)
        predictions = model(sequences, t)
    return predictions


################################
# load model
hyper_net = HyperNetwork(latent_dim=20, hidden_size=hidden_size, num_matrices=1)
model = DynamicRNN(input_size=input_size, hidden_size=hidden_size, output_size=output_size, num_layers=1,
                   hyper_net=hyper_net)
model.load_state_dict(torch.load('best_model_{}.pth'.format(model_name)))

# Make predictions on both training and test sets
train_predictions = predict(model, train_sequences,  t_train_normalized)
test_predictions = predict(model, test_sequences,t_test_normalized)

# Combine training and test predictions for visualization
predictions = torch.cat((train_predictions, test_predictions), dim=0)

# Convert tensors to NumPy arrays
predictions_np = predictions.numpy()
actual_np = torch.cat((train_targets, test_targets), dim=0).numpy()

# Inverse transform to get actual values
# predictions_actual = scaler.inverse_transform(predictions_np)
# actual_actual = scaler.inverse_transform(actual_np)

predictions_actual = predictions_np
actual_actual = actual_np


# Adjust the time axis to match the reduced size (excluding first SEQ_LENGTH steps)
t_adjusted = t[SEQ_LENGTH:]


# Visualization of the entire dataset (training and test) with t_adjusted as x-axis
plt.figure(figsize=(12, 6))
# Plot for each feature
for i in range(output_size):
    plt.subplot(output_size, 1, i + 1)
    plt.plot(t_high, y_low_high_res[:, i], label="True high")
    plt.scatter(t_low, y_low[:, i], label='Actual', marker='o', c="darkblue")
    plt.scatter(t_adjusted, predictions_actual[:, i], label='Predicted', marker='.',  c="orange", s=3)
    plt.axvline(x=t[train_size + SEQ_LENGTH - 1], color='red', linestyle='--', label='Train/Test Split')  # Mark the split point
    plt.title(f'Feature {i + 1}: {model_name}')
    plt.xlabel('Time (t)')
    plt.ylabel('Value')
    plt.legend()
    plt.grid(True)

plt.tight_layout()
plt.savefig('./results/{}_{}_forecasting.pdf'.format(dataname, model_name), format='pdf', dpi=300)

# Calculate MSE and MAPE for training and test sets separately

if after_interpolation == True:
    actual_actual_low = y_low_high_res[t_low_index]
    actual_high_low = y_high[t_low_index]
    actual_actual = y_low_high_res[SEQ_LENGTH:]
    actual_high = y_high[SEQ_LENGTH:]
else:
    actual_actual = y_low[SEQ_LENGTH:]

train_predictions_actual = predictions_actual[:train_size]
test_predictions_actual = predictions_actual[train_size:]

train_actual_actual = actual_actual[:train_size]
test_actual_actual = actual_actual[train_size:]

train_actual_high = actual_high[:train_size]
test_actual_high = actual_high[train_size:]

train_mse = mean_squared_error(train_actual_actual, train_predictions_actual)
train_mape = mean_absolute_percentage_error(train_actual_actual, train_predictions_actual) * 100

test_mse = mean_squared_error(test_actual_actual, test_predictions_actual)
test_mape = mean_absolute_percentage_error(test_actual_actual, test_predictions_actual) * 100

print(f'{model_name} Training MSE: {train_mse:.4f}')
print(f'{model_name} Training MAPE: {train_mape:.2f}%')

print(f'{model_name} Test MSE: {test_mse:.6f}')
print(f'{model_name} Test MAPE: {test_mape:.2f}%')




if dataname == "spiral" + "_after_interpolation_":
    actual_low = y_low[:]
    actual_low_2 = y_high[t_low_index]

    # print(y_low_2.shape, actual_low2.shape)
    # exit()
    plt.figure(figsize=(6, 5))

    plt.plot(train_actual_high, train_actual_actual, label="True", c="black")
    plt.plot(test_actual_high, test_actual_actual, c="black")
    plt.plot(train_actual_high, train_predictions_actual, label="Pred Train", c="blue")
    plt.plot(test_actual_high, test_predictions_actual, label="Pred Test", c="orange")

    plt.scatter(actual_high_low, actual_actual_low, label ="Low True")
    # plt.scatter(y_low_2, actual_low_2, label="Low True")


    # plt.title(f'{dataname}_{model_name}_2d')
    plt.xlabel('X')
    plt.ylabel('Y')
    # plt.legend()
    plt.savefig('./results/{}_{}_forecasting_2d.pdf'.format(dataname, model_name), format='pdf', dpi=300)

plt.show()