import torch
import torch.nn as nn
import torch.optim as optim
import torchdiffeq
import numpy as np
import matplotlib.pyplot as plt
import random
import get_date_send
import get_date_sendv2
from scipy.interpolate import CubicSpline
import torch.nn.functional as F
from sklearn.decomposition import PCA
import copy
import get_date_send_spiral
import pandas as pd

import matplotlib.pyplot as plt
plt.ion()  #

#  well organized test data,  no need for Q-optimal, shared Neural ODEs, input is HR, freeze ode for lr, lr start from hr, combined hr/lr as input (not used, but have functions)
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


seed = 42
set_seed(seed)

# Function to split combined data into HR and LR input/output and create sequences
def split_and_create_sequences(combined_data, output_size_HR, output_size_LR, seq_length):
    x_hr = combined_data[:, :output_size_HR + output_size_LR]  # Combined HR + LR input
    y_hr = combined_data[:, :output_size_HR]  # HR output
    y_lr = combined_data[:, output_size_HR:]  # LR output

    # Create sequences for HR and LR
    x_hr_seq, y_hr_seq = create_sequences(x_hr, y_hr, seq_length)
    x_lr_seq, y_lr_seq = create_sequences(x_hr, y_lr, seq_length)  # Note: LR uses the same input as HR

    return x_hr_seq, y_hr_seq, x_lr_seq, y_lr_seq


# Function to prepare training and testing data for HR and LR models
def prepare_data_for_training(high_res_data, interpolated_lr_data, train_size, seq_length):

    combined_data = np.concatenate([high_res_data, interpolated_lr_data], axis=1)
    combined_data = high_res_data
    # Split the HR data into training and testing sets
    train_hr_data = combined_data[:train_size]
    test_hr_data = combined_data[train_size:]

    # Split the interpolated LR data into training and testing sets, prepared only for output
    train_lr_data = np.array(interpolated_lr_data[:train_size])
    test_lr_data = np.array(interpolated_lr_data[train_size:])

    # Create sequences for HR model (HR window -> HR next time step)
    train_x_hr, train_y_hr = create_sequences(train_hr_data, high_res_data[:train_size], seq_length) # x, y, seq
    test_x_hr, test_y_hr = create_sequences(test_hr_data, high_res_data[train_size:], seq_length)

    # Create sequences for LR model (HR window -> LR next time step)
    train_x_lr, train_y_lr = create_sequences(train_hr_data, train_lr_data, seq_length)
    test_x_lr, test_y_lr = create_sequences(test_hr_data, test_lr_data, seq_length)

    return train_x_hr, train_y_hr, train_x_lr, train_y_lr, test_x_hr, test_y_hr, test_x_lr, test_y_lr


# Function to set random seed for reproducibility
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


# Compute Mean Squared Error (MSE)
def compute_mse(predicted, true):
    return np.mean(np.square(predicted - true))


# Compute Mean Absolute Percentage Error (MAPE)
def compute_mape(predicted, true):
    return np.mean(np.abs((predicted - true) / true)) * 100


# Function to plot the entire HR data and interpolated LR data, emphasizing true LR data points
def plot_full_interpolation_vs_true(hr_data, interpolated_lr_data, lr_indices, low_res_data):
    plt.figure(figsize=(12, 8))

    # Plot the true HR data

    plt.plot(range(len(hr_data)), hr_data[:, 0], label='True HR Data (Column 1)', marker='o', markersize=2)
    # plt.plot(range(len(hr_data)), hr_data[:, 1], label='True HR Data (Column 2)', marker='o', markersize=2)

    # Plot the interpolated LR data
    plt.plot(range(len(interpolated_lr_data)), interpolated_lr_data[:, 0], label='Interpolated LR Data (Column 1)',
             linestyle='--', marker='x', markersize=2)
    # plt.plot(range(len(interpolated_lr_data)), interpolated_lr_data[:, 1], label='Interpolated LR Data (Column 2)',
    #          linestyle='--', marker='x', markersize=2)

    # Plot the true LR data with larger markers for emphasis
    plt.scatter(lr_indices, low_res_data[:, 0], color='red', label='True LR Data (Column 1)', s=100, edgecolor='black',
                zorder=5)
    # plt.scatter(lr_indices, low_res_data[:, 1], color='blue', label='True LR Data (Column 2)', s=100, edgecolor='black',
    #             zorder=5)

    plt.xlabel('Sample Index')
    plt.ylabel('Data Values')
    plt.title('Full Interpolation vs. True Data')
    plt.legend()
    plt.grid(True)
    plt.show()

