# scripts/calculate_wavelet_errors.py

import os
import glob
import argparse
import logging
import nibabel as nib
import numpy as np
import pywt
import pandas as pd
from tqdm import tqdm

# --- Setup Basic Logging ---
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

def calculate_reconstruction_mae(original_data: np.ndarray, wavelet_name: str) -> float:
    """
    Performs a DWT/IDWT roundtrip and calculates the Mean Absolute Error (MAE).

    Args:
        original_data (np.ndarray): The original 3D volume.
        wavelet_name (str): The name of the wavelet to use.

    Returns:
        float: The calculated Mean Absolute Error.
    """
    # 1. Perform Discrete Wavelet Transform (DWT)
    coeffs = pywt.wavedecn(original_data, wavelet=wavelet_name, level=1)
    
    # 2. Perform Inverse Discrete Wavelet Transform (IDWT)
    recon_data = pywt.waverecn(coeffs, wavelet=wavelet_name)
    
    # 3. Crop reconstruction to match original shape
    slicing = tuple(slice(0, s) for s in original_data.shape)
    recon_data_cropped = recon_data[slicing]
    
    # 4. Calculate and return the Mean Absolute Error
    mae = np.mean(np.abs(original_data - recon_data_cropped))
    return mae

def main():
    parser = argparse.ArgumentParser(
        description="Calculate the reconstruction Mean Absolute Error (MAE) and its standard deviation "
                    "for various wavelets across a dataset of 3D NIfTI files.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        "--input_dir",
        type=str,
        required=True,
        help="Directory containing the real NIfTI dataset (.nii.gz files)."
    )
    parser.add_argument(
        "--output_csv",
        type=str,
        required=True,
        help="Path to save the resulting CSV file with MAE and STD values."
    )
    parser.add_argument(
        "--wavelets",
        nargs='+',
        default=['haar','db4', 'sym4', 'coif2', 'bior3.3'],
        help="List of wavelet types to test (e.g., 'haar', 'db4', 'sym4')."
    )
    parser.add_argument(
        "--num_files",
        type=int,
        default=None,
        help="(Optional) Limit the number of files to process for a quick test run."
    )
    args = parser.parse_args()

    # --- 1. Find all NIfTI files ---
    if not os.path.isdir(args.input_dir):
        logger.error(f"Input directory not found: {args.input_dir}")
        return

    nifti_files = sorted(glob.glob(os.path.join(args.input_dir, "*.nii.gz")))
    
    if not nifti_files:
        logger.error(f"No .nii.gz files found in {args.input_dir}")
        return
        
    if args.num_files:
        nifti_files = nifti_files[:args.num_files]
        logger.info(f"Processing a subset of {len(nifti_files)} files.")

    # --- 2. Process files and collect MAE scores ---
    # Dictionary to hold a list of MAE scores for each wavelet
    results = {w: [] for w in args.wavelets}

    logger.info(f"Starting analysis on {len(nifti_files)} files for {len(args.wavelets)} wavelets...")

    for file_path in tqdm(nifti_files, desc="Processing Files", unit="file"):
        try:
            original_data = nib.load(file_path).get_fdata(dtype=np.float32)
            if original_data.ndim != 3:
                logger.warning(f"Skipping non-3D file: {os.path.basename(file_path)}")
                continue

            for wavelet in args.wavelets:
                try:
                    mae = calculate_reconstruction_mae(original_data, wavelet)
                    results[wavelet].append(mae)
                except Exception as e:
                    logger.warning(f"Could not process wavelet '{wavelet}' for file {os.path.basename(file_path)}: {e}")

        except Exception as e:
            logger.error(f"Failed to load or process file '{os.path.basename(file_path)}': {e}")
            continue
            
    # --- 3. Aggregate results and create summary ---
    summary_data = []
    for wavelet_name, mae_scores in results.items():
        if not mae_scores:
            logger.warning(f"No successful reconstructions for wavelet '{wavelet_name}'. It will be excluded from the report.")
            continue
            
        mean_mae = np.mean(mae_scores)
        std_mae = np.std(mae_scores)
        
        summary_data.append({
            "Wavelet": wavelet_name,
            "Mean_MAE": mean_mae,
            "Std_MAE": std_mae,
            "Num_Samples_Processed": len(mae_scores)
        })

    if not summary_data:
        logger.error("No results were generated. Aborting.")
        return

    # --- 4. Save to CSV ---
    df = pd.DataFrame(summary_data)
    # Sort by Mean MAE to easily see the best-performing wavelets
    df_sorted = df.sort_values(by="Mean_MAE").reset_index(drop=True)

    try:
        output_dir = os.path.dirname(args.output_csv)
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)
        df_sorted.to_csv(args.output_csv, index=False, float_format='%.6e')
        logger.info(f"Results successfully saved to: {args.output_csv}")
    except Exception as e:
        logger.error(f"Failed to save CSV file: {e}")
        
    # --- 5. Print summary to console ---
    print("\n--- Wavelet Reconstruction Error Summary ---")
    print(df_sorted.to_string(index=False))
    print("------------------------------------------")


if __name__ == "__main__":
    main()