#!/usr/bin/env python3
"""
Plotting utilities extracted from response_profile.py

Includes periodic spline and gaussian smoothing helpers and a compact
plot function to compare naive vs task responses.
"""

from typing import Dict, Optional, Tuple, List

import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from scipy.interpolate import UnivariateSpline, splrep, splev
from scipy.ndimage import gaussian_filter1d

# Global plotting settings matching response_profile.py
plt.rcParams.update({'font.size': 20})
plt.rcParams.update({'legend.fontsize': 14})
plt.rcParams.update({'axes.spines.top': False, 'axes.spines.right': False})


# Colors matching response_profile.py
NAIVE_COLOR = (101/255.0, 192/255.0, 255/255.0)
TRAINED_COLOR_GO = (79/255.0, 143/255.0, 0/255.0)
TRAINED_COLOR_NON_GO = (255/255.0, 38/255.0, 0/255.0)
PRIOR_VLINE_COLOR = (121/255.0, 121/255.0, 121/255.0)
STIM_VLINE_COLOR_GO = (140/255.0, 192/255.0, 76/255.0)
STIM_VLINE_COLOR_NON_GO = (255/255.0, 136/255.0, 116/255.0)


def periodic_spline_eval(angles, values, period: float = 180.0, smoothing_s: float = 0.0, num: int = 500):
    angles_arr = np.asarray(angles, dtype=float) % period
    values_arr = np.asarray(values, dtype=float)
    valid_mask = np.isfinite(angles_arr) & np.isfinite(values_arr)
    angles_arr = angles_arr[valid_mask]
    values_arr = values_arr[valid_mask]
    if angles_arr.size < 4:
        raise ValueError("Not enough points for periodic spline fit")
    order = np.argsort(angles_arr)
    angles_arr = angles_arr[order]
    values_arr = values_arr[order]
    angles_ext = np.concatenate([angles_arr - period, angles_arr, angles_arr + period])
    values_ext = np.concatenate([values_arr, values_arr, values_arr])
    try:
        tck = splrep(angles_ext, values_ext, s=smoothing_s, per=True)
        xs = np.linspace(0.0, period, num, endpoint=True)
        ys = splev(xs, tck)
    except Exception:
        spline = UnivariateSpline(angles_ext, values_ext, s=smoothing_s)
        xs = np.linspace(0.0, period, num, endpoint=True)
        ys = spline(xs)
    return xs, ys


def periodic_gaussian_eval(angles, values, period: float = 180.0, sigma: float = 6.0, num: int = 500):
    angles_arr = np.asarray(angles, dtype=float) % period
    values_arr = np.asarray(values, dtype=float)
    valid_mask = np.isfinite(angles_arr) & np.isfinite(values_arr)
    angles_arr = angles_arr[valid_mask]
    values_arr = values_arr[valid_mask]
    if angles_arr.size < 2:
        raise ValueError("Not enough points for Gaussian smoothing")
    grid_x = np.arange(0, int(period) + 1, 1, dtype=float)
    order = np.argsort(angles_arr)
    angles_sorted = angles_arr[order]
    values_sorted = values_arr[order]
    if angles_sorted[0] > 0.0:
        angles_sorted = np.insert(angles_sorted, 0, 0.0)
        values_sorted = np.insert(values_sorted, 0, values_sorted[-1])
    if angles_sorted[-1] < period:
        angles_sorted = np.append(angles_sorted, period)
        values_sorted = np.append(values_sorted, values_sorted[1] if values_sorted.size > 1 else values_sorted[0])
    grid_vals = np.interp(grid_x, angles_sorted, values_sorted)
    smoothed_grid = gaussian_filter1d(grid_vals, sigma=sigma, mode='wrap')
    xs_dense = np.linspace(0.0, period, num, endpoint=True)
    ys_dense = np.interp(xs_dense, grid_x, smoothed_grid)
    return xs_dense, ys_dense


