import random
from get_date_sendv2 import *
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, mean_absolute_percentage_error
from sklearn.preprocessing import MinMaxScaler
import controldiffeq  # Import controldiffeq for Neural CDE
from scipy.interpolate import interp1d

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


seed = 42  # You can set any value you want here
set_seed(seed)

model_name = "NCDE"
# Parameters
input_size = 2  # Number of input features
hidden_size = 60  # Number of features in the hidden state
output_size = 2  # Number of output features
num_epochs = 300
learning_rate = 0.001
SEQ_LENGTH = 10  # Number of time steps to use for each input sequence
after_interpolation = True

### DATA #################################################################################
# dataname = "200bus"
# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_200bus_9_10()
# factor = 0.66
# input_size, output_size = 1, 1

# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_PV_online()
# dataname = "PV"
# factor = 0.8
# input_size, output_size = 1, 1

# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_load_oncor()
# dataname = "load"
# input_size, output_size = 1, 1
# factor = 0.6

# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_ari_quality()
# dataname = "air"
# factor = 0.66
# input_size, output_size = 1, 1


t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_spiral()
dataname = "spiral"
factor = 0.66
input_size, output_size = 1, 1


print(t_high.shape, y_high.shape, t_low.shape, y_low.shape, y_low_high_res.shape, t_low_index.shape, Interval)
print(f"Rate {len(t_low) / len(t_high)* 100}   %")

# After interpolation ################################################################################
if after_interpolation == True:
    data_dim = y_low.shape[1]
    cubic_spline_Y = np.zeros(shape=(len(t_high), y_low.shape[1]))  #
    for d in range(data_dim):
        cubic_spline = interp1d(t_low, y_low[:, d], kind='cubic', fill_value='extrapolate')
        cubic_spline_y2 = cubic_spline(t_high)
        cubic_spline_Y[:, d] = cubic_spline_y2

    t, value, value_high_real = t_high, cubic_spline_Y, y_high
    dataname += "_after_interpolation_"
else:
    t, value, value_high_real = t_low, y_low, None
print(t_low.shape, y_low.shape, t_high.shape, value_high_real.shape)
###############################################################################

# Convert data to PyTorch tensors
value_tensor = torch.tensor(value, dtype=torch.float32)
value_high_real = torch.tensor(value_high_real, dtype=torch.float32)

sequences, targets = create_sequences(value_tensor, SEQ_LENGTH, value_high_real)

train_size = int(len(sequences) * factor)  # 0.6 for load, 0.8 for PV
test_size = len(sequences) - train_size

train_sequences = sequences[:train_size]  # ([batch_size, sequence_length, num_features])  # only one batch
train_targets = targets[:train_size]

test_sequences = sequences[train_size:]
test_targets = targets[train_size:]

###############################################################################################
# Neural Controlled Differential Equation (NCDE) Model using controldiffeq

class CDEFunc(nn.Module):
    def __init__(self, input_channels, hidden_channels):
        super(CDEFunc, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.linear1 = nn.Linear(hidden_channels, 128)
        self.linear2 = nn.Linear(128, hidden_channels * input_channels)

    def forward(self, z):
        z = self.linear1(z)
        z = z.relu()
        z = self.linear2(z)
        # Easy-to-forget gotcha: Best results tend to be obtained by adding a final tanh nonlinearity.
        z = z.tanh()
        # Reshape to [batch_size, hidden_channels, input_channels]
        z = z.view(z.size(0), self.hidden_channels, self.input_channels)
        return z

class NeuralCDE(nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super(NeuralCDE, self).__init__()
        self.hidden_channels = hidden_channels

        self.func = CDEFunc(input_channels, hidden_channels)
        # Initial hidden state is a function of the first observation
        self.initial = nn.Linear(input_channels, hidden_channels)
        self.readout = nn.Linear(hidden_channels, output_channels)

    def forward(self, times, coeffs):
        batch_size = coeffs[0].size(0)
        device = coeffs[0].device

        spline = controldiffeq.NaturalCubicSpline(times, coeffs)
        z0 = self.initial(spline.evaluate(times[0]))
        # Solve the CDE
        z_T = controldiffeq.cdeint(dX_dt=spline.derivative,
                                   z0=z0,
                                   func=self.func,
                                   t=times[[0, -1]],
                                   atol=1e-2,
                                   rtol=1e-2)
        # Get the final hidden state
        z_T = z_T[-1]
        # Apply the readout layer
        pred_y = self.readout(z_T)
        return pred_y



# Prepare time steps
t_window = torch.linspace(0., SEQ_LENGTH - 1, steps=SEQ_LENGTH)

# Compute natural cubic spline coefficients for training and testing sequences
train_coeffs = controldiffeq.natural_cubic_spline_coeffs(t_window, train_sequences)
test_coeffs = controldiffeq.natural_cubic_spline_coeffs(t_window, test_sequences)

# Create datasets and dataloaders
train_dataset = torch.utils.data.TensorDataset(*train_coeffs, train_targets)
test_dataset = torch.utils.data.TensorDataset(*test_coeffs, test_targets)

batch_size = 32
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Instantiate the NeuralCDE model
model = NeuralCDE(input_size, hidden_size, input_size)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)



best_loss = float('inf')  # Set initial best loss to infinity
best_model_path = 'best_model_{}.pth'.format(model_name)# Path to save the best model
# Training the NeuralCDE model
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for batch in train_dataloader:
        *batch_coeffs, batch_targets = batch
        batch_coeffs = tuple(batch_coeffs)
        optimizer.zero_grad()
        outputs = model(t_window, batch_coeffs)
        loss = criterion(outputs, batch_targets)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    if (epoch + 1) % 50 == 0 or epoch == 0:
        avg_loss = epoch_loss / len(train_dataloader)
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.6f}')

    # Check if current loss is lower than the best loss
    if loss.item() < best_loss:
        best_loss = loss.item()
        torch.save(model.state_dict(), best_model_path)  # Save the model's state dict
        print(f"New best model saved with loss: {best_loss:.6f}")