def compute_and_plot_test_results(predictions_test, targets_test, train_length, model_type):

    # Compute errors for testing data
    inter_test_data = predictions_test[:train_length, :]
    inter_true_data = targets_test[:train_length, :]
    test_mse = np.mean(np.square(inter_test_data - inter_true_data))
    test_mape = np.mean(np.abs((inter_test_data - inter_true_data) / (inter_true_data))) * 100
    # Print the computed errors
    print(f"{model_type} Model - Interpolation MSE: {test_mse:.4f}, MAPE: {test_mape:.2f}%")

    # Compute errors for testing data
    pre_test_data = predictions_test[train_length:, :]
    pred_true_data = targets_test[train_length:, :]
    test_mse = np.mean(np.square(pre_test_data - pred_true_data))
    test_mape = np.mean(np.abs((pre_test_data - pred_true_data) / (pred_true_data))) * 100
    # Print the computed errors
    print(f"{model_type} Model - Prediction MSE: {test_mse:.4f}, MAPE: {test_mape:.2f}%")



    # Compute errors for testing data
    test_mse = np.mean(np.square(predictions_test - targets_test))
    test_mape = np.mean(np.abs((predictions_test - targets_test) / targets_test )) * 100

    # Print the computed errors
    print(f"{model_type} Model - Total MSE: {test_mse:.4f}, MAPE: {test_mape:.2f}%")

    # Plot the results
    plt.figure(figsize=(12, 6))

    # Determine the length of the testing data
    test_length = len(predictions_test)

    # Plot the testing data, with training period in different colors
    plt.plot(range(train_length), targets_test[:train_length, 0], label='True Train Column 1', color='blue', linestyle='solid')
    plt.plot(range(train_length), predictions_test[:train_length, 0], label='Predicted Train Column 1', color='blue', linestyle='dashed')
    # command out for 200-bus system
    # plt.plot(range(train_length), targets_test[:train_length, 1], label='True Train Column 2', color='green', linestyle='solid')
    # plt.plot(range(train_length), predictions_test[:train_length, 1], label='Predicted Train Column 2', color='green', linestyle='dashed')

    # Plot the remaining testing period in different colors
    plt.plot(range(train_length, test_length), targets_test[train_length:, 0], label='True Test Column 1', color='red', linestyle='solid')
    plt.plot(range(train_length, test_length), predictions_test[train_length:, 0], label='Predicted Test Column 1', color='red', linestyle='dashed')
    # plt.plot(range(train_length, test_length), targets_test[train_length:, 1], label='True Test Column 2', color='orange', linestyle='solid')
    # plt.plot(range(train_length, test_length), predictions_test[train_length:, 1], label='Predicted Test Column 2', color='orange', linestyle='dashed')

    # Add labels, title, and legend
    plt.xlabel('Time Index')
    plt.ylabel('Values')
    plt.title(f'{model_type} Model - Test Results with Distinct Training/Testing Periods')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()

    # Show the plot
    plt.show()
# Merged function to assign and normalize time indices
def assign_and_normalize_time_indices(high_res_data, downsample_index, days=8):
    # Determine the number of samples and samples per day
    total_samples = len(high_res_data)
    samples_per_day = total_samples // days

    # Assign time indices in a repeated [0, 1] manner for HR
    t_high_normalized = np.tile(np.linspace(start_time, max_time + start_time, samples_per_day), days)

    # Handle any remainder by adjusting t_high_normalized accordingly
    remainder = total_samples % (samples_per_day * days)
    if remainder > 0:
        interval = t_high_normalized[1] - t_high_normalized[0]
        extra_time_indices = interval * np.arange(0, remainder)
        t_high_normalized = np.concatenate((t_high_normalized, extra_time_indices))

    t_high_normalized = t_high_normalized  # avoid starting from 0
    # Normalize LR time indices using downsample_index
    t_low_normalized = t_high_normalized[downsample_index]

    return torch.tensor(t_high_normalized, dtype=torch.float32), torch.tensor(t_low_normalized, dtype=torch.float32)


# Function to interpolate the entire LR data to the HR level
def interpolate_lr_to_hr_full(hr_indices, lr_indices, low_res_data, low_to_high_res_data):
    # Create a cubic spline interpolator based on LR indices and data
    cs = CubicSpline(lr_indices, low_res_data, axis=0)


    # Interpolate the LR data across the entire HR index range [:train_size,]
    interpolated_lr_to_hr = cs(hr_indices)

    test_mse = np.mean(np.square(interpolated_lr_to_hr - low_to_high_res_data))
    test_mape = np.mean(np.abs((interpolated_lr_to_hr - low_to_high_res_data) / (low_to_high_res_data ))) * 100
    # Print the computed errors
    print("Cublic spline Interpolation MSE: ", test_mse)
    print("Cublic spline Interpolation MAPE: ", test_mape)

    return torch.tensor(interpolated_lr_to_hr, dtype=torch.float32)


# Function to create sequences for the RNN
def create_sequences(x_data, y_data, seq_length):
    xs, ys = [], []
    for i in range(len(x_data) - seq_length):
        x = x_data[i:i + seq_length]
        y = y_data[i + seq_length]
        xs.append(x)
        ys.append(y)
    return torch.tensor(xs, dtype=torch.float32), torch.tensor(ys, dtype=torch.float32)


# Function to split sequences into batches with correct time indices
def split_sequences_into_batches(x_data, y_data, t_data):
    batches_x = []
    batches_y = []
    batches_t = []
    batch_starts = []  # To store the start indices of each batch
    current_idx = 0

    while current_idx < len(x_data):
        # Find the index where the time reaches the next 1.0
        next_one_idx = torch.where(t_data[current_idx:] == start_time + max_time)[0]

        if len(next_one_idx) > 0:
            end_idx = current_idx + next_one_idx[0].item() + 1
        else:
            # If no exact 1.0 is found, use the remaining data as the last batch
            end_idx = len(x_data)

        x_batch = x_data[current_idx:end_idx]
        y_batch = y_data[current_idx:end_idx]
        t_batch = t_data[current_idx:end_idx]

        # Add the batch to the list, regardless of its size
        batches_x.append(x_batch)
        batches_y.append(y_batch)
        batches_t.append(t_batch)
        batch_starts.append(current_idx)  # Store the start index of this batch

        current_idx = end_idx

    return list(zip(batches_x, batches_y, batches_t)), batch_starts


# Function to map original LR indices to their positions within each batch
def map_lr_indices_to_batches(lr_indices, batch_starts, batch_lengths, seq_length):
    """ Map original LR indices to their positions within each batch, aligned with y_data. """
    batch_lr_indices = []

    # Adjust original lr_indices to account for the shift by seq_length
    adjusted_lr_indices = [idx - seq_length for idx in lr_indices if idx >= seq_length]

    for start, length in zip(batch_starts, batch_lengths):
        # Adjust lr_indices to the batch-relative positions within y_data
        relative_indices = [idx - start for idx in adjusted_lr_indices if start <= idx < start + length]
        batch_lr_indices.append(relative_indices)

    return batch_lr_indices