def plot_series_pair(
    naive_series: Dict[int, float],
    task_series: Optional[Dict[int, float]] = None,
    smoothing_s: float = 0.05,
    periodic_spline: bool = False,
    use_gaussian_smoothing: bool = True,
    gaussian_sigma: float = 3.0,
    show_std_band: bool = True,
    naive_std: Optional[Dict[int, float]] = None,
    task_std: Optional[Dict[int, float]] = None,
    title: str = "Response Profile",
    show_legend: bool = True,
    prior_angles: Optional[List[int]] = None,
    show_prior_vlines: bool = True,
    colors: Tuple[Tuple[float, float, float], Tuple[float, float, float]] = ((101/255.0, 192/255.0, 255/255.0), (79/255.0, 143/255.0, 0/255.0)),
    filepath: Optional[str] = None,
):
    plt.figure(figsize=(10, 8))
    orientations = sorted(naive_series.keys())
    naive_values = [naive_series.get(a, np.nan) for a in orientations]
    if task_series is not None:
        task_values = [task_series.get(a, np.nan) for a in orientations]
    else:
        task_values = None

    try:
        if use_gaussian_smoothing:
            xs, naive_smooth = periodic_gaussian_eval(orientations, naive_values, period=180.0, sigma=gaussian_sigma, num=500)
        else:
            if periodic_spline:
                xs, naive_smooth = periodic_spline_eval(orientations, naive_values, period=180.0, smoothing_s=smoothing_s, num=500)
            else:
                naive_spline = UnivariateSpline(orientations, naive_values, s=smoothing_s)
                xs = np.linspace(min(orientations), max(orientations), 500)
                naive_smooth = naive_spline(xs)
        plt.plot(xs, naive_smooth, color=colors[0], linewidth=3.0, label='Natural Prior')
        if show_std_band and naive_std is not None:
            std_vals = [naive_std.get(a, 0.0) for a in orientations]
            try:
                if use_gaussian_smoothing:
                    band_xs, band_std = periodic_gaussian_eval(orientations, std_vals, period=180.0, sigma=gaussian_sigma, num=500)
                else:
                    if periodic_spline:
                        band_xs, band_std = periodic_spline_eval(orientations, std_vals, period=180.0, smoothing_s=smoothing_s, num=500)
                    else:
                        std_spline = UnivariateSpline(orientations, std_vals, s=smoothing_s)
                        band_xs = np.linspace(min(orientations), max(orientations), 500)
                        band_std = std_spline(band_xs)
            except Exception:
                band_xs = xs
                band_std = np.interp(xs, orientations, std_vals)
            lower = naive_smooth - band_std
            upper = naive_smooth + band_std
            plt.fill_between(band_xs, lower, upper, color=colors[0], alpha=0.2, linewidth=0)
    except Exception:
        pass

    if task_values is not None:
        try:
            if use_gaussian_smoothing:
                xs_t, task_smooth = periodic_gaussian_eval(orientations, task_values, period=180.0, sigma=gaussian_sigma, num=500)
            else:
                if periodic_spline:
                    xs_t, task_smooth = periodic_spline_eval(orientations, task_values, period=180.0, smoothing_s=smoothing_s, num=500)
                else:
                    task_spline = UnivariateSpline(orientations, task_values, s=smoothing_s)
                    xs_t = np.linspace(min(orientations), max(orientations), 500)
                    task_smooth = task_spline(xs_t)
            plt.plot(xs_t, task_smooth, color=colors[1], linewidth=3.0, label='Task Prior')
            if show_std_band and task_std is not None:
                tstd_vals = [task_std.get(a, 0.0) for a in orientations]
                try:
                    if use_gaussian_smoothing:
                        t_band_xs, t_band_std = periodic_gaussian_eval(orientations, tstd_vals, period=180.0, sigma=gaussian_sigma, num=500)
                    else:
                        if periodic_spline:
                            t_band_xs, t_band_std = periodic_spline_eval(orientations, tstd_vals, period=180.0, smoothing_s=smoothing_s, num=500)
                        else:
                            t_std_spline = UnivariateSpline(orientations, tstd_vals, s=smoothing_s)
                            t_band_xs = np.linspace(min(orientations), max(orientations), 500)
                            t_band_std = t_std_spline(t_band_xs)
                except Exception:
                    t_band_xs = xs_t
                    t_band_std = np.interp(xs_t, orientations, tstd_vals)
                t_lower = task_smooth - t_band_std
                t_upper = task_smooth + t_band_std
                plt.fill_between(t_band_xs, t_lower, t_upper, color=colors[1], alpha=0.15, linewidth=0)
        except Exception:
            pass

    if show_prior_vlines and prior_angles is not None:
        for pa in prior_angles:
            plt.axvline(pa, linestyle='--', color=(121/255.0, 121/255.0, 121/255.0), alpha=0.9, linewidth=2)

    plt.xlabel('Orientation (°)')
    plt.ylabel('Population Activity')
    plt.title(title)
    plt.grid(False)
    plt.xticks([0, 30, 60, 90, 120, 150, 180])
    if show_legend:
        plt.legend(fontsize=14)  # Match legend fontsize from response_profile.py
    if filepath:
        plt.savefig(filepath, format='pdf', bbox_inches='tight')
        plt.close()
    else:
        return plt.gcf()


