#!/usr/bin/env python3
"""
Small runner script using core and plotting modules to perform a simplified
aggregated Laplace fit and test on selected angles, and save a compact plot.

This does not modify the original response_profile.py; it reuses extracted logic.
"""

import os
import time
from typing import List

import numpy as np
import sys
import os as _os
try:
    # When executed as a module: python -m task_vae.scripts.run_analysis
    from .core import (
        prepare_data_and_model,
        fit_aggregated_laplace,
        compute_angle_test,
    )
    from .plotting import plot_series_pair, plot_multi_panel_aggregated, plot_ratio_aggregated
except Exception:
    # When executed directly: python run_analysis.py
    _this_dir = _os.path.dirname(__file__)
    if _this_dir not in sys.path:
        sys.path.append(_this_dir)
    from core import (  # type: ignore
        prepare_data_and_model,
        fit_aggregated_laplace,
        compute_angle_test,
    )
    from plotting import plot_series_pair, plot_multi_panel_aggregated, plot_ratio_aggregated  # type: ignore


def main():
    STIMULUS_ANGLES: List[int] = [45, 90]
    TEST_ANGLES: List[int] = [45, 90, 75, 70, 65, 60]
    SPATIAL_FREQUENCY = 3
    PRIOR_CONTRAST = 1.0
    TEST_CONTRAST = 0.3
    PRIOR_MOMENT = "abs"
    TEST_MOMENT = "mean"
    TEST_THRESHOLD = 0.0
    NUM_ITERATIONS = 5
    ORIENTATION_BIN_SIZE = 5
    SMOOTHING_S = 0.05
    PERIODIC_SPLINE = False
    USE_GAUSSIAN_SMOOTHING = True
    GAUSSIAN_SIGMA = 3.0
    SHOW_STD_BAND = True
    
    print("=== Lightweight Response Profile Analysis ===")
    print(f"Stimulus angles: {STIMULUS_ANGLES}")
    print(f"Test angles: {TEST_ANGLES}")
    print(f"Iterations: {NUM_ITERATIONS}")
    print()

    base_dir = os.path.dirname(__file__)
    fits_dir = os.path.join(base_dir, 'fits')
    plots_dir = os.path.join(base_dir, 'plots')
    os.makedirs(fits_dir, exist_ok=True)
    os.makedirs(plots_dir, exist_ok=True)

    data, model, most_responsive_angles = prepare_data_and_model(
        spatial_frequency=SPATIAL_FREQUENCY,
        prior_contrast=PRIOR_CONTRAST,
        test_moment=TEST_MOMENT,
        test_threshold=TEST_THRESHOLD,
    )

    aggregated_laplace = fit_aggregated_laplace(
        data=data,
        model=model,
        most_responsive_angles=most_responsive_angles,
        stimulus_angles=STIMULUS_ANGLES,
        spatial_frequency=SPATIAL_FREQUENCY,
        prior_contrast=PRIOR_CONTRAST,
        prior_moment=PRIOR_MOMENT,
        num_iterations=NUM_ITERATIONS,
        orientation_bin_size=ORIENTATION_BIN_SIZE,
        fits_dir=fits_dir,
        load_existing_fits=True,
    )

    # Accumulate results across angles similar to response_profile.py
    all_test_results = {}
    for test_angle in TEST_ANGLES:
        results = compute_angle_test(
            data=data,
            model=model,
            most_responsive_angles=most_responsive_angles,
            test_angle=test_angle,
            spatial_frequency=SPATIAL_FREQUENCY,
            test_contrast=TEST_CONTRAST,
            test_moment=TEST_MOMENT,
            test_threshold=TEST_THRESHOLD,
            bin_size=ORIENTATION_BIN_SIZE,
            aggregated_laplace=aggregated_laplace,
        )
        if results is None:
            continue
        naive_series, naive_std, with_priors, with_priors_std = results
        all_test_results[test_angle] = {
            'no_prior': naive_series,
            'no_prior_std': naive_std,
            'with_priors': with_priors,
            'with_priors_std': with_priors_std,
        }

    # Generate plots matching response_profile.py structure
    if len(all_test_results) > 0:
        print(f"\nGenerating plots for {len(all_test_results)} test results...")
        angles_str = "_".join([str(a) for a in STIMULUS_ANGLES])
        timestamp = int(time.time())
        
        # Multi-panel plot
        base = os.path.join(plots_dir, f'response_profiles_{angles_str}_prior{PRIOR_CONTRAST}_{timestamp}')
        multi_panel_file = plot_multi_panel_aggregated(
            all_test_results=all_test_results,
            test_angles=TEST_ANGLES,
            stimulus_angles=STIMULUS_ANGLES,
            smoothing_s=SMOOTHING_S,
            periodic_spline=PERIODIC_SPLINE,
            use_gaussian_smoothing=USE_GAUSSIAN_SMOOTHING,
            gaussian_sigma=GAUSSIAN_SIGMA,
            show_naive_curve=True,
            show_std_band=SHOW_STD_BAND,
            show_prior_vlines=True,
            go_angle=[],
            outfile_base=base,
        )
        if multi_panel_file:
            print(f"Saved multi-panel plot: {multi_panel_file}")
        
        # Ratio plot
        ratio_file = os.path.join(plots_dir, f'response_profiles_ratio_{angles_str}_prior{PRIOR_CONTRAST}_{timestamp}.pdf')
        ratio_result = plot_ratio_aggregated(
            all_test_results=all_test_results,
            test_angles=TEST_ANGLES,
            smoothing_s=SMOOTHING_S,
            periodic_spline=PERIODIC_SPLINE,
            use_gaussian_smoothing=USE_GAUSSIAN_SMOOTHING,
            gaussian_sigma=GAUSSIAN_SIGMA,
            outfile=ratio_file,
        )
        if ratio_result:
            print(f"Saved ratio plot: {ratio_result}")
    else:
        print("No test results to plot.")
    
    print("\nAnalysis complete!")


if __name__ == '__main__':
    main()