# Function to handle training and evaluation for the entire dataset
def train_and_evaluate_batches(train_batches, model, criterion, optimizer, num_epochs, is_lr_model=False,
                               lr_indices=None, high_res_data=None, low_to_high_res_data=None, t_high_normalized=None,
                               seq_length=None, y0_hr=None, y0_lr=None, interpolated_lr_data =None):
    final_train_predictions = []
    final_train_targets = []
    final_test_predictions = []
    final_test_targets = []

    lowest_loss = float('inf')
    best_model_state = None
    print_flag = 'Train HR '

    # Select the appropriate initial orthogonal matrix
    y0_initial = y0_lr if is_lr_model else y0_hr

    # Epoch loop
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0

        # Batch loop
        for batch_idx, (x_batch, y_batch, t_batch) in enumerate(train_batches):
            optimizer.zero_grad()

            # Forward pass with appropriate initial orthogonal matrix
            outputs = model(x_batch, t_batch, y0_initial)  # Pass time and initial matrix

            if is_lr_model:
                print_flag = 'Train LR '
                # Use batch-specific LR indices
                lr_mask = torch.zeros_like(y_batch, dtype=torch.bool)
                for idx in lr_indices[batch_idx]:
                    lr_mask[idx, :] = True

                # Use torch.where to zero out the unmasked elements
                masked_outputs = torch.where(lr_mask, outputs, torch.tensor(0.0, device=outputs.device))
                masked_targets = torch.where(lr_mask, y_batch, torch.tensor(0.0, device=y_batch.device))

                # # still use spline-based interpolation as pseudo-labels
                masked_outputs = outputs
                masked_targets = y_batch

                # Compute the loss on the full tensors (only non-zero elements contribute)
                loss = criterion(masked_outputs, masked_targets)
            else:
                # For HR data, evaluate loss for each output
                loss = criterion(outputs, y_batch)

            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        # Check if this is the lowest loss so far, and save the model state if it is
        if epoch_loss < lowest_loss:
            lowest_loss = epoch_loss
            best_model_state = copy.deepcopy(model.state_dict())

        # Print average loss for the epoch
        print(print_flag + f'Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss / len(train_batches):.4f}')

    # Load the best model state for evaluation
    model.load_state_dict(best_model_state)

    # After training is complete, evaluate on the entire training data to get final predictions
    model.eval()
    with torch.no_grad():
        for x_batch, y_batch, t_batch in train_batches:
            train_predictions = model(x_batch, t_batch, y0_initial).cpu().numpy()
            final_train_predictions.append(train_predictions)
            final_train_targets.append(y_batch.numpy())

        combined_data = np.concatenate([high_res_data, interpolated_lr_data], axis=1)
        combined_data = high_res_data
        # Prepare test data inputs and ground truth outputs for HR and LR models
        test_x_hr, test_y_hr = create_sequences(combined_data, high_res_data, seq_length)  # x, y, seq
        _, test_y_lr = create_sequences(combined_data, low_to_high_res_data, seq_length)
        t_test_normalized = t_high_normalized[seq_length:]

        # Split the test sequences into batches for HR and LR models
        test_batches_hr, test_batch_starts_hr = split_sequences_into_batches(test_x_hr, test_y_hr, t_test_normalized)
        test_batches_lr, test_batch_starts_lr = split_sequences_into_batches(test_x_hr, test_y_lr, t_test_normalized)

        # Evaluate on the test data
        if is_lr_model:
            for x_batch, y_batch, t_batch in test_batches_lr:
                test_predictions_lr = model(x_batch, t_batch, y0_lr).cpu().numpy()
                final_test_predictions.append(test_predictions_lr)
                final_test_targets.append(y_batch.numpy())
        else:
            for x_batch, y_batch, t_batch in test_batches_hr:
                test_predictions_hr = model(x_batch, t_batch, y0_hr).cpu().numpy()
                final_test_predictions.append(test_predictions_hr)
                final_test_targets.append(y_batch.numpy())

    # Concatenate all batches to get the final outputs
    final_train_predictions = np.concatenate(final_train_predictions, axis=0)
    final_train_targets = np.concatenate(final_train_targets, axis=0)
    final_test_predictions = np.concatenate(final_test_predictions, axis=0)
    final_test_targets = np.concatenate(final_test_targets, axis=0)

    return final_train_predictions, final_test_predictions, final_train_targets, final_test_targets


def generate_W_hh_seq_for_batches(train_batches_hr, train_batches_lr, model_HR, model_LR, y0_hr, y0_lr):
    """Generate W_hh_seq for both HR and LR batches, applying lr_mask to LR batches."""
    W_hh_seq_batches_hr = []
    W_hh_seq_batches_lr = []

    # Generate W_hh_seq for HR batches
    for batch_idx, (_, _, t_batch) in enumerate(train_batches_hr):
        W_hh_seq_hr = model_HR.neural_ode._generate_W_hh(t_batch, y0_hr)
        W_hh_seq_batches_hr.append(W_hh_seq_hr)

    # Generate W_hh_seq for LR batches with masking
    for batch_idx, (_, _, t_batch) in enumerate(train_batches_lr):
        W_hh_seq_lr = model_LR.neural_ode._generate_W_hh(t_batch, y0_lr)
        W_hh_seq_batches_lr.append(W_hh_seq_lr)

        # # Apply lr_mask
        # lr_mask = torch.zeros_like(W_hh_seq_lr, dtype=torch.bool)
        # for idx in lr_indices[batch_idx]:
        #     lr_mask[idx, :, :] = True
        #
        # W_hh_seq_lr_masked = torch.where(lr_mask, W_hh_seq_lr, torch.tensor(0.0, device=W_hh_seq_lr.device))
        # W_hh_seq_batches_lr.append(W_hh_seq_lr_masked)

    return W_hh_seq_batches_hr, W_hh_seq_batches_lr



def check_orthogonality(Q_optimal):
    # Check orthogonality of Q_optimal
    I_approx = torch.matmul(Q_optimal.T, Q_optimal)
    difference = torch.norm(I_approx - torch.eye(Q_optimal.size(0), device=Q_optimal.device))
    is_orthogonal = torch.allclose(I_approx, torch.eye(Q_optimal.size(0), device=Q_optimal.device), atol=1e-6)
    print(f"Difference from Identity: {difference}")
    print(f"Is orthogonal? {'Yes' if is_orthogonal else 'No'}")