def plot_multi_panel_aggregated(
    all_test_results: Dict[int, Dict],
    test_angles: List[int],
    stimulus_angles: List[int],
    smoothing_s: float,
    periodic_spline: bool,
    use_gaussian_smoothing: bool,
    gaussian_sigma: float,
    show_naive_curve: bool,
    show_std_band: bool,
    show_prior_vlines: bool,
    go_angle,
    outfile_base: str,
):
    if len(all_test_results) == 0:
        return None
    num_test_angles = len(all_test_results)
    cols = num_test_angles if num_test_angles > 0 else 1
    rows = 1
    # Match exact figure size from response_profile.py
    plt.figure(figsize=(10 * cols, 8))
    for i, test_angle in enumerate(test_angles):
        if test_angle not in all_test_results:
            continue
        results = all_test_results[test_angle]
        plt.subplot(rows, cols, i + 1)
        orientations = sorted(results['no_prior'].keys())
        natural_values = [results['no_prior'].get(a, 0) for a in orientations]
        task_values = [results['with_priors']['aggregated_final'].get(a, 0) for a in orientations] if 'aggregated_final' in results['with_priors'] else None
        is_go_angle = (test_angle in go_angle) if isinstance(go_angle, (list, tuple, np.ndarray)) else (test_angle == go_angle)
        trained_color = TRAINED_COLOR_GO if is_go_angle else TRAINED_COLOR_NON_GO
        if show_naive_curve:
            plt.plot(orientations, natural_values, marker='o', linestyle='None', color=NAIVE_COLOR, markersize=3, alpha=0.4)
        if task_values is not None:
            plt.plot(orientations, task_values, marker='o', linestyle='None', color=trained_color, markersize=3, alpha=0.4)
        try:
            if use_gaussian_smoothing:
                xs, nat_smoothed = periodic_gaussian_eval(orientations, natural_values, period=180.0, sigma=gaussian_sigma, num=500)
            else:
                if periodic_spline:
                    xs, nat_smoothed = periodic_spline_eval(orientations, natural_values, period=180.0, smoothing_s=smoothing_s, num=500)
                else:
                    nat_spline = UnivariateSpline(orientations, natural_values, s=smoothing_s)
                    xs = np.linspace(min(orientations), max(orientations), 500)
                    nat_smoothed = nat_spline(xs)
            plt.plot(xs, nat_smoothed, color=NAIVE_COLOR, linewidth=3.0, label='Natural Prior')
            if show_std_band and 'no_prior_std' in results and results['no_prior_std'] is not None:
                std_values = [results['no_prior_std'].get(a, 0.0) for a in orientations]
                try:
                    if use_gaussian_smoothing:
                        band_xs, band_std = periodic_gaussian_eval(orientations, std_values, period=180.0, sigma=gaussian_sigma, num=500)
                    else:
                        if periodic_spline:
                            band_xs, band_std = periodic_spline_eval(orientations, std_values, period=180.0, smoothing_s=smoothing_s, num=500)
                        else:
                            std_spline = UnivariateSpline(orientations, std_values, s=smoothing_s)
                            band_xs = np.linspace(min(orientations), max(orientations), 500)
                            band_std = std_spline(band_xs)
                except Exception:
                    band_xs = xs
                    band_std = np.interp(xs, orientations, std_values)
                lower = nat_smoothed - band_std
                upper = nat_smoothed + band_std
                plt.fill_between(band_xs, lower, upper, color=NAIVE_COLOR, alpha=0.2, linewidth=0)
        except Exception:
            pass
        if task_values is not None:
            try:
                if use_gaussian_smoothing:
                    xs_t, task_smoothed = periodic_gaussian_eval(orientations, task_values, period=180.0, sigma=gaussian_sigma, num=500)
                else:
                    if periodic_spline:
                        xs_t, task_smoothed = periodic_spline_eval(orientations, task_values, period=180.0, smoothing_s=smoothing_s, num=500)
                    else:
                        task_spline = UnivariateSpline(orientations, task_values, s=smoothing_s)
                        xs_t = np.linspace(min(orientations), max(orientations), 500)
                        task_smoothed = task_spline(xs_t)
                plt.plot(xs_t, task_smoothed, color=trained_color, linewidth=3.0, label='Task Prior')
                if show_std_band and 'with_priors_std' in results and results['with_priors_std'] and 'aggregated_final' in results['with_priors_std']:
                    comp_std_values = [results['with_priors_std']['aggregated_final'].get(a, 0.0) for a in orientations]
                    try:
                        if use_gaussian_smoothing:
                            comp_band_xs, comp_band_std = periodic_gaussian_eval(orientations, comp_std_values, period=180.0, sigma=gaussian_sigma, num=500)
                        else:
                            if periodic_spline:
                                comp_band_xs, comp_band_std = periodic_spline_eval(orientations, comp_std_values, period=180.0, smoothing_s=smoothing_s, num=500)
                            else:
                                comp_std_spline = UnivariateSpline(orientations, comp_std_values, s=smoothing_s)
                                comp_band_xs = np.linspace(min(orientations), max(orientations), 500)
                                comp_band_std = comp_std_spline(comp_band_xs)
                    except Exception:
                        comp_band_xs = xs_t
                        comp_band_std = np.interp(xs_t, orientations, comp_std_values)
                    comp_lower = task_smoothed - comp_band_std
                    comp_upper = task_smoothed + comp_band_std
                    plt.fill_between(comp_band_xs, comp_lower, comp_upper, color=trained_color, alpha=0.15, linewidth=0)
            except Exception:
                pass
        plt.xlim(0, 180)
        if show_prior_vlines:
            for pa in stimulus_angles:
                plt.axvline(pa, linestyle='--', color=PRIOR_VLINE_COLOR, alpha=0.9, linewidth=5)
        plt.axvline(test_angle, linestyle='--', color=(STIM_VLINE_COLOR_GO if is_go_angle else STIM_VLINE_COLOR_NON_GO), alpha=0.9, linewidth=1.0)
        plt.xticks([0, 30, 60, 90, 120, 150, 180])
        plt.grid(False)
        plt.xlabel('Orientation (°)')
        plt.ylabel('Population Activity')
        plt.title(f'Stimulus Angle: {test_angle}°')
    # Match exact subplot adjustments from response_profile.py
    plt.subplots_adjust(left=0.06, right=0.98, top=0.92, bottom=0.10, wspace=0.30)
    plt.savefig(outfile_base + '.pdf', format='pdf', bbox_inches='tight')
    plt.close()
    return outfile_base + '.pdf'


