#!/usr/bin/env python3
"""
GO/NOGO Analysis Script

This script performs a specialized analysis where:
1. Takes a single GO_ANGLE and a list of NOGO_ANGLES
2. For each NOGO angle, creates a prior with [GO_ANGLE, nogo_angle]
3. Fits an aggregated Laplace prior using iterative method
4. Tests the response at the NOGO angle with this prior
5. Creates a multi-panel plot with one panel per NOGO angle

This generates insights into how GO/NOGO priors affect responses at NOGO angles.
"""

import os
import time
from typing import List

import numpy as np
import sys
import os as _os

try:
    # When executed as a module: python -m task_vae.scripts.go_nogo_analysis
    from .core import (
        prepare_data_and_model,
        fit_aggregated_laplace,
        compute_angle_test,
    )
    from .plotting import plot_multi_panel_aggregated
except Exception:
    # When executed directly: python go_nogo_analysis.py
    _this_dir = _os.path.dirname(__file__)
    if _this_dir not in sys.path:
        sys.path.append(_this_dir)
    from core import (  # type: ignore
        prepare_data_and_model,
        fit_aggregated_laplace,
        compute_angle_test,
    )
    from plotting import plot_multi_panel_aggregated  # type: ignore


def main():
    # Configuration parameters
    GO_ANGLE: int = 45
    NOGO_ANGLES: List[int] = [90, 75, 70, 65, 60]
    SPATIAL_FREQUENCY = 3
    PRIOR_CONTRAST = 1.0
    TEST_CONTRAST = 0.3
    PRIOR_MOMENT = "abs"
    TEST_MOMENT = "mean"
    TEST_THRESHOLD = 0.0
    NUM_ITERATIONS = 5
    ORIENTATION_BIN_SIZE = 5
    SMOOTHING_S = 0.05
    PERIODIC_SPLINE = False
    USE_GAUSSIAN_SMOOTHING = True
    GAUSSIAN_SIGMA = 3.0
    SHOW_STD_BAND = True
    
    print("=== GO/NOGO Response Profile Analysis ===")
    print(f"GO angle: {GO_ANGLE}°")
    print(f"NOGO angles: {NOGO_ANGLES}")
    print(f"Iterations: {NUM_ITERATIONS}")
    print()

    base_dir = os.path.dirname(__file__)
    fits_dir = os.path.join(base_dir, 'fits')
    plots_dir = os.path.join(base_dir, 'plots')
    os.makedirs(fits_dir, exist_ok=True)
    os.makedirs(plots_dir, exist_ok=True)

    # Prepare data and model once
    print("Preparing data and model...")
    data, model, most_responsive_angles = prepare_data_and_model(
        spatial_frequency=SPATIAL_FREQUENCY,
        prior_contrast=PRIOR_CONTRAST,
        test_moment=TEST_MOMENT,
        test_threshold=TEST_THRESHOLD,
    )
    print("Data and model ready.\n")

    # Accumulate results for each NOGO angle
    all_test_results = {}
    
    for nogo_angle in NOGO_ANGLES:
        print(f"=== Processing NOGO angle {nogo_angle}° ===")
        
        # Create stimulus angles: [GO_ANGLE, nogo_angle]
        stimulus_angles = [GO_ANGLE, nogo_angle]
        print(f"Stimulus angles for this iteration: {stimulus_angles}")
        
        # Fit aggregated Laplace prior for this GO/NOGO combination
        aggregated_laplace = fit_aggregated_laplace(
            data=data,
            model=model,
            most_responsive_angles=most_responsive_angles,
            stimulus_angles=stimulus_angles,
            spatial_frequency=SPATIAL_FREQUENCY,
            prior_contrast=PRIOR_CONTRAST,
            prior_moment=PRIOR_MOMENT,
            num_iterations=NUM_ITERATIONS,
            orientation_bin_size=ORIENTATION_BIN_SIZE,
            fits_dir=fits_dir,
            load_existing_fits=True,
        )
        
        # Test response at the NOGO angle using this prior
        print(f"Testing response at NOGO angle {nogo_angle}°...")
        results = compute_angle_test(
            data=data,
            model=model,
            most_responsive_angles=most_responsive_angles,
            test_angle=nogo_angle,
            spatial_frequency=SPATIAL_FREQUENCY,
            test_contrast=TEST_CONTRAST,
            test_moment=TEST_MOMENT,
            test_threshold=TEST_THRESHOLD,
            bin_size=ORIENTATION_BIN_SIZE,
            aggregated_laplace=aggregated_laplace,
        )
        
        if results is not None:
            naive_series, naive_std, with_priors, with_priors_std = results
            all_test_results[nogo_angle] = {
                'no_prior': naive_series,
                'no_prior_std': naive_std,
                'with_priors': with_priors,
                'with_priors_std': with_priors_std,
                'stimulus_angles': stimulus_angles,  # Store for reference
            }
            print(f"Successfully processed NOGO angle {nogo_angle}°")
        else:
            print(f"No results for NOGO angle {nogo_angle}° (no images found)")
        
        print()

    # Generate multi-panel plot with one panel per NOGO angle
    if len(all_test_results) > 0:
        print(f"Generating GO/NOGO multi-panel plot for {len(all_test_results)} NOGO angles...")
        timestamp = int(time.time())
        
        # Create output filename
        nogo_str = "_".join([str(a) for a in NOGO_ANGLES if a in all_test_results])
        base = os.path.join(plots_dir, f'go{GO_ANGLE}_nogo_{nogo_str}_prior{PRIOR_CONTRAST}_{timestamp}')
        
        # Custom multi-panel plot for GO/NOGO analysis
        multi_panel_file = plot_go_nogo_multi_panel(
            all_test_results=all_test_results,
            go_angle=GO_ANGLE,
            nogo_angles=[a for a in NOGO_ANGLES if a in all_test_results],
            smoothing_s=SMOOTHING_S,
            periodic_spline=PERIODIC_SPLINE,
            use_gaussian_smoothing=USE_GAUSSIAN_SMOOTHING,
            gaussian_sigma=GAUSSIAN_SIGMA,
            show_std_band=SHOW_STD_BAND,
            outfile_base=base,
        )
        
        if multi_panel_file:
            print(f"Saved GO/NOGO multi-panel plot: {multi_panel_file}")
    else:
        print("No test results to plot.")
    
    print("\nGO/NOGO analysis complete!")