def compute_Q_and_calculate_error(W_hh_seq_batches_hr, W_hh_seq_batches_lr, lr_indices):
    """Compute the optimal orthogonal matrix Q and calculate the transformation error, considering masked W_hr and W_lr."""

    # Initialize the aggregate product matrix M
    latent_dim = W_hh_seq_batches_hr[0].size(1)  # assuming square matrices
    M = torch.zeros(latent_dim, latent_dim, device=W_hh_seq_batches_hr[0].device)

    # Accumulate the product matrices for all batches
    for i in range(len(W_hh_seq_batches_hr)):
        W_hr = W_hh_seq_batches_hr[i]
        W_lr = W_hh_seq_batches_lr[i]

        # Apply the mask to both W_hr and W_lr to ensure only the relevant elements are considered
        lr_mask = torch.zeros_like(W_lr, dtype=torch.bool)
        for idx in lr_indices[i]:
            lr_mask[idx, :, :] = True

        W_hr_masked = torch.where(lr_mask, W_hr, torch.tensor(0.0, device=W_hr.device))
        W_lr_masked = torch.where(lr_mask, W_lr, torch.tensor(0.0, device=W_lr.device))

        # Perform batch-wise matrix multiplication and sum the results
        batch_product = torch.matmul(W_hr_masked.transpose(1, 2), W_lr_masked)
        M += batch_product.sum(dim=0)

    # Perform SVD on M
    U, _, V = torch.svd(M)

    # Compute the optimal orthogonal matrix Q
    Q_optimal = torch.matmul(U, V.t())

    # Check orthogonality of Q_optimal
    print('Check orthogonality for Q-optimal')
    check_orthogonality(Q_optimal)

    # Apply Q_optimal to W_hh_seq_batches_hr and compute the error
    total_error = 0.0
    count = 0

    for i in range(len(W_hh_seq_batches_hr)):
        W_hr = W_hh_seq_batches_hr[i]
        W_lr = W_hh_seq_batches_lr[i]

        # Apply the mask to both W_hr and W_lr to ensure only the relevant elements are considered
        lr_mask = torch.zeros_like(W_lr, dtype=torch.bool)
        for idx in lr_indices[i]:
            lr_mask[idx, :, :] = True

        # Apply the mask again to W_hr and W_lr
        W_hr_masked = torch.where(lr_mask, W_hr, torch.tensor(0.0, device=W_hr.device))
        W_lr_masked = torch.where(lr_mask, W_lr, torch.tensor(0.0, device=W_lr.device))

        # Apply the optimal Q to W_hr_masked (right-side multiplication)
        W_hr_transformed = torch.matmul(W_hr_masked, Q_optimal)

        # Calculate the error only on the masked elements
        error = F.mse_loss(W_hr_transformed, W_lr_masked)
        total_error += error.item() * lr_mask.sum().item()
        count += lr_mask.sum().item()

    # Calculate the average transformation error
    transformation_error = total_error / count

    print(f"Transformation Error: {transformation_error}")

    # # Additional check with a single pair of matrices
    # W_hr_single = W_hh_seq_batches_hr[0][4]  # Select a specific matrix from the first batch
    # W_lr_single = W_hh_seq_batches_lr[0][4]  # Corresponding low-resolution matrix
    #
    # print("\n--- Single Matrix Check ---")
    # # Compute the single optimal Q and transformation error
    # M_single = torch.matmul(W_hr_single.transpose(0, 1), W_lr_single)
    # U_single, _, V_single = torch.svd(M_single)
    # Q_optimal_single = torch.matmul(U_single, V_single.t())
    #
    # W_hr_single_transformed = torch.matmul(W_hr_single, Q_optimal_single)
    # single_error = F.mse_loss(W_hr_single_transformed, W_lr_single)
    #
    # # Check orthogonality of Q_optimal_single
    # print('Check orthogonality for Q_optimal_single')
    # check_orthogonality(Q_optimal_single)
    # print(f"Single Matrix Transformation Error: {single_error.item()}")

    return Q_optimal, transformation_error