def plot_ratio_aggregated(
    all_test_results: Dict[int, Dict],
    test_angles: List[int],
    smoothing_s: float,
    periodic_spline: bool,
    use_gaussian_smoothing: bool,
    gaussian_sigma: float,
    outfile: str,
):
    if len(all_test_results) == 0:
        return None
    # Match exact figure size from response_profile.py
    plt.figure(figsize=(8, 8))
    ax_ratio = plt.gca()
    # Match spine settings from response_profile.py
    ax_ratio.spines['right'].set_visible(False)
    ax_ratio.spines['top'].set_visible(False)
    num_curves = len(all_test_results)
    cmap = plt.cm.get_cmap('tab20' if num_curves > 10 else 'tab10', num_curves)
    for idx, test_angle in enumerate(test_angles):
        if test_angle not in all_test_results:
            continue
        results = all_test_results[test_angle]
        orientations = sorted(results['no_prior'].keys())
        natural_values = [results['no_prior'].get(a, np.nan) for a in orientations]
        task_values = [results['with_priors']['aggregated_final'].get(a, np.nan) for a in orientations] if 'aggregated_final' in results['with_priors'] else None
        if task_values is None:
            continue
        ratio_values = [t / n if (n is not None and not np.isnan(n) and n != 0) else np.nan for n, t in zip(natural_values, task_values)]
        try:
            valid_mask = ~np.isnan(ratio_values)
            if np.sum(valid_mask) >= 4:
                xs_valid = np.array(orientations)[valid_mask]
                ys_valid = np.array(ratio_values)[valid_mask]
                if use_gaussian_smoothing:
                    xs_dense, ys_dense = periodic_gaussian_eval(xs_valid, ys_valid, period=180.0, sigma=max(1e-6, gaussian_sigma), num=500)
                else:
                    if periodic_spline:
                        xs_dense, ys_dense = periodic_spline_eval(xs_valid, ys_valid, period=180.0, smoothing_s=max(1e-6, smoothing_s), num=500)
                    else:
                        ratio_spline = UnivariateSpline(xs_valid, ys_valid, s=max(1e-6, smoothing_s))
                        xs_dense = np.linspace(float(np.min(xs_valid)), float(np.max(xs_valid)), 500)
                        ys_dense = ratio_spline(xs_dense)
                curve_color = cmap(idx)
                ax_ratio.plot(xs_dense, ys_dense, color=curve_color, linewidth=3.0, label=f'{test_angle}°')
        except Exception:
            pass
    ax_ratio.set_xlabel('Orientation (°)')
    ax_ratio.set_ylabel('Task/Natural ratio')
    ax_ratio.set_title('Task/Natural ratio across test angles')
    ax_ratio.set_xlim(0, 180)
    ax_ratio.grid(False)
    ax_ratio.set_xticks([0, 30, 60, 90, 120, 150, 180])
    ax_ratio.legend()
    plt.savefig(outfile, format='pdf', bbox_inches='tight')
    plt.close()
    return outfile



