import numpy as np
import pickle
import os
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.linear_model import LinearRegression
from scipy.stats import shapiro, ttest_1samp
from sklearn.metrics import r2_score
from statsmodels.stats.diagnostic import het_breuschpagan
import statsmodels.api as sm
# Save data to a file
def save_data(data, filename):
    with open(filename, 'wb') as f:
        pickle.dump(data, f)
    print(f"Data saved to {filename}")

# Load data from a file, or generate if not available
def load_data(N, d_y, d_z, max_length, q, try_load_from_file, filename=None):
    if filename is None:
        filename = f'time_series_N{N}_dy{d_y}_dz{d_z}_len{max_length}_q{q}.pkl'
    
    # Check if file exists
    if os.path.exists(filename) and try_load_from_file:
        print(f"Data found in {filename}, loading data...")
        with open(filename, 'rb') as f:
            data = pickle.load(f)
    else:
        print(f"Data not found. Generating new data...")
        data = generate_time_series_with_global_mean_dependency(N, d_y, d_z, max_length,q)
        save_data(data, filename)
    return data

def generate_time_series_with_global_mean_dependency(N, d_y, d_z, max_length,q):
    time_series_data = []
    
    # Random weights for covariate and time series dependencies
    W_y = np.random.randn(d_y, d_y)*0.5  # Weights for past Y
    W_x = np.random.randn(d_y, d_z)*0.5  # Weights for past X
    W_global = np.random.randn(d_y)  # Weights for the global mean of Y

    eps = 0.1
    scaling_W_y = np.max(np.abs(np.linalg.eig(W_y)[0]))+ eps
    scaling_W_x = d_z
    scaling_W_global = np.max(np.abs((W_global))) + eps
    W_y /= scaling_W_y
    W_x /= scaling_W_x
    W_global /= scaling_W_global
    
    for i in range(N):
        start_time = np.random.randint(0, max_length // 2)
        end_time = np.random.randint(start_time + q+2, max_length) # assume q < max_length//2
        length = end_time - start_time + 1
        
        X_i = np.random.randn(d_z, length)  # Covariates
        Y_i = np.zeros((d_y, length))  # Time series initialized to zeros
        
        time_series_data.append((Y_i, X_i, start_time, end_time))
    
    # Iterate over each time step and compute Y based on the past values and global mean
    for t in range(1, max_length):  # Start from t=1 because t=0 has no previous data
        available_series = [
            i for i, (_, _, s_i, t_i) in enumerate(time_series_data)
            if s_i <= t-1 and t <= t_i #and #and t - 1 < t_i - s_i + 1  # Ensure t-1 is in bounds
        ]
        
        if len(available_series) == 0:
            continue
        
        # Compute global mean of Y's at time t-1, considering only valid series
        global_mean = np.mean([(time_series_data[i][0][:, t-time_series_data[i][2]-1]) for i in available_series], axis=0)
        global_mean = (global_mean - np.mean(global_mean)) / (np.std(global_mean) + 1e-5)  # Standardize
        
        # Update each available Y_i at time t
        for i in available_series:
            Y_i, X_i, s_i, t_i = time_series_data[i]
            
            
            time_series_data[i][0][:, t - s_i] = (W_y @ Y_i[:, t-1 - s_i] +     # Self-dependency on previous Y
                                W_x @ X_i[:, t-1 - s_i] +     # Covariate dependency on previous X
                                W_global * global_mean +      # Global mean dependency on previous Y's
                                np.random.randn(d_y) * 0.3) #makes it depend on abs time   # Adding some noise
    
    return time_series_data







# Define parameters
np_seed = 42
np.random.seed(np_seed)
N = 100
d_y = 3
d_z = 3
max_length = 100
q = 5
try_load_from_file = False

# Load data, generate if not already saved
data = load_data(N, d_y, d_z, max_length, q,try_load_from_file)



def evaluate_predictions(predictions):
    mse_list = []
    for pred, true in predictions:
        mse = mean_squared_error(true.flatten(), pred.flatten())
        mse_list.append(mse)
    return np.mean(mse_list)

def simulate_forecasting_task_with_global_mean_dependency(data, q):
    predictions = []
    
    for Y_i, X_i, s_i, t_i in data:
        t = np.random.randint(s_i + 1, t_i - q)  # Randomly choose t, ensuring t >= 1
        
        Y_past = Y_i[:, :t - s_i + 1]  # Past time series up to t
        X_past = X_i[:, :t - s_i + 1]  # Past covariates
        
        # Find the available time series that have data at time t-1
        available_series = [
            i for i, (_, _, s_i_avail, t_i_avail) in enumerate(data)
            if s_i_avail <= t - 1 and t - 1 < t_i_avail - s_i_avail + s_i  # Ensure time t-1 is within bounds
        ]
        
        # Compute the global mean only for valid time series at time t-1
        
        if len(available_series) > 0:
            prev_periods = []
            for i in available_series:
                prev_periods.append(data[i][0][:, t - 1 - s_i])
            previous_period = np.array(prev_periods)
            global_mean_past = np.mean(previous_period, axis=0)
        else:
            global_mean_past = np.zeros(Y_i.shape[0])  # Default to zero if no valid series
        
        # Dummy prediction based on mean of Y_past (simple baseline)
        # Predict `q` time steps by repeating the mean prediction `q` times
        Y_mean = np.mean(Y_past, axis=1).reshape(-1, 1) + global_mean_past.reshape(-1, 1)
        Y_future_pred = np.repeat(Y_mean, q, axis=1)  # Repeat the prediction `q` times
        
        # Store the predicted values and the true future values
        predictions.append((Y_future_pred, Y_i[:, t - s_i + 1:t - s_i + 1 + q]))
    
    return predictions
 


q = 5  # Prediction horizon
preds_with_global_mean_dependency = simulate_forecasting_task_with_global_mean_dependency(data, q)

# Evaluate the predictions
mse_with_global_mean_dependency = evaluate_predictions(preds_with_global_mean_dependency)
print(f"Mean Squared Error across all predictions: {mse_with_global_mean_dependency}")

# Function to visualize and save the generated time series data
def save_time_series_plot(data, series_index=0, filename="plot_shifted_ts/time_series_plot.png"):
    """
    Visualize and save the time series Y_i and corresponding covariates X_i for a given index.
    
    Parameters:
    - data: The generated time series dataset
    - series_index: Index of the time series to visualize (default: 0)
    - filename: The name of the file to save the plot (default: "time_series_plot.png")
    """
    # Get the specific time series and covariates
    Y_i, X_i, s_i, t_i = data[series_index]
    
    time_steps = np.arange(s_i, t_i + 1)
    
    # Create subplots for Y and X
    fig, axs = plt.subplots(2, 1, figsize=(10, 8))
    
    # Plot Y (Time series data)
    for dim in range(Y_i.shape[0]):
        axs[0].plot(time_steps, Y_i[dim, :], label=f"Y[{dim}]")
    axs[0].set_title(f"Time Series (Y) for Series {series_index}")
    axs[0].set_xlabel("Time Step")
    axs[0].set_ylabel("Y values")
    axs[0].legend(loc="upper right")
    axs[0].grid(True)
    
    # Plot X (Covariates)
    for dim in range(X_i.shape[0]):
        axs[1].plot(time_steps, X_i[dim, :], label=f"X[{dim}]", linestyle="--")
    axs[1].set_title(f"Covariates (X) for Series {series_index}")
    axs[1].set_xlabel("Time Step")
    axs[1].set_ylabel("X values")
    axs[1].legend(loc="upper right")
    axs[1].grid(True)
    
    # Save the plot to the specified file
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()  # Close the plot to avoid display issues
    
    print(f"Plot saved as {filename}")

# Function to visualize and save multiple time series, each starting at its own start point
def save_multiple_time_series_plot(data, num_series=3, filename="plot_shifted_ts/multiple_time_series_plot.png"):
    """
    Visualize and save multiple time series Y_i with their own start points.
    
    Parameters:
    - data: The generated time series dataset
    - num_series: The number of time series to plot (default: 3)
    - filename: The name of the file to save the plot (default: "multiple_time_series_plot.png")
    """
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Plot up to 'num_series' time series with their own start times
    for series_index, (Y_i, X_i, s_i, t_i) in enumerate(data[:num_series]):
        time_steps = np.arange(s_i, t_i + 1)  # Time steps based on start and end times
        
        # Plot each dimension of the time series Y_i
        for dim in range(Y_i.shape[0]):
            ax.plot(time_steps, Y_i[dim, :], label=f"Y[{dim}] - Series {series_index}, Start={s_i}")
    
    ax.set_title(f"Multiple Time Series with Different Start Points (Showing {num_series})")
    ax.set_xlabel("Time Step")
    ax.set_ylabel("Y values")
    ax.legend(loc="upper right", fontsize="small")
    ax.grid(True)
    
    # Save the plot to the specified file
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()  # Close the plot to avoid display issues
    
    print(f"Plot saved as {filename}")

# Example: Save the plot of multiple time series to a file, plotting only 3 time series by default
save_multiple_time_series_plot(data, num_series=3, filename="multiple_time_series_plot.png")
# Save the plot of the first time series
save_time_series_plot(data, series_index=0, filename="plot_shifted_ts/time_series_0_plot.png")

# Save the plot of another time series if needed
save_time_series_plot(data, series_index=1, filename="plot_shifted_ts/time_series_1_plot.png")

def prepare_regression_data(data, cutoff=None, base_filename="global_mean_plot"):
    """
    Prepare data for linear regression by collecting Y_t, Y_{t-1}, X_{t-1}, and global_mean_{t-1}.
    Split into training and testing based on a time cutoff. Also, return and plot the global mean.
    
    Parameters:
    - data: The generated time series dataset
    - cutoff: The time step cutoff for splitting the dataset into train and test sets
    - base_filename: Base name of the file to save the global mean plot (default: "global_mean_plot")
    
    Returns:
    - X_train, Y_train: The features and targets for training
    - X_test, Y_test: The features and targets for testing
    - global_mean: The global mean across all series at each time step
    """
    X_train, Y_train = [], []
    X_test, Y_test = [], []
    global_mean_per_timestep = {}
    
    # Iterate through each series in the data
    for Y_i, X_i, s_i, t_i in data:
        for t in range(s_i + 1, t_i):  # Skip the first time step, as we need t-1
            # Prepare the features for regression
            Y_prev = Y_i[:, t - 1 - s_i]  # Y_{t-1}
            X_prev = X_i[:, t - 1 - s_i]  # X_{t-1}
            
            # Calculate global mean of Y's at time t-1 across all series
            available_series = [
                i for i, (Y_avail, _, s_avail, t_avail) in enumerate(data)
                if s_avail <= t - 1 and t - 1 <= t_avail
            ]
            global_mean_prev = np.mean([(data[i][0][:, t - 1 - data[i][2]]) for i in available_series], axis=0)
            global_mean_prev = (global_mean_prev - np.mean(global_mean_prev)) / (np.std(global_mean_prev) + 1e-5)  # Standardize
            
            # Save the global mean for plotting
            if t - 1 not in global_mean_per_timestep:
                global_mean_per_timestep[t - 1] = []
            global_mean_per_timestep[t - 1].append(global_mean_prev)
            
            # Stack Y_{t-1}, X_{t-1}, and global_mean_{t-1} as features
            features = np.hstack([Y_prev, X_prev, global_mean_prev])
            
            # Y_t is the target
            Y_target = Y_i[:, t - s_i]
            
            # Split into train/test based on cutoff
            if cutoff and t <= cutoff:
                X_train.append(features)
                Y_train.append(Y_target)
            elif cutoff and t > cutoff:
                X_test.append(features)
                Y_test.append(Y_target)
    
    # Convert lists to arrays
    X_train = np.array(X_train)
    Y_train = np.array(Y_train)
    X_test = np.array(X_test)
    Y_test = np.array(Y_test)
    
    # Compute global mean over all time steps
    global_mean_per_timestep = {t: np.mean(global_mean_per_timestep[t], axis=0) for t in global_mean_per_timestep}
    
    # Plot the global mean for each dimension over time
    plt.figure(figsize=(10, 6))
    time_steps = sorted(global_mean_per_timestep.keys())
    global_mean_array = np.array([global_mean_per_timestep[t] for t in time_steps])
    
    for dim in range(global_mean_array.shape[1]):
        plt.plot(time_steps, global_mean_array[:, dim], label=f"Global Mean Y[{dim+1}]")
    
    plt.title("Global Mean of Y over Time")
    plt.xlabel("Time Step")
    plt.ylabel("Global Mean (Standardized)")
    plt.legend()
    plt.tight_layout()
    
    # Save the plot in the 'plot_shifted_ts' folder
    output_dir = "plot_shifted_ts"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    plt.savefig(os.path.join(output_dir, f"{base_filename}.png"))
    plt.close()
    
    print(f"Global mean plot saved in the 'plot_shifted_ts' folder as {base_filename}.png")
    
    return X_train, Y_train, X_test, Y_test


def fit_linear_regression(X_train, Y_train, d_y, d_z):
    """
    Fit a linear regression model to estimate W_y, W_x, and W_global.
    
    Parameters:
    - X_train: The features matrix [Y_{t-1}, X_{t-1}, global_mean_{t-1}]
    - Y_train: The target matrix Y_t
    - d_y: The dimensionality of Y
    - d_z: The dimensionality of X
    
    Returns:
    - W_y: The estimated weights for Y_{t-1}
    - W_x: The estimated weights for X_{t-1}
    - W_global: The estimated weights for global_mean_{t-1}
    - models: List of fitted linear models (one per dimension of Y)
    """
    # Fit a linear regression model for each dimension of Y separately
    W_y = np.zeros((d_y, d_y))
    W_x = np.zeros((d_y, d_z))
    W_global = np.zeros(d_y)
    models = []
    
    for dim in range(d_y):
        # Fit linear regression for each dimension of Y
        model = LinearRegression()
        model.fit(X_train, Y_train[:, dim])
        models.append(model)
        
        # Extract the coefficients for Y_{t-1}, X_{t-1}, and global_mean_{t-1}
        W_y[dim, :] = model.coef_[:d_y]  # Coefficients for Y_{t-1}
        W_x[dim, :] = model.coef_[d_y:d_y + d_z]  # Coefficients for X_{t-1}
        W_global[dim] = model.coef_[-d_y]  # Coefficients for global_mean_{t-1}
    
    return W_y, W_x, W_global, models

def predict_and_evaluate(models, X, Y_true, d_y):
    """
    Use fitted models to predict Y and evaluate the prediction error.
    
    Parameters:
    - models: List of fitted models (one per dimension of Y)
    - X: The feature matrix for prediction
    - Y_true: The true values of Y to compare against
    - d_y: The dimensionality of Y
    
    Returns:
    - Y_pred: The predicted values
    - mse: The mean squared error of the predictions
    """
    Y_pred = np.zeros_like(Y_true)
    for dim in range(d_y):
        Y_pred[:, dim] = models[dim].predict(X)
    
    mse = mean_squared_error(Y_true, Y_pred)
    return Y_pred, mse


def fit_linear_regression_wo_global_mean(X_train, Y_train, d_y, d_z):
    """
    Fit a linear regression model to estimate W_y, W_x, and W_global.
    
    Parameters:
    - X_train: The features matrix [Y_{t-1}, X_{t-1}, global_mean_{t-1}]
    - Y_train: The target matrix Y_t
    - d_y: The dimensionality of Y
    - d_z: The dimensionality of X
    
    Returns:
    - W_y: The estimated weights for Y_{t-1}
    - W_x: The estimated weights for X_{t-1}
    - W_global: The estimated weights for global_mean_{t-1}
    - models: List of fitted linear models (one per dimension of Y)
    """
    # Fit a linear regression model for each dimension of Y separately
    W_y = np.zeros((d_y, d_y))
    W_x = np.zeros((d_y, d_z))
    models = []
    
    for dim in range(d_y):
        # Fit linear regression for each dimension of Y
        model = LinearRegression()
        model.fit(X_train[:,:d_y+d_z], Y_train[:, dim])
        models.append(model)
        
        # Extract the coefficients for Y_{t-1}, X_{t-1}, and global_mean_{t-1}
        W_y[dim, :] = model.coef_[:d_y]  # Coefficients for Y_{t-1}
        W_x[dim, :] = model.coef_[d_y:d_y + d_z]  # Coefficients for X_{t-1}
    
    return W_y, W_x, models




# Step 1: Split the data into training and test sets based on a cutoff time
cutoff = 70  # Example cutoff time (can be chosen as desired)
X_train, Y_train, X_test, Y_test = prepare_regression_data(data, cutoff=cutoff)

# Step 2: Fit the linear regression model on the training data
d_y = Y_train.shape[1]  # Number of dimensions of Y
d_z = X_train.shape[1] - d_y - d_y  # Number of dimensions of X (subtracting Y_{t-1} and global_mean)
W_y, W_x, W_global, models = fit_linear_regression(X_train, Y_train, d_y, d_z)

W_y_wo_mean, W_x_wo_mean, models_wo_mean = fit_linear_regression_wo_global_mean(X_train, Y_train, d_y, d_z)

# L2 norm of the weights
l2_norm_W_y = np.linalg.norm(W_y)
l2_norm_W_x = np.linalg.norm(W_x)
l2_norm_W_global = np.linalg.norm(W_global)
print(f"L2 Norm of W_y: {l2_norm_W_y:.4f}")
print(f"L2 Norm of W_x: {l2_norm_W_x:.4f}")
print(f"L2 Norm of W_global: {l2_norm_W_global:.4f}")

# W_y - W_y_wo_mean
l2_norm_W_y_diff = np.linalg.norm(W_y - W_y_wo_mean)
l2_norm_W_x_diff = np.linalg.norm(W_x - W_x_wo_mean)
print(f"L2 Norm of W_y - W_y_wo_mean: {l2_norm_W_y_diff:.4f}")
print(f"L2 Norm of W_x - W_x_wo_mean: {l2_norm_W_x_diff:.4f}")

# Step 3: Predict and evaluate in-sample (training set)
Y_train_pred, mse_train = predict_and_evaluate(models, X_train, Y_train, d_y)
print(f'In-sample MSE: {mse_train}')

# Step 4: Predict and evaluate out-of-sample (test set)
Y_test_pred, mse_test = predict_and_evaluate(models, X_test, Y_test, d_y)
print(f'Out-of-sample MSE: {mse_test}')

Y_train_pred_wo_mean, mse_train_wo_mean = predict_and_evaluate(models_wo_mean, X_train[:,:d_y+d_z], Y_train, d_y)
print(f'In-sample MSE without global mean: {mse_train_wo_mean}')

Y_test_pred_wo_mean, mse_test_wo_mean = predict_and_evaluate(models_wo_mean, X_test[:,:d_y+d_z], Y_test, d_y)
print(f'Out-of-sample MSE without global mean: {mse_test_wo_mean}')

def analyze_residuals(Y_true, Y_pred, dataset_type):
    """
    Analyze residuals by testing if they are zero-mean normal and computing R².
    
    Parameters:
    - Y_true: The true values of Y
    - Y_pred: The predicted values of Y
    - dataset_type: A string indicating whether it's in-sample or out-of-sample
    
    Returns:
    - residual_stats: A dictionary with results from the normality and zero-mean tests
    """
    residuals = Y_true - Y_pred
    d_y = Y_true.shape[1]  # Number of components in Y (dimensions)
    
    residual_stats = {}
    
    # Compute R² score for the predictions
    r2 = r2_score(Y_true, Y_pred)
    residual_stats['R²'] = r2
    print(f"R² for {dataset_type}: {r2:.4f}")
    
    # Test each component of the residuals separately
    for dim in range(d_y):
        residuals_dim = residuals[:, dim]
        
        # Shapiro-Wilk test for normality
        shapiro_stat, shapiro_p = shapiro(residuals_dim)
        
        # T-test to check if the mean of the residuals is zero
        t_stat, t_p = ttest_1samp(residuals_dim, 0)
        
        # Save the statistics for this dimension
        residual_stats[f"Dimension {dim+1}"] = {
            "Shapiro-Wilk p-value": shapiro_p,
            "T-test p-value (mean=0)": t_p,
            "Residual Mean": residuals_dim.mean(),
            "Residual Variance": residuals_dim.var()
        }
        
        print(f"Dimension {dim+1}:")
        print(f"  Shapiro-Wilk p-value for normality: {shapiro_p:.4f}")
        print(f"  T-test p-value for zero mean: {t_p:.4f}")
        print(f"  Residual Mean: {residuals_dim.mean():.4f}")
        print(f"  Residual Variance: {residuals_dim.var():.4f}")
    
    return residual_stats

# Example usage:

# Analyze in-sample residuals (training data)
train_stats = analyze_residuals(Y_train, Y_train_pred, 'In-Sample')

# Analyze out-of-sample residuals (test data)
test_stats = analyze_residuals(Y_test, Y_test_pred, 'Out-of-Sample')


r2_wo_mean_in_sample = r2_score(Y_train, Y_train_pred_wo_mean)
print(f"R² for In-Sample without global mean: {r2_wo_mean_in_sample:.4f}")
r2_wo_mean_out_of_sample = r2_score(Y_test, Y_test_pred_wo_mean)
print(f"R² for Out-of-Sample without global mean: {r2_wo_mean_out_of_sample:.4f}")
# Step 5: Plot residuals
# Example usage:

def save_residual_plots_with_time(Y_true, Y_pred, data, dataset_type, cutoff, base_filename="plot_shifted_ts/residuals_plot"):
    """
    Save residuals plots between true and predicted Y, with both absolute time (t) and relative time (t - s_i).
    The function checks whether the time is before or after the cutoff to distinguish train and test sets.
    
    Parameters:
    - Y_true: The true values of Y
    - Y_pred: The predicted values of Y
    - data: The original time series data, to retrieve absolute and relative time
    - dataset_type: A string indicating whether it's in-sample or out-of-sample
    - cutoff: Time step cutoff to split data into train and test sets
    - base_filename: Base name of the file to save the residual plots (default: "residuals_plot")
    """
    residuals = Y_true - Y_pred
    d_y = Y_true.shape[1]  # Number of components in Y (dimensions)
    
    # Store absolute time (t) and relative time (t - s_i)
    absolute_times = []
    relative_times = []
    
    index = 0
    for Y_i, X_i, s_i, t_i in data:
        for t in range(s_i + 1, t_i):
            if dataset_type == 'In-Sample' and t <= cutoff:
                absolute_times.append(t)
                relative_times.append(t - s_i)
            elif dataset_type == 'Out-of-Sample' and t > cutoff:
                absolute_times.append(t)
                relative_times.append(t - s_i)
    
    absolute_times = np.array(absolute_times)
    relative_times = np.array(relative_times)
    
    # Save each component of the residuals with both absolute and relative times
    for dim in range(d_y):
        # Plot with absolute time
        plt.figure(figsize=(10, 6))
        plt.scatter(absolute_times, residuals[:, dim], alpha=0.5)
        plt.axhline(0, color='red', linestyle='--')
        plt.title(f'Residuals vs Absolute Time - Dimension {dim+1} ({dataset_type})')
        plt.xlabel('Absolute Time (t)')
        plt.ylabel(f'Residuals (Y[{dim+1}])')
        plt.tight_layout()
        plt.savefig(f"{base_filename}_{dataset_type}_dim{dim+1}_absolute_time.png")
        plt.close()
        
        # Plot with relative time
        plt.figure(figsize=(10, 6))
        plt.scatter(relative_times, residuals[:, dim], alpha=0.5)
        plt.axhline(0, color='red', linestyle='--')
        plt.title(f'Residuals vs Relative Time - Dimension {dim+1} ({dataset_type})')
        plt.xlabel('Relative Time (t - s_i)')
        plt.ylabel(f'Residuals (Y[{dim+1}])')
        plt.tight_layout()
        plt.savefig(f"{base_filename}_{dataset_type}_dim{dim+1}_relative_time.png")
        plt.close()

# Function to test homoscedasticity
def test_homoscedasticity(Y_true, Y_pred, data, dataset_type, cutoff):
    """
    Test if the residuals are homoscedastic with respect to absolute and relative time, 
    and check if each time step belongs to train or test set based on the cutoff.
    
    Parameters:
    - Y_true: The true values of Y
    - Y_pred: The predicted values of Y
    - data: The original time series data, to retrieve absolute and relative time
    - dataset_type: A string indicating whether it's in-sample or out-of-sample
    - cutoff: Time step cutoff to split data into train and test sets
    """
    residuals = Y_true - Y_pred
    d_y = Y_true.shape[1]  # Number of components in Y (dimensions)
    
    # Store absolute time (t) and relative time (t - s_i)
    absolute_times = []
    relative_times = []
    
    index = 0
    for Y_i, X_i, s_i, t_i in data:
        for t in range(s_i + 1, t_i):
            if dataset_type == 'In-Sample' and t <= cutoff:
                absolute_times.append(t)
                relative_times.append(t - s_i)
            elif dataset_type == 'Out-of-Sample' and t > cutoff:
                absolute_times.append(t)
                relative_times.append(t - s_i)
    
    absolute_times = np.array(absolute_times)
    relative_times = np.array(relative_times)
    
    homoscedasticity_results = {}
    
    # Check homoscedasticity for each dimension of the residuals
    for dim in range(d_y):
        residuals_dim = residuals[:, dim]
        
        print(f"Testing homoscedasticity for Dimension {dim+1} ({dataset_type})...")
        
        # Breusch-Pagan Test (absolute time as independent variable)
        abs_time_exog = sm.add_constant(absolute_times)  # Add a constant term
        bp_test = het_breuschpagan(residuals_dim, abs_time_exog)
        homoscedasticity_results[f"Dimension {dim+1} (Absolute Time)"] = {
            "BP Test Statistic": bp_test[0],
            "BP Test p-value": bp_test[1]
        }
        
        print(f"  Breusch-Pagan p-value (absolute time): {bp_test[1]:.4f}")
        
        # Breusch-Pagan Test (relative time as independent variable)
        rel_time_exog = sm.add_constant(relative_times)  # Add a constant term
        bp_test_rel = het_breuschpagan(residuals_dim, rel_time_exog)
        homoscedasticity_results[f"Dimension {dim+1} (Relative Time)"] = {
            "BP Test Statistic": bp_test_rel[0],
            "BP Test p-value": bp_test_rel[1]
        }
        
        print(f"  Breusch-Pagan p-value (relative time): {bp_test_rel[1]:.4f}")
    
    return homoscedasticity_results

# Example usage:

# Save residual plots for the training set (in-sample)
save_residual_plots_with_time(Y_train, Y_train_pred, data, 'In-Sample', cutoff, base_filename="plot_shifted_ts/residuals_in_sample")

# Save residual plots for the test set (out-of-sample)
save_residual_plots_with_time(Y_test, Y_test_pred, data, 'Out-of-Sample', cutoff, base_filename="plot_shifted_ts/residuals_out_of_sample")

# Test homoscedasticity for the training set (in-sample)
train_homoscedasticity_results = test_homoscedasticity(Y_train, Y_train_pred, data, 'In-Sample', cutoff)

# Test homoscedasticity for the test set (out-of-sample)
test_homoscedasticity_results = test_homoscedasticity(Y_test, Y_test_pred, data, 'Out-of-Sample', cutoff)



import os

def plot_average_magnitude_Y_per_timestep(data, cutoff, base_filename="avg_magnitude_Y"):
    """
    Plot the average magnitude of Y for each dimension at each time step, averaged across all available series,
    with both absolute and relative time as the x-axis, and save the plots in the 'plot_shifted_ts' folder.
    
    Parameters:
    - data: The original time series data, containing Y and the corresponding times.
    - cutoff: The time step cutoff to distinguish train and test sets.
    - base_filename: Base name of the file to save the plots (default: "avg_magnitude_Y_plot").
    """
    # Create the 'plot_shifted_ts' folder if it doesn't exist
    output_dir = "plot_shifted_ts"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    d_y = data[0][0].shape[0]  # Get the number of dimensions of Y
    
    # Initialize dictionaries to store magnitudes at each time step
    magnitudes_train = {}
    magnitudes_test = {}
    magnitudes_train_rel = {}
    magnitudes_test_rel = {}
    # Iterate through the dataset to accumulate Y magnitudes for train and test sets
    for Y_i, X_i, s_i, t_i in data:
        for t in range(s_i + 1, t_i):
            Y_t = Y_i[:, t - s_i]  # Y at time t
            
            # Compute the magnitude for each dimension
            magnitude_t = np.abs(Y_t)
            
            if t <= cutoff:  # Training set
                if t not in magnitudes_train:
                    magnitudes_train[t] = {'count': 0, 'magnitude': np.zeros(d_y)}
                if t-s_i not in magnitudes_train_rel:
                    magnitudes_train_rel[t-s_i] = {'count': 0, 'magnitude': np.zeros(d_y)}
                magnitudes_train[t]['magnitude'] += magnitude_t
                magnitudes_train[t]['count'] += 1
                magnitudes_train_rel[t-s_i]['magnitude'] += magnitude_t
                magnitudes_train_rel[t-s_i]['count'] += 1
            else:  # Test set
                if t not in magnitudes_test:
                    magnitudes_test[t] = {'count': 0, 'magnitude': np.zeros(d_y)}
                if t-s_i not in magnitudes_test_rel:
                    magnitudes_test_rel[t-s_i]= {'count': 0, 'magnitude': np.zeros(d_y)}
                magnitudes_test[t]['magnitude'] += magnitude_t
                magnitudes_test[t]['count'] += 1
                magnitudes_test_rel[t-s_i]['magnitude'] += magnitude_t
                magnitudes_test_rel[t-s_i]['count'] += 1
    
    # Compute the average magnitudes for each time step
    avg_magnitude_train = {t: magnitudes_train[t]['magnitude'] / magnitudes_train[t]['count'] 
                           for t in magnitudes_train}
    avg_magnitude_test = {t: magnitudes_test[t]['magnitude'] / magnitudes_test[t]['count'] 
                          for t in magnitudes_test}
    avg_magnitude_train_rel = {t: magnitudes_train_rel[t]['magnitude'] / magnitudes_train_rel[t]['count']
                                for t in magnitudes_train_rel}
    avg_magnitude_test_rel = {t: magnitudes_test_rel[t]['magnitude'] / magnitudes_test_rel[t]['count'] 
                                for t in magnitudes_test_rel}
    
    # Sort times for plotting
    absolute_times_train = sorted(avg_magnitude_train.keys())
    absolute_times_test = sorted(avg_magnitude_test.keys())

    times_train_rel = sorted(avg_magnitude_train_rel.keys())
    times_test_rel = sorted(avg_magnitude_test_rel.keys())
    
    # Prepare the average magnitude arrays for plotting
    avg_magnitude_train_array = np.array([avg_magnitude_train[t] for t in absolute_times_train])
    avg_magnitude_test_array = np.array([avg_magnitude_test[t] for t in absolute_times_test])

    avg_magnitude_train_array_rel = np.array([avg_magnitude_train_rel[t] for t in times_train_rel])
    avg_magnitude_test_array_rel = np.array([avg_magnitude_test_rel[t] for t in times_test_rel])
    
    # Plot the average magnitude for each dimension with absolute time
    plt.figure(figsize=(10, 6))
    for dim in range(d_y):
        plt.plot(absolute_times_train, avg_magnitude_train_array[:, dim], label=f"Train Y[{dim+1}] (Abs Time)", linestyle='-')
        plt.plot(absolute_times_test, avg_magnitude_test_array[:, dim], label=f"Test Y[{dim+1}] (Abs Time)", linestyle='--')
    plt.title("Average Magnitude of Y (Absolute Time)")
    plt.xlabel("Absolute Time (t)")
    plt.ylabel("Average Magnitude |Y|")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{base_filename}_absolute_time.png"))
    plt.close()

    
    # Plot the average magnitude for each dimension with relative time
    plt.figure(figsize=(10, 6))
    for dim in range(d_y):
        plt.plot(times_train_rel, avg_magnitude_train_array_rel[:, dim], label=f"Train Y[{dim+1}] (Rel Time)", linestyle='-')
        plt.plot(times_test_rel, avg_magnitude_test_array_rel[:, dim], label=f"Test Y[{dim+1}] (Rel Time)", linestyle='--')
    plt.title("Average Magnitude of Y (Relative Time)")
    plt.xlabel("Relative Time (t - s_i)")
    plt.ylabel("Average Magnitude |Y|")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{base_filename}_relative_time.png"))
    plt.close()

    print(f"Plots saved in the 'plot_shifted_ts' folder as {base_filename}_absolute_time.png and {base_filename}_relative_time.png")

# Example usage:
plot_average_magnitude_Y_per_timestep(data, cutoff, base_filename="avg_magnitude_Y_per_timestep")


def plot_mean_std_Y_per_timestep(data, cutoff, base_filename="mean_std_Y"):
    """
    Plot the mean and standard deviation of Y for each dimension at each time step, averaged across all available series,
    with both absolute and relative time as the x-axis, and save the plots in the 'plot_shifted_ts' folder.
    
    Parameters:
    - data: The original time series data, containing Y and the corresponding times.
    - cutoff: The time step cutoff to distinguish train and test sets.
    - base_filename: Base name of the file to save the plots (default: "mean_std_Y").
    """
    # Create the 'plot_shifted_ts' folder if it doesn't exist
    output_dir = "plot_shifted_ts"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    d_y = data[0][0].shape[0]  # Get the number of dimensions of Y
    
    # Initialize dictionaries to store values at each time step
    values_train = {}
    values_test = {}
    values_train_rel = {}
    values_test_rel = {}
    
    # Iterate through the dataset to accumulate Y values for train and test sets
    for Y_i, X_i, s_i, t_i in data:
        for t in range(s_i + 1, t_i):
            Y_t = Y_i[:, t - s_i]  # Y at time t
            
            if t <= cutoff:  # Training set
                if t not in values_train:
                    values_train[t] = []
                if t-s_i not in values_train_rel:
                    values_train_rel[t-s_i] = []
                values_train[t].append(Y_t)
                values_train_rel[t-s_i].append(Y_t)
            else:  # Test set
                if t not in values_test:
                    values_test[t] = []
                if t-s_i not in values_test_rel:
                    values_test_rel[t-s_i]= []
                values_test[t].append(Y_t)
                values_test_rel[t-s_i].append(Y_t)
    
    # Compute the mean and standard deviation for each time step
    mean_train = {t: np.mean(values_train[t], axis=0) for t in values_train}
    std_train = {t: np.std(values_train[t], axis=0) for t in values_train}
    
    mean_test = {t: np.mean(values_test[t], axis=0) for t in values_test}
    std_test = {t: np.std(values_test[t], axis=0) for t in values_test}

    mean_train_rel = {t: np.mean(values_train_rel[t], axis=0) for t in values_train_rel}
    std_train_rel = {t: np.std(values_train_rel[t], axis=0) for t in values_train_rel}
    
    mean_test_rel = {t: np.mean(values_test_rel[t], axis=0) for t in values_test_rel}
    std_test_rel = {t: np.std(values_test_rel[t], axis=0) for t in values_test_rel}
    
    # Sort times for plotting
    absolute_times_train = sorted(mean_train.keys())
    absolute_times_test = sorted(mean_test.keys())
    times_train_rel = sorted(mean_train_rel.keys())
    times_test_rel = sorted(mean_test_rel.keys())
    
    # Prepare the arrays for plotting
    mean_train_array = np.array([mean_train[t] for t in absolute_times_train])
    std_train_array = np.array([std_train[t] for t in absolute_times_train])
    mean_test_array = np.array([mean_test[t] for t in absolute_times_test])
    std_test_array = np.array([std_test[t] for t in absolute_times_test])
    
    mean_train_array_rel = np.array([mean_train_rel[t] for t in times_train_rel])
    std_train_array_rel = np.array([std_train_rel[t] for t in times_train_rel])
    mean_test_array_rel = np.array([mean_test_rel[t] for t in times_test_rel])
    std_test_array_rel = np.array([std_test_rel[t] for t in times_test_rel])
    
    # Plot the mean and standard deviation for each dimension with absolute time
    plt.figure(figsize=(10, 6))
    for dim in range(d_y):
        plt.errorbar(absolute_times_train, mean_train_array[:, dim], yerr=std_train_array[:, dim], 
                     label=f"Train Y[{dim+1}] (Abs Time)", linestyle='-')
        plt.errorbar(absolute_times_test, mean_test_array[:, dim], yerr=std_test_array[:, dim], 
                     label=f"Test Y[{dim+1}] (Abs Time)", linestyle='--')
    plt.title("Mean and Standard Deviation of Y (Absolute Time)")
    plt.xlabel("Absolute Time (t)")
    plt.ylabel("Mean ± Std Y")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{base_filename}_absolute_time.png"))
    plt.close()

    # Plot the mean and standard deviation for each dimension with relative time
    plt.figure(figsize=(10, 6))
    for dim in range(d_y):
        plt.errorbar(times_train_rel, mean_train_array_rel[:, dim], yerr=std_train_array_rel[:, dim], 
                     label=f"Train Y[{dim+1}] (Rel Time)", linestyle='-')
        plt.errorbar(times_test_rel, mean_test_array_rel[:, dim], yerr=std_test_array_rel[:, dim], 
                     label=f"Test Y[{dim+1}] (Rel Time)", linestyle='--')
    plt.title("Mean and Standard Deviation of Y (Relative Time)")
    plt.xlabel("Relative Time (t - s_i)")
    plt.ylabel("Mean ± Std Y")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{base_filename}_relative_time.png"))
    plt.close()

    print(f"Plots saved in the 'plot_shifted_ts' folder as {base_filename}_absolute_time.png and {base_filename}_relative_time.png")

# Example usage:
plot_mean_std_Y_per_timestep(data, cutoff, base_filename="mean_std_Y_per_timestep")