def visualize_pca_matrix_flow_with_lr_marks(W_hh_seq_batches_hr, W_hh_seq_batches_lr, batch_index=1, lr_indices=None):
    """
    Perform PCA separately on W_hr_batch and W_lr_batch and visualize the matrix flow in 3D,
    with markers only for non-masked LR PCA data.
    :param W_hh_seq_batches_hr: Sequence of HR matrices.
    :param W_hh_seq_batches_lr: Sequence of LR matrices.
    :param batch_index: Index of the batch to visualize (0-based).
    :param lr_indices: Indices of non-masked LR data points.
    """

    # Extract the matrices for the specified batch
    W_hr_batch = W_hh_seq_batches_hr[batch_index]
    W_lr_batch = W_hh_seq_batches_lr[batch_index]

    # Reshape the batch of matrices into 2D arrays where each row is a flattened matrix
    num_matrices_hr = W_hr_batch.size(0)
    num_matrices_lr = W_lr_batch.size(0)

    matrix_size_hr = W_hr_batch.size(1) * W_hr_batch.size(2)
    matrix_size_lr = W_lr_batch.size(1) * W_lr_batch.size(2)

    W_hr_flattened = W_hr_batch.view(num_matrices_hr, matrix_size_hr).detach().cpu().numpy()
    W_lr_flattened = W_lr_batch.view(num_matrices_lr, matrix_size_lr).detach().cpu().numpy()

    # Concatenate HR and LR flattened matrices
    W_concat_flattened = np.concatenate([W_hr_flattened, W_lr_flattened], axis=0)

    # # Apply PCA on the concatenated data
    # pca = PCA(n_components=3)
    # pca.fit(W_concat_flattened)
    #
    # # Transform both HR and LR using the same PCA projection
    # W_hr_pca = pca.transform(W_hr_flattened)
    # W_lr_pca = pca.transform(W_lr_flattened)

    # Apply PCA to the HR data without centralization
    pca_hr = PCA(n_components=3)
    pca_hr.fit(W_hr_flattened)
    W_hr_pca = np.dot(W_hr_flattened, pca_hr.components_.T)

    # Apply PCA to the LR data without centralization
    pca_lr = PCA(n_components=3)
    pca_lr.fit(W_lr_flattened)
    W_lr_pca = np.dot(W_lr_flattened, pca_lr.components_.T)

    # Create a 3D plot
    fig = plt.figure(figsize=(14, 10))
    ax = fig.add_subplot(111, projection='3d')

    # Step 4: Save W_hh_pca to a CSV file
    df = pd.DataFrame(W_hr_pca, columns=['PC1', 'PC2', 'PC3'])  # Create a DataFrame with appropriate column names
    df.to_csv('W_hh_pca_hr_ours.csv', index=False)  # Save to CSV without the index column

    # Step 4: Save W_hh_pca to a CSV file
    df = pd.DataFrame(W_lr_pca, columns=['PC1', 'PC2', 'PC3'])  # Create a DataFrame with appropriate column names
    df.to_csv('W_hh_pca_lr_ours.csv', index=False)  # Save to CSV without the index column

    # Plot the trajectory of the HR matrix flow in 3D space without markers
    ax.plot(W_hr_pca[:, 0], W_hr_pca[:, 1], W_hr_pca[:, 2], label='HQ Parameter Flow', color='blue')

    # Plot the trajectory of the LR matrix flow in 3D space without markers
    ax.plot(W_lr_pca[:, 0], W_lr_pca[:, 1], W_lr_pca[:, 2], label='LQ Parameter Flow', color='green', linestyle='--')

    # Add markers only for non-masked LR data points
    if lr_indices is not None:
        W_lr_non_masked = W_lr_pca[lr_indices[batch_index]]
        ax.scatter(W_lr_non_masked[:, 0], W_lr_non_masked[:, 1], W_lr_non_masked[:, 2], color='green', marker='o', s=50, label='LQ Sample Parameter')

    ax.set_xlabel('PC 1')
    ax.set_ylabel('PC 2')
    ax.set_zlabel('PC 3')
    # ax.set_title(f'3D PCA Visualization of Matrix Flow (Batch {batch_index + 1})')
    ax.legend()
    plt.show()
    plt.savefig('./results/Ours.pdf', format='pdf', dpi=300)

# Step 6: Define the Geo_RNN model class
class NeuralODE(nn.Module):
    def __init__(self, latent_dim, hidden_size, num_matrices):
        super(NeuralODE, self).__init__()
        self.hidden_size = hidden_size
        self.num_matrices = num_matrices
        # Define the ODE function part
        self.fc1_list = nn.ModuleList([nn.Linear(1, latent_dim) for _ in range(num_matrices)])
        self.fc2_list = nn.ModuleList([nn.Linear(latent_dim, hidden_size * hidden_size) for _ in range(num_matrices)])
        self.coefficients = nn.Parameter(torch.randn(num_matrices))

    def _generate_W_hh(self, t, y0):
        y0 = y0.to(t.device)

        # Ensure t is one-dimensional
        t = t.squeeze()  # This ensures that t is 1D before passing to odeint

        rtol = 1e-3  # Increase the relative tolerance (default is usually 1e-5)
        atol = 1e-4  # Increase the absolute tolerance (default is usually 1e-6)

        # Solve the ODE, ode_out will have shape [time_steps, 1, hidden_size * hidden_size]
        ode_out = torchdiffeq.odeint(self._ode_func, y0, t, rtol=rtol, atol=atol, method='euler')

        # Reshape the output to [batch_size, hidden_size, hidden_size]
        ode_out = ode_out.squeeze(1)  # Remove the extra dimension for batch processing

        return ode_out.reshape(t.size(0), self.hidden_size, self.hidden_size)

    def _ode_func(self, t, y):
        skew_symmetric_matrices = []

        for i in range(len(self.fc1_list)):
            # Generate matrix Ai using the neural network
            A = self.fc1_list[i](t.view(-1, 1))  # Pass t through the first layer
            A = torch.tanh(A)  # Apply non-linearity
            A = self.fc2_list[i](A)  # Get the output matrix A

            # Reshape A to be a square matrix of size [hidden_size, hidden_size]
            A = A.view(-1, self.hidden_size, self.hidden_size)

            # Compute the skew-symmetric matrix Si = Ai - Ai.T
            S = A - A.transpose(1, 2)
            skew_symmetric_matrices.append(S)

        # Compute the final skew-symmetric matrix S = a1 * S1 + a2 * S2 + ... + aM * SM
        combined_S = sum(self.coefficients[i] * skew_symmetric_matrices[i] for i in range(len(self.fc1_list)))
        combined_S = combined_S/self.num_matrices
        # Compute the time derivative of y (dy/dt = combined_S * y)
        out = torch.matmul(combined_S, y.view(-1, self.hidden_size, self.hidden_size))

        # Reshape the output to match the expected size
        return out.view(-1, self.hidden_size * self.hidden_size)


class ModelHR(nn.Module):
    def __init__(self, hidden_size, input_size, output_size, neural_ode):
        super(ModelHR, self).__init__()
        self.hidden_size = hidden_size
        self.rnn_cell = nn.RNNCell(input_size, hidden_size)  # Example: Using GRUCell for HR model
        self.fc = nn.Linear(hidden_size, output_size)
        self.neural_ode = neural_ode  # Shared Neural ODE

    def forward(self, x, t, y0_hr):
        batch_size, seq_length, _ = x.size()
        h_t = torch.zeros(batch_size, self.hidden_size).to(x.device)
        W_hh_seq = self.neural_ode._generate_W_hh(t, y0_hr)  # Use shared Neural ODE

        for i in range(seq_length):
            h_t_expanded = h_t.unsqueeze(1)
            h_t_time_weight = torch.bmm(h_t_expanded, W_hh_seq)
            h_t_time_weight = h_t_time_weight.squeeze(1)
            h_t = torch.tanh(
                x[:, i] @ self.rnn_cell.weight_ih.t() + h_t_time_weight + self.rnn_cell.bias_ih + self.rnn_cell.bias_hh)

        output = self.fc(h_t)
        return output