# Function to make predictions
def predict(model, times, dataloader):
    model.eval()
    predictions = []
    actuals = []
    with torch.no_grad():
        for batch in dataloader:
            *batch_coeffs, batch_targets = batch
            batch_coeffs = tuple(batch_coeffs)
            outputs = model(times, batch_coeffs)
            predictions.append(outputs)
            actuals.append(batch_targets)
    predictions = torch.cat(predictions, dim=0)
    actuals = torch.cat(actuals, dim=0)
    return predictions, actuals

# Make predictions on both training and test sets
train_predictions, train_actuals = predict(model, t_window, train_dataloader)
test_predictions, test_actuals = predict(model, t_window, test_dataloader)

# Combine training and test predictions for visualization
predictions = torch.cat((train_predictions, test_predictions), dim=0)
actuals = torch.cat((train_actuals, test_actuals), dim=0)

# Convert tensors to NumPy arrays
predictions_np = predictions.numpy()
actual_np = actuals.numpy()

# Inverse transform to get actual values
# predictions_actual = scaler.inverse_transform(predictions_np)
# actual_actual = scaler.inverse_transform(actual_np)

predictions_actual = predictions_np
actual_actual = actual_np

# Adjust the time axis to match the reduced size (excluding first SEQ_LENGTH steps)
t_adjusted = t[SEQ_LENGTH:len(predictions_actual)+SEQ_LENGTH]

# Visualization of the entire dataset (training and test) with t_adjusted as x-axis
plt.figure(figsize=(12, 6))
# Plot for each feature
for i in range(output_size):
    plt.subplot(output_size, 1, i + 1)
    plt.plot(t_high, y_low_high_res[:, i], label="True high")
    plt.scatter(t_low, y_low[:, i], label='Actual', marker='o', c="darkblue")
    plt.scatter(t_adjusted, predictions_actual[:, i], label='Predicted', marker='.',  c="orange", s=3)
    plt.axvline(x=t[train_size + SEQ_LENGTH - 1], color='red', linestyle='--', label='Train/Test Split')  # Mark the split point
    plt.title(f'Feature {i + 1}: {model_name}')
    plt.xlabel('Time (t)')
    plt.ylabel('Value')
    plt.legend()
    plt.grid(True)

plt.tight_layout()
plt.savefig('./results/{}_{}_forecasting.pdf'.format(dataname, model_name), format='pdf', dpi=300)


# Calculate MSE and MAPE for training and test sets separately

if after_interpolation == True:
    actual_actual_low = y_low_high_res[t_low_index]
    actual_high_low = y_high[t_low_index]
    actual_actual = y_low_high_res[SEQ_LENGTH:]
    actual_high = y_high[SEQ_LENGTH:]
else:
    actual_actual = y_low[SEQ_LENGTH:]

train_predictions_actual = predictions_actual[:train_size]
test_predictions_actual = predictions_actual[train_size:]

train_actual_actual = actual_actual[:train_size]
test_actual_actual = actual_actual[train_size:]

train_actual_high = actual_high[:train_size]
test_actual_high = actual_high[train_size:]

train_mse = mean_squared_error(train_actual_actual, train_predictions_actual)
train_mape = mean_absolute_percentage_error(train_actual_actual, train_predictions_actual) * 100

test_mse = mean_squared_error(test_actual_actual, test_predictions_actual)
test_mape = mean_absolute_percentage_error(test_actual_actual, test_predictions_actual) * 100

print(f'{model_name} Training MSE: {train_mse:.4f}')
print(f'{model_name} Training MAPE: {train_mape:.2f}%')

print(f'{model_name} Test MSE: {test_mse:.6f}')
print(f'{model_name} Test MAPE: {test_mape:.2f}%')




if dataname == "spiral" + "_after_interpolation_":
    actual_low = y_low[:]
    actual_low_2 = y_high[t_low_index]

    # print(y_low_2.shape, actual_low2.shape)
    # exit()
    plt.figure(figsize=(6, 5))

    plt.plot(train_actual_high, train_actual_actual, label="True", c="black")
    plt.plot(test_actual_high, test_actual_actual, c="black")
    plt.plot(train_actual_high, train_predictions_actual, label="Pred Train", c="blue")
    plt.plot(test_actual_high, test_predictions_actual, label="Pred Test", c="orange")

    plt.scatter(actual_high_low, actual_actual_low, label ="Low True")
    # plt.scatter(y_low_2, actual_low_2, label="Low True")


    # plt.title(f'{dataname}_{model_name}_2d')
    plt.xlabel('X')
    plt.ylabel('Y')
    # plt.legend()
    plt.savefig('./results/{}_{}_forecasting_2d.pdf'.format(dataname, model_name), format='pdf', dpi=300)

plt.show()