def plot_go_nogo_multi_panel(
    all_test_results,
    go_angle: int,
    nogo_angles: List[int],
    smoothing_s: float,
    periodic_spline: bool,
    use_gaussian_smoothing: bool,
    gaussian_sigma: float,
    show_std_band: bool,
    outfile_base: str,
):
    """
    Custom multi-panel plot for GO/NOGO analysis.
    Each panel shows the response at a NOGO angle when trained with [GO_ANGLE, nogo_angle].
    """
    if len(all_test_results) == 0:
        return None
    
    # Import plotting utilities
    import matplotlib.pyplot as plt
    from plotting import (
        periodic_gaussian_eval, 
        periodic_spline_eval, 
        NAIVE_COLOR, 
        TRAINED_COLOR_GO, 
        TRAINED_COLOR_NON_GO, 
        PRIOR_VLINE_COLOR, 
        STIM_VLINE_COLOR_NON_GO
    )
    from scipy.interpolate import UnivariateSpline
    
    num_panels = len(nogo_angles)
    cols = num_panels
    rows = 1
    
    # Match exact figure size from response_profile.py
    plt.figure(figsize=(10 * cols, 8))
    
    for i, nogo_angle in enumerate(nogo_angles):
        if nogo_angle not in all_test_results:
            continue
            
        results = all_test_results[nogo_angle]
        stimulus_angles = results['stimulus_angles']  # [GO_ANGLE, nogo_angle]
        
        plt.subplot(rows, cols, i + 1)
        
        # Get response data
        orientations = sorted(results['no_prior'].keys())
        natural_values = [results['no_prior'].get(a, 0) for a in orientations]
        task_values = [results['with_priors']['aggregated_final'].get(a, 0) for a in orientations] if 'aggregated_final' in results['with_priors'] else None
        
        # Plot raw data points
        plt.plot(orientations, natural_values, marker='o', linestyle='None', color=NAIVE_COLOR, markersize=3, alpha=0.4)
        if task_values is not None:
            plt.plot(orientations, task_values, marker='o', linestyle='None', color=TRAINED_COLOR_NON_GO, markersize=3, alpha=0.4)
        
        # Plot smoothed curves
        try:
            if use_gaussian_smoothing:
                xs, nat_smoothed = periodic_gaussian_eval(orientations, natural_values, period=180.0, sigma=gaussian_sigma, num=500)
            else:
                if periodic_spline:
                    xs, nat_smoothed = periodic_spline_eval(orientations, natural_values, period=180.0, smoothing_s=smoothing_s, num=500)
                else:
                    nat_spline = UnivariateSpline(orientations, natural_values, s=smoothing_s)
                    xs = np.linspace(min(orientations), max(orientations), 500)
                    nat_smoothed = nat_spline(xs)
            plt.plot(xs, nat_smoothed, color=NAIVE_COLOR, linewidth=3.0, label='Natural Prior')
            
            # Add standard deviation band if requested
            if show_std_band and 'no_prior_std' in results and results['no_prior_std'] is not None:
                std_values = [results['no_prior_std'].get(a, 0.0) for a in orientations]
                try:
                    if use_gaussian_smoothing:
                        band_xs, band_std = periodic_gaussian_eval(orientations, std_values, period=180.0, sigma=gaussian_sigma, num=500)
                    else:
                        if periodic_spline:
                            band_xs, band_std = periodic_spline_eval(orientations, std_values, period=180.0, smoothing_s=smoothing_s, num=500)
                        else:
                            std_spline = UnivariateSpline(orientations, std_values, s=smoothing_s)
                            band_xs = np.linspace(min(orientations), max(orientations), 500)
                            band_std = std_spline(band_xs)
                except Exception:
                    band_xs = xs
                    band_std = np.interp(xs, orientations, std_values)
                lower = nat_smoothed - band_std
                upper = nat_smoothed + band_std
                plt.fill_between(band_xs, lower, upper, color=NAIVE_COLOR, alpha=0.2, linewidth=0)
        except Exception:
            pass
        
        # Plot task prior curve
        if task_values is not None:
            try:
                if use_gaussian_smoothing:
                    xs_t, task_smoothed = periodic_gaussian_eval(orientations, task_values, period=180.0, sigma=gaussian_sigma, num=500)
                else:
                    if periodic_spline:
                        xs_t, task_smoothed = periodic_spline_eval(orientations, task_values, period=180.0, smoothing_s=smoothing_s, num=500)
                    else:
                        task_spline = UnivariateSpline(orientations, task_values, s=smoothing_s)
                        xs_t = np.linspace(min(orientations), max(orientations), 500)
                        task_smoothed = task_spline(xs_t)
                plt.plot(xs_t, task_smoothed, color=TRAINED_COLOR_NON_GO, linewidth=3.0, label='Task Prior')
                
                # Add standard deviation band for task prior if requested
                if show_std_band and 'with_priors_std' in results and results['with_priors_std'] and 'aggregated_final' in results['with_priors_std']:
                    comp_std_values = [results['with_priors_std']['aggregated_final'].get(a, 0.0) for a in orientations]
                    try:
                        if use_gaussian_smoothing:
                            comp_band_xs, comp_band_std = periodic_gaussian_eval(orientations, comp_std_values, period=180.0, sigma=gaussian_sigma, num=500)
                        else:
                            if periodic_spline:
                                comp_band_xs, comp_band_std = periodic_spline_eval(orientations, comp_std_values, period=180.0, smoothing_s=smoothing_s, num=500)
                            else:
                                comp_std_spline = UnivariateSpline(orientations, comp_std_values, s=smoothing_s)
                                comp_band_xs = np.linspace(min(orientations), max(orientations), 500)
                                comp_band_std = comp_std_spline(comp_band_xs)
                    except Exception:
                        comp_band_xs = xs_t
                        comp_band_std = np.interp(xs_t, orientations, comp_std_values)
                    comp_lower = task_smoothed - comp_band_std
                    comp_upper = task_smoothed + comp_band_std
                    plt.fill_between(comp_band_xs, comp_lower, comp_upper, color=TRAINED_COLOR_NON_GO, alpha=0.15, linewidth=0)
            except Exception:
                pass
        
        # Add vertical lines for stimulus angles (GO and current NOGO)
        for stim_angle in stimulus_angles:
            if stim_angle == go_angle:
                # GO angle gets a different style
                plt.axvline(stim_angle, linestyle='--', color=PRIOR_VLINE_COLOR, alpha=0.9, linewidth=5, label=f'GO ({stim_angle}°)' if i == 0 else "")
            else:
                # NOGO angle
                plt.axvline(stim_angle, linestyle='--', color=PRIOR_VLINE_COLOR, alpha=0.9, linewidth=5)
        
        # Highlight the test angle (same as nogo_angle)
        plt.axvline(nogo_angle, linestyle='--', color=STIM_VLINE_COLOR_NON_GO, alpha=0.9, linewidth=1.0)
        
        # Formatting
        plt.xlim(0, 180)
        plt.xticks([0, 30, 60, 90, 120, 150, 180])
        plt.grid(False)
        plt.xlabel('Orientation (°)')
        plt.ylabel('Population Activity')
        plt.title(f'Stimulus Angle: {nogo_angle}°')
    
    # Match exact subplot adjustments from response_profile.py
    plt.subplots_adjust(left=0.06, right=0.98, top=0.92, bottom=0.10, wspace=0.30)
    plt.savefig(outfile_base + '.pdf', format='pdf', bbox_inches='tight')
    plt.close()
    return outfile_base + '.pdf'


if __name__ == '__main__':
    main()