class ModelLR(nn.Module):
    def __init__(self, hidden_size, input_size, output_size, neural_ode):
        super(ModelLR, self).__init__()
        self.hidden_size = hidden_size
        self.rnn_cell = nn.RNNCell(input_size, hidden_size)  # Using RNNCell for LR model
        self.fc = nn.Linear(hidden_size, output_size)
        self.neural_ode = neural_ode  # Shared Neural ODE

    def forward(self, x, t, y0_lr):
        batch_size, seq_length, _ = x.size()
        h_t = torch.zeros(batch_size, self.hidden_size).to(x.device)
        W_hh_seq = self.neural_ode._generate_W_hh(t, y0_lr)  # Use shared Neural ODE with y0_lr

        for i in range(seq_length):
            h_t_expanded = h_t.unsqueeze(1)
            h_t_time_weight = torch.bmm(h_t_expanded, W_hh_seq)
            h_t_time_weight = h_t_time_weight.squeeze(1)
            h_t = torch.tanh(
                x[:, i] @ self.rnn_cell.weight_ih.t() + h_t_time_weight + self.rnn_cell.bias_ih + self.rnn_cell.bias_hh)


        output = self.fc(h_t)
        return output



plt.ion()


# Train and evaluate models as needed


# sprial data
# t_high, high_res_data, t_low, low_res_data, low_to_high_res_data, downsample_index, lr_interval = get_date_send_spiral.get_data_spiral() # load data

# load data
t_high, high_res_data, t_low, low_res_data, low_to_high_res_data, downsample_index, lr_interval = get_date_send.get_load_oncor() # load data

# # pv data
# base_path = "C:\\Software\\Geometric-DL\\HR-LR-DL\\Data\\onemin-Ground-2017-01-"
#
# # Initialize lists to store results
# high_res_data_list = []
# low_res_data_list = []
# low_to_high_res_data_list = []
# downsample_index_list = []
#
# # Initialize offset for downsample_index
# offset = 0
# # Loop through numbers 1 to 8 and construct file paths
# day_profile = 11 # input 10 days' data
# for i in range(1, day_profile):
#     if i < 10:
#         path = f"{base_path}{0}{i}.csv"
#     else:
#         path = f"{base_path}{i}.csv"
#     t_high1, high_res_data1, t_low1, low_res_data1, low_to_high_res_data1, downsample_index1, lr_interval = get_date_sendv2.get_data_PV_online(path)
#     high_res_data_list.append(high_res_data1)
#     low_res_data_list.append(low_res_data1)
#     low_to_high_res_data_list.append(low_to_high_res_data1)
#
#     # Adjust downsample_index with the current offset and add to the list
#     adjusted_downsample_index = downsample_index1 + offset
#     downsample_index_list.append(adjusted_downsample_index)
#
#     # Update the offset based on the current high_res_data length
#     offset += high_res_data1.shape[0]
#
# # Concatenate the data lists into single arrays
# high_res_data = np.concatenate(high_res_data_list, axis=0)
# low_res_data = np.concatenate(low_res_data_list, axis=0)
# low_to_high_res_data = np.concatenate(low_to_high_res_data_list, axis=0)
# downsample_index = np.concatenate(downsample_index_list, axis=0)


# transient voltage data

# base_path = "C:\\Software\\Geometric-DL\\HR-LR-DL\\Data\\VA-event3-"
#
# # Initialize lists to store results
# high_res_data_list = []
# low_res_data_list = []
# low_to_high_res_data_list = []
# downsample_index_list = []
#
# # Initialize offset for downsample_index
# offset = 0
# # Loop through numbers 1 to 8 and construct file paths
# day_profile = 4 # input 3 days' data
# for i in range(1, day_profile):
#     path = f"{base_path}{i}.csv"
#     t_high1, high_res_data1, t_low1, low_res_data1, low_to_high_res_data1, downsample_index1, lr_interval = get_date_sendv2.get_data_200bus_9_10(path)
#     high_res_data_list.append(high_res_data1)
#     low_res_data_list.append(low_res_data1)
#     low_to_high_res_data_list.append(low_to_high_res_data1)
#
#     # Adjust downsample_index with the current offset and add to the list
#     adjusted_downsample_index = downsample_index1 + offset
#     downsample_index_list.append(adjusted_downsample_index)
#
#     # Update the offset based on the current high_res_data length
#     offset += high_res_data1.shape[0]
#
# # Concatenate the data lists into single arrays
# high_res_data = np.concatenate(high_res_data_list, axis=0)
# low_res_data = np.concatenate(low_res_data_list, axis=0)
# low_to_high_res_data = np.concatenate(low_to_high_res_data_list, axis=0)
# downsample_index = np.concatenate(downsample_index_list, axis=0)


# air quality data
# t_high, high_res_data, t_low, low_res_data, low_to_high_res_data, downsample_index, lr_interval = get_date_sendv2.get_data_ari_quality()

plt.show()


# Step 1: Assign time indices of [0, 1] repeated for 8 times
total_samples = len(high_res_data)

# load
days = 8

# spiral
# days = 10
# pv
# days = 10

# 200bus
# days = 5 # for air quality data, 5 weeks

start_time = 0
max_time = 0.3 # most is 1 except spiral
# Define the sequence length
seq_length = 10

# pv
# train_test_ratio = 0.8

# 200 bus, air quality
# train_test_ratio = 0.66


