import numpy as np
import matplotlib.pyplot as plt
from statsmodels.tsa.stattools import acf
from scipy.signal import welch
from joblib import Parallel, delayed
from typing import List, Optional, Tuple
import pandas as pd
import os

def _compute_acf_for_window(window: np.ndarray, nlags: int) -> np.ndarray:
    """Computes the ACF for each feature in a single window."""
    return np.array([acf(window[:, k], nlags=nlags, fft=True) for k in range(window.shape[1])]).T

def _compute_psd_for_window(window: np.ndarray, fs: float) -> Tuple[np.ndarray, np.ndarray]:
    L = window.shape[0]
    freqs, Pxx_list = [], []
    for k in range(window.shape[1]):
        f, p = welch(window[:, k], fs=fs, nperseg=L)
        if k == 0:
            freqs = f
        Pxx_list.append(p)
    return freqs, np.array(Pxx_list).T

# --- Modified Main Function ---

def generate_comparison_plots_per_window(
    ts_generated: np.ndarray, 
    ts_original: np.ndarray, 
    fs: float = 1.0, 
    nlags: Optional[int] = None, 
    feature_names: Optional[List[str]] = None,
    plot_dir: str = '.',
    n_jobs: int = -1
):
    """
    Computes and saves comparison plots for the average ACF and PSD of two time series.

    The function generates two separate PNG files:
    1. An ACF plot with K subplots, comparing the original and generated series for each feature.
    2. A PSD plot with K subplots, also comparing the two series for each feature.

    Args:
        ts_generated (np.ndarray): The generated time series with shape (B, L, K).
        ts_original (np.ndarray): The original time series with shape (B, L, K).
        fs (float, optional): Sampling frequency of the time series. Defaults to 1.0.
        nlags (int, optional): Number of lags for ACF. Defaults to min(L-1, 40).
        feature_names (List[str], optional): Names for each feature (K).
                                             Defaults to ['Feature 0', 'Feature 1', ...].
        plot_dir (str, optional): Directory to save the plots. Defaults to current directory.
        n_jobs (int, optional): Number of jobs for joblib. -1 means using all available cores.
    """
    # --- Input Validation and Setup ---
    if ts_generated.shape != ts_original.shape:
        raise ValueError("Input time series must have the same shape.")
    
    if ts_generated.ndim != 3:
        raise ValueError("Input time series must be 3-dimensional (B, L, K).")
    
    acf_path = os.path.join(plot_dir, 'comparison_plot_acf_per_window.png')
    psd_path = os.path.join(plot_dir, 'comparison_plot_psd_per_window.png')

    B, L, K = ts_original.shape
    
    if nlags is None:
        nlags = min(L // 2, 50)
    else:
        nlags = min(nlags, L - 1)

    if feature_names is None:
        feature_names = [f'Feature {k}' for k in range(K)]
    
    # --- Parallel Computation ---
    parallel = Parallel(n_jobs=n_jobs)

    # 1. ACF Calculation
    all_acfs_gen = parallel(delayed(_compute_acf_for_window)(ts_generated[i], nlags) for i in range(B))
    mean_acf_gen = np.mean(np.array(all_acfs_gen), axis=0)
    
    all_acfs_orig = parallel(delayed(_compute_acf_for_window)(ts_original[i], nlags) for i in range(B))
    mean_acf_orig = np.mean(np.array(all_acfs_orig), axis=0)

    # 2. PSD Calculation
    psd_results_gen = parallel(delayed(_compute_psd_for_window)(ts_generated[i], fs) for i in range(B))
    freqs = psd_results_gen[0][0] # Frequencies are the same for all
    all_psds_gen = np.array([res[1] for res in psd_results_gen])
    mean_psd_gen = np.mean(all_psds_gen, axis=0)

    psd_results_orig = parallel(delayed(_compute_psd_for_window)(ts_original[i], fs) for i in range(B))
    all_psds_orig = np.array([res[1] for res in psd_results_orig])
    mean_psd_orig = np.mean(all_psds_orig, axis=0)

    # --- Plotting and Saving ---

    # 1. Generate and save ACF plot
    fig_acf, axes_acf = plt.subplots(K, 1, figsize=(10, 4 * K), sharex=True)
    if K == 1: axes_acf = [axes_acf] # Ensure axes_acf is always a list
    fig_acf.suptitle('Average Autocorrelation (ACF) Comparison', fontsize=16)

    for k in range(K):
        ax = axes_acf[k]
        ax.plot(range(nlags + 1), mean_acf_orig[:, k], label='Original', color='blue')
        ax.plot(range(nlags + 1), mean_acf_gen[:, k], label='Generated', color='red', linestyle='--')
        ax.set_title(f'{feature_names[k]}')
        ax.set_ylabel('Autocorrelation')
        ax.grid(True, linestyle='--', alpha=0.6)
        ax.axhline(0, color='black', linewidth=0.8)
        ax.legend()
    axes_acf[-1].set_xlabel('Lag')
    fig_acf.tight_layout(rect=[0, 0.03, 1, 0.96])
    fig_acf.savefig(acf_path)
    plt.close(fig_acf) # Close figure to free memory

    # 2. Generate and save PSD plot
    fig_psd, axes_psd = plt.subplots(K, 1, figsize=(10, 4 * K), sharex=True)
    if K == 1: axes_psd = [axes_psd] # Ensure axes_psd is always a list
    fig_psd.suptitle('Average Power Spectral Density (PSD) Comparison', fontsize=16)
    
    for k in range(K):
        ax = axes_psd[k]
        ax.plot(freqs, mean_psd_orig[:, k], label='Original', color='blue')
        ax.plot(freqs, mean_psd_gen[:, k], label='Generated', color='red', linestyle='--')
        ax.set_title(f'{feature_names[k]}')
        ax.set_ylabel('Power/Frequency')
        ax.set_yscale('log')
        ax.grid(True, linestyle='--', alpha=0.6)
        ax.legend()
    axes_psd[-1].set_xlabel(f'Frequency (Hz, fs={fs})')
    fig_psd.tight_layout(rect=[0, 0.03, 1, 0.96])
    fig_psd.savefig(psd_path)
    plt.close(fig_psd) # Close figure to free memory



def plot_chunk(fake_chunk : np.ndarray, ori_chunk : np.ndarray, plot_path:str = './chunk.png'):

    assert fake_chunk.shape == ori_chunk.shape, "Fake and original chunks must have the same shape."
    n_features = fake_chunk.shape[-1]
    fig, axs = plt.subplots(n_features, 1, figsize=(16, 5*n_features))
    if fake_chunk.ndim == 3:
        # If the input is 3D, we assume the shape is (B, L, K)
        n_features = fake_chunk.shape[2]
        

    for f in range(n_features):
        if fake_chunk.ndim == 3:
            # If the input is 3D, we assume the shape is (B, L, K)
            # plot all the B samples for the feature f as blob
            axs[f].plot(fake_chunk[:, :, f].T, alpha=0.5, label="Generated", color='blue')
            axs[f].plot(ori_chunk[:, :, f].T, alpha=0.5, label="Original", color='orange', linestyle='--')
        else:
            # If the input is 2D, we assume the shape is (L, K)
            axs[f].plot(fake_chunk[:, f], label="Generated", color='blue')
            axs[f].plot(ori_chunk[:, f], label="Original", color='orange', linestyle='--')

        axs[f].set_title(f"Feature {f} — original vs generated")
        axs[f].set_ylabel("Value")
        axs[f].set_xlabel("Time")
        if fake_chunk.ndim == 2:
            axs[f].legend()
        axs[f].grid(True)

    plt.tight_layout()
    plt.savefig(plot_path)
    plt.close(fig)

def generate_comparison_plots(df: pd.DataFrame, num_covariates: int, nlags: int, plot_dir: str = '.'):
    """
    Generates two summary figures with vertical subplots:
    1. A figure with ACF subplots for all covariates.
    2. A figure with PSD subplots for all covariates.
    
    Args:
        df (pd.DataFrame): The dataframe containing the time series data.
        num_covariates (int): The total number of covariates to plot.
        nlags (int): The number of lags to compute for the ACF.
    """
    # --- Setup Figure Grids (Single Column) ---
    plt.style.use('seaborn-v0_8-whitegrid')
    
    # Create Figure 1 for ACF plots. All subplots will share the same x-axis.
    # The figure height is proportional to the number of plots to ensure readability.
    fig_acf, axes_acf = plt.subplots(
        nrows=num_covariates, 
        ncols=1, 
        figsize=(10, num_covariates * 3.5), 
        sharex=True
    )
    fig_acf.suptitle('Autocorrelation Function (ACF) Comparison', fontsize=16, y=0.99)
    
    # Create Figure 2 for PSD plots
    fig_psd, axes_psd = plt.subplots(
        nrows=num_covariates, 
        ncols=1, 
        figsize=(10, num_covariates * 3.5)
    )
    fig_psd.suptitle('Power Spectral Density (PSD) Comparison', fontsize=16, y=0.99)
    
    # If there's only one covariate, subplots returns a single axis object, not an array.
    # We wrap it in a list to make the loop work consistently.
    if num_covariates == 1:
        axes_acf = [axes_acf]
        axes_psd = [axes_psd]

    # --- Loop Through Covariates and Populate Subplots ---
    original_cols = [col for col in df.columns if 'orig' in col.lower()]
    generated_cols = [col for col in df.columns if 'gen' in col.lower()]
    for i, original_col in enumerate(original_cols):
        col_name = original_col.split('_')[1]  # Extract base name for generated column
        # find it in the generated columns
        generated_col = next((col for col in generated_cols if col.endswith(col_name)), None)

        original_series = df[original_col].dropna()
        generated_series = df[generated_col].dropna()
        if original_series.empty or generated_series.empty:
            continue

        # --- Select the correct subplot from the 1D array of axes ---
        ax_acf = axes_acf[i]
        ax_psd = axes_psd[i]

        # --- 1. Populate ACF Subplot ---
        acf_original = acf(original_series, nlags=nlags, fft=True)
        acf_generated = acf(generated_series, nlags=nlags, fft=True)
        conf_interval = 1.96 / np.sqrt(len(original_series))
        
        ax_acf.plot(acf_original, marker='o', linestyle='-', markersize=4, label='Original')
        ax_acf.plot(acf_generated, marker='x', linestyle='--', markersize=4, label='Generated')
        ax_acf.axhline(y=conf_interval, color='gray', linestyle='--', alpha=0.7)
        ax_acf.axhline(y=-conf_interval, color='gray', linestyle='--', alpha=0.7)
        ax_acf.set_title(f'Covariate {i}')
        ax_acf.set_xlim(0, nlags)
        ax_acf.set_ylabel('ACF')

        # --- 2. Populate PSD Subplot ---
        freq_orig, psd_orig = welch(original_series)
        freq_gen, psd_gen = welch(generated_series)
        
        ax_psd.semilogy(freq_orig, psd_orig, label='Original')
        ax_psd.semilogy(freq_gen, psd_gen, linestyle='--', label='Generated')
        ax_psd.set_title(f'Covariate {i}')
        ax_psd.set_ylabel('Power (log)')
        ax_psd.set_xlabel('Frequency')
    
    # --- Final Touches ---
    # Add a shared X-axis label to the bottom-most ACF plot
    axes_acf[-1].set_xlabel('Lag')
    
    # Add a single legend to the first plot of each figure for clarity
    axes_acf[0].legend()
    axes_psd[0].legend()

    # Adjust layout to prevent title and labels from overlapping
    fig_acf.tight_layout()
    fig_psd.tight_layout()

    # Save the figures
    fig_acf.savefig(f'{plot_dir}/comparison_plot_acf.png', dpi=300)
    fig_psd.savefig(f'{plot_dir}/comparison_plot_psd.png', dpi=300)