# load
train_test_ratio = 0.66

train_size = int(train_test_ratio * total_samples)

# Assign and normalize time indices
t_high_normalized, t_low_normalized = assign_and_normalize_time_indices(high_res_data, downsample_index, days=days)

# Step 2: Interpolate LR data to HR level
hr_indices = np.arange(total_samples)  # HR indices (full range)
interpolated_lr_to_hr = interpolate_lr_to_hr_full(hr_indices, downsample_index, low_res_data, low_to_high_res_data)

# Example usage for plotting the entire interpolated LR data with emphasized true LR data points
plot_full_interpolation_vs_true(low_to_high_res_data, interpolated_lr_to_hr, downsample_index, low_res_data)




# Step 3: Prepare training and testing data for HR and LR models

train_x_hr, train_y_hr, train_x_lr, train_y_lr, test_x_hr, test_y_hr, test_x_lr, test_y_lr = prepare_data_for_training(
    high_res_data, interpolated_lr_to_hr, train_size, seq_length)

# Adjust the time indices for the sequences
t_train_normalized = t_high_normalized[:train_size][seq_length:]  # Align time indices with sequences
t_test_normalized = t_high_normalized[train_size:][seq_length:]  # Align time indices with sequences

# Step 4: Split sequences into batches with correct time indices

train_batches_hr, train_batch_starts_hr = split_sequences_into_batches(train_x_hr, train_y_hr, t_train_normalized)
test_batches_hr, test_batch_starts_hr = split_sequences_into_batches(test_x_hr, test_y_hr, t_test_normalized)

train_batches_lr, train_batch_starts_lr = split_sequences_into_batches(train_x_lr, train_y_lr, t_train_normalized)
test_batches_lr, test_batch_starts_lr = split_sequences_into_batches(test_x_lr, test_y_lr, t_test_normalized)

# Step 5: Map original LR indices to batch-specific indices, aligned with y_data
train_batch_lr_indices = map_lr_indices_to_batches(downsample_index, train_batch_starts_lr, [len(batch[1]) for batch in train_batches_lr], seq_length)
test_batch_lr_indices = map_lr_indices_to_batches(downsample_index, test_batch_starts_lr, [len(batch[1]) for batch in test_batches_lr], seq_length)

# Step 6: Initialize the Geo_RNN models
# load data parameters
latent_dim = 80
hidden_size = 60

# # pv data parameters
# latent_dim = 120
# hidden_size = 60

# # 200-bus data parameters
# latent_dim = 80
# hidden_size = 60

# air quality data parameters
# latent_dim = 90
# hidden_size = 60


num_matrices = 1  # Number of skew-symmetric matrices to combine  1 for load
input_size = train_x_hr.shape[2]  # Input size is the number of columns in the combined data
output_size_HR = high_res_data.shape[1]  # Should match the number of columns in high_res_data (HR)
output_size_LR = low_res_data.shape[1]  # Should match the number of columns in low_res_data (LR)

# Hyperparameters
learning_rate = 0.0005 # 0.0005 for pv and load and 200-bus and air, 0.001 for spiral
num_epochs = 400  # 200 for load, 1000 for spiral


# Define two different initial orthogonal matrices
y0_hr = torch.eye(hidden_size).reshape(1, hidden_size * hidden_size)  # Identity matrix
y0_lr = torch.randn(hidden_size, hidden_size)
y0_lr, _ = torch.qr(y0_lr)  # QR decomposition to ensure y0_lr is orthogonal
y0_lr = y0_lr.reshape(1, hidden_size * hidden_size)
print('Check orthogonality for y0_lr:')
check_orthogonality(y0_lr.view(hidden_size, hidden_size))

# Shared Neural ODE
neural_ode = NeuralODE(latent_dim=latent_dim, hidden_size=hidden_size, num_matrices=num_matrices)

# Separate RNN models with shared Neural ODE
model_HR = ModelHR(hidden_size=hidden_size, input_size=input_size, output_size=output_size_HR, neural_ode=neural_ode)
model_LR = ModelLR(hidden_size=hidden_size, input_size=input_size, output_size=output_size_LR, neural_ode=neural_ode)

# Training for HR and LR models
criterion = nn.MSELoss()
optimizer_HR = optim.Adam(list(model_HR.parameters()) + list(neural_ode.parameters()), lr=learning_rate)
optimizer_LR = optim.Adam(list(model_LR.parameters()) + list(neural_ode.parameters()), lr=learning_rate)

criterion = nn.MSELoss()
optimizer_HR = optim.Adam(model_HR.parameters(), lr=learning_rate)
optimizer_LR = optim.Adam(model_LR.parameters(), lr=learning_rate)

# Step 7: Training the models

# Training Geo_RNN_HR
# Step 7: Training the models

# Training Geo_RNN_HR
# Training HR model
final_train_predictions_HR, final_test_predictions_HR, final_train_targets_HR, final_test_targets_HR = train_and_evaluate_batches(
    train_batches_hr, model_HR, criterion, optimizer_HR, num_epochs,
    high_res_data=high_res_data, low_to_high_res_data=low_to_high_res_data, t_high_normalized=t_high_normalized,
    seq_length=seq_length, y0_hr=y0_hr, y0_lr=y0_lr, interpolated_lr_data = interpolated_lr_to_hr)

compute_and_plot_test_results(final_test_predictions_HR, final_test_targets_HR, len(final_train_predictions_HR), model_type="HR")

#  Initialize LR model with HR model's weights
model_LR.load_state_dict(model_HR.state_dict())

# # Freeze neural_ode parameters
# for param in model_LR.neural_ode.parameters():
#     param.requires_grad = False
#
# # Define optimizer (only optimize parameters that require gradients)
# optimizer_LR = torch.optim.Adam(filter(lambda p: p.requires_grad, model_LR.parameters()), lr=learning_rate)

# Training LR model with correct `lr_indices` for each batch
final_train_predictions_LR, final_test_predictions_LR, final_train_targets_LR, final_test_targets_LR = train_and_evaluate_batches(
    train_batches_lr, model_LR, criterion, optimizer_LR, num_epochs,
    is_lr_model=True, lr_indices=train_batch_lr_indices,
    high_res_data=high_res_data, low_to_high_res_data=low_to_high_res_data, t_high_normalized=t_high_normalized,
    seq_length=seq_length, y0_hr=y0_hr, y0_lr=y0_lr, interpolated_lr_data = interpolated_lr_to_hr)

compute_and_plot_test_results(final_test_predictions_LR, final_test_targets_LR, len(final_train_predictions_LR), model_type="LR")


# Example usage after training:
# Generate W_hh_seq for both HR and LR batches
W_hh_seq_batches_hr, W_hh_seq_batches_lr = generate_W_hh_seq_for_batches(
    train_batches_hr, train_batches_lr, model_HR, model_LR, y0_hr, y0_lr)

# Example usage
# Assume W_hh_seq_batches_hr and W_hh_seq_batches_lr are generated, and lr_indices are available

# Step 1: Compute the optimal Q
Q_optimal, transformation_error = compute_Q_and_calculate_error(W_hh_seq_batches_hr, W_hh_seq_batches_lr, train_batch_lr_indices)

# def plot_3d_W_hh_entries_for_batch(W_hh_seq_batches_hr, W_hh_seq_batches_lr, lr_indices, entries, batch_index=1):
#     """
#     Create a 3D plot for three specific entries in W_hh_seq_batches_hr and W_hh_seq_batches_lr for a specific batch.
#     :param W_hh_seq_batches_hr: Sequence of HR matrices.
#     :param W_hh_seq_batches_lr: Sequence of LR matrices.
#     :param lr_indices: Indices where masks are applied to LR matrices.
#     :param entries: List of three tuples specifying the matrix entries to plot, e.g., [(0, 1), (1, 2), (2, 3)].
#     :param batch_index: Index of the batch to plot (0-based).
#     """
#
#     # Ensure exactly three entries are provided
#     assert len(entries) == 3, "Please provide exactly three entries as a list of tuples."
#
#     # Extract the specific batch
#     W_hr_batch = W_hh_seq_batches_hr[batch_index]
#     W_lr_batch = W_hh_seq_batches_lr[batch_index]
#     lr_mask_indices = lr_indices[batch_index]
#
#     # Prepare data for plotting
#     W_hr_values = {i: [] for i in range(3)}
#     W_lr_values = {i: [] for i in range(3)}
#     time_indices = list(range(W_hr_batch.size(0)))
#     masked_time_indices = []
#
#     for t in range(W_hr_batch.size(0)):
#         for i, (row_idx, col_idx) in enumerate(entries):
#             W_hr_values[i].append(W_hr_batch[t, row_idx, col_idx].item())  # Extract HR entry
#             W_lr_values[i].append(W_lr_batch[t, row_idx, col_idx].item())  # Extract LR entry
#
#         if t in lr_mask_indices:
#             masked_time_indices.append(t)
#
#     # Convert to arrays for plotting
#     W_hr_values = [np.array(W_hr_values[i]) for i in range(3)]
#     W_lr_values = [np.array(W_lr_values[i]) for i in range(3)]
#
#     # Create a 3D plot
#     fig = plt.figure(figsize=(14, 10))
#     ax = fig.add_subplot(111, projection='3d')
#
#     # Plot HR values
#     ax.plot(W_hr_values[0], W_hr_values[1], W_hr_values[2], label='HR', color='blue', linestyle='-')
#
#     # Plot LR values
#     ax.plot(W_lr_values[0], W_lr_values[1], W_lr_values[2], label='LR', color='red', linestyle='--')
#
#     # Mark the masked LR points
#     ax.scatter(W_lr_values[0][masked_time_indices], W_lr_values[1][masked_time_indices], W_lr_values[2][masked_time_indices],
#                color='red', edgecolor='black', s=50, label='Masked LR Points')
#
#     ax.set_xlabel(f'W_hh[{entries[0][0]}, {entries[0][1]}]')
#     ax.set_ylabel(f'W_hh[{entries[1][0]}, {entries[1][1]}]')
#     ax.set_zlabel(f'W_hh[{entries[2][0]}, {entries[2][1]}]')
#     ax.set_title(f'3D Plot of Selected W_hh Entries Over Time for Batch {batch_index + 1}')
#     ax.legend()
#     plt.show()

# plot_3d_W_hh_entries_for_batch(W_hh_seq_batches_hr, W_hh_seq_batches_lr, train_batch_lr_indices, [(0, 1), (1, 2), (2, 3)], batch_index=1)

visualize_pca_matrix_flow_with_lr_marks(W_hh_seq_batches_hr, W_hh_seq_batches_lr, batch_index=1, lr_indices=train_batch_lr_indices)


a = 1
# plot for spiral

# plt.figure(figsize=(6, 5))
#
# plt.plot(high_res_data, low_to_high_res_data, label="True", c="black")  # train_high_true, train_low_in_high_true
# # plt.plot(test_actual_high, test_actual_actual, c="black")  # test_high_true, test_low_in_high_true
# plt.plot(high_res_data[seq_length:train_size + seq_length,], final_test_predictions_LR[:train_size,], label="Pred Train", c="blue")  # train_high_true, train_predict
# plt.plot(high_res_data[train_size + seq_length:,], final_test_predictions_LR[train_size:,], label="Pred Test", c="orange")  # test_high_true, test_predict
# plt.scatter(high_res_data[downsample_index], low_res_data, label="Low True")
#
# # plt.title(f'{dataname}_{model_name}_2d')
# plt.xlabel('X')
# plt.ylabel('Y')
# # plt.legend()
# plt.savefig('./results/{}_{}_forecasting_2d.pdf'.format('spiral', 'Ours'), format='pdf', dpi=300)




