#!/usr/bin/env python3
"""
Core logic extracted from response_profile.py for aggregated Laplace fitting and testing.

Assumptions simplified per request:
- STIMULUS_CONTRAST is always None
- LAPLACE_ANGLE_COMBINATIONS is always None
- COMPONENT_WEIGHTS is always None
- NATURAL_PRIOR_WEIGHT is always None

This module exposes focused, reusable functions without plotting side-effects.
"""

import os
import time
import pickle
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch

# Preserve import path logic used by scripts
import sys
from os.path import dirname, join
sys.path.append(join(dirname(__file__), '..', '..'))

from task_vae.utils import (
    generate_gratings,
    load_model,
    calculate_avg_abs_z_by_angle,
    find_most_responsive_angles,
    response_by_orientation,
    get_constrained_images,
)


def save_laplace_distribution(laplace_dist: torch.distributions.Laplace, filename: str) -> None:
    loc = laplace_dist.loc.cpu().detach().numpy()
    scale = laplace_dist.scale.cpu().detach().numpy()
    with open(filename, 'wb') as f:
        pickle.dump({'loc': loc, 'scale': scale}, f)


def load_laplace_distribution(filename: str, device: torch.device) -> torch.distributions.Laplace:
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    loc = torch.tensor(data['loc'], dtype=torch.float32).to(device)
    scale = torch.tensor(data['scale'], dtype=torch.float32).to(device)
    return torch.distributions.Laplace(loc=loc, scale=scale)


def fit_aggregated_laplace(
    data: Dict,
    model: torch.nn.Module,
    most_responsive_angles: np.ndarray,
    stimulus_angles: List[int],
    spatial_frequency: int,
    prior_contrast: float,
    prior_moment: str,
    num_iterations: int,
    orientation_bin_size: int,
    fits_dir: str,
    load_existing_fits: bool,
) -> torch.distributions.Laplace:
    """
    Fit or load an aggregated Laplace prior using the simplified flow.
    """
    print(f"Fitting aggregated Laplace for angles {stimulus_angles}")
    
    all_stimulus_images = []
    for stimulus_angle in stimulus_angles:
        stimulus_images = get_constrained_images(
            data, angle=stimulus_angle,
            spatial_frequency=spatial_frequency,
            contrast=prior_contrast,
        )
        print(f"  Found {len(stimulus_images)} images for angle {stimulus_angle}°")
        all_stimulus_images.extend(stimulus_images)
    all_stimulus_images = np.array(all_stimulus_images)
    print(f"Total stimulus images: {len(all_stimulus_images)}")

    angles_str = "_".join([str(a) for a in stimulus_angles])
    laplace_final_filename = f"laplace_aggregated_{angles_str}_sf{spatial_frequency}_contrast{prior_contrast}_final.pkl"
    laplace_final_filepath = os.path.join(fits_dir, laplace_final_filename)

    if load_existing_fits and os.path.exists(laplace_final_filepath):
        print(f"Loading existing Laplace fit from: {laplace_final_filepath}")
        device = next(model.parameters()).device
        return load_laplace_distribution(laplace_final_filepath, device)

    print("Computing initial response (no prior)...")
    # Compute initial response (no prior)
    _ = response_by_orientation(
        all_stimulus_images, model, most_responsive_angles,
        prior_image=None,
        absolute=True,
        one_layer=True,
        return_output_laplace=True,
        moment=prior_moment,
        bin_size=orientation_bin_size,
    )

    print("Initializing Laplace distribution...")
    # Initialize Laplace prior (zero mean, unit scale)
    device = next(model.parameters()).device
    sample_output = model(torch.tensor(all_stimulus_images[:1], dtype=torch.float32).to(device))
    z_dim = sample_output[1]["z"][1].loc.shape[-1]
    loc = torch.zeros(z_dim, device=device)
    scale = torch.ones(z_dim, device=device)
    laplace_dist = torch.distributions.Laplace(loc=loc, scale=scale)

    print("Computing response with initial prior...")
    # First pass with prior to get output laplace
    response_with_prior_result = response_by_orientation(
        all_stimulus_images, model, most_responsive_angles,
        prior_image=laplace_dist,
        absolute=True,
        one_layer=True,
        return_output_laplace=True,
        moment=prior_moment,
        bin_size=orientation_bin_size,
    )
    if isinstance(response_with_prior_result, tuple):
        current_laplace = response_with_prior_result[1]
    else:
        current_laplace = None

    print(f"Performing {num_iterations} iterations of Laplace fusion...")
    for iteration in range(1, num_iterations + 1):
        print(f"  Iteration {iteration}/{num_iterations}")
        response_iter_result = response_by_orientation(
            all_stimulus_images, model, most_responsive_angles,
            prior_image=current_laplace,
            absolute=True,
            one_layer=True,
            return_output_laplace=True,
            moment=prior_moment,
            bin_size=orientation_bin_size,
        )
        if isinstance(response_iter_result, tuple):
            current_laplace = response_iter_result[1]

    print(f"Saving final Laplace to: {laplace_final_filepath}")
    save_laplace_distribution(current_laplace, laplace_final_filepath)
    return current_laplace


def compute_angle_test(
    data: Dict,
    model: torch.nn.Module,
    most_responsive_angles: np.ndarray,
    test_angle: int,
    spatial_frequency: int,
    test_contrast: float,
    test_moment: str,
    test_threshold: float,
    bin_size: int,
    aggregated_laplace: torch.distributions.Laplace,
) -> Tuple[Dict[int, float], Optional[Dict[int, float]], Dict[str, Dict[int, float]], Dict[str, Dict[int, float]]]:
    """
    Compute naive and task responses for a single test angle, returning mean and std series.
    NATURAL_PRIOR_WEIGHT is assumed None, so not used.
    """
    print(f"Testing angle {test_angle}°...")
    
    test_images = get_constrained_images(
        data, angle=test_angle,
        spatial_frequency=spatial_frequency,
        contrast=test_contrast,
    )
    print(f"  Found {len(test_images)} test images")
    if len(test_images) == 0:
        print(f"  No images found for test angle {test_angle}°, skipping...")
        return None  # type: ignore

    print("  Computing response without prior...")
    test_response_no_prior_result = response_by_orientation(
        test_images, model, most_responsive_angles,
        prior_image=None,
        absolute=True,
        one_layer=True,
        return_output_laplace=False,
        moment=test_moment,
        threshold=test_threshold,
        natural_prior_weight=None,
        bin_size=bin_size,
        return_variability=True,
    )
    if isinstance(test_response_no_prior_result, tuple):
        test_response_no_prior = test_response_no_prior_result[0]
        try:
            test_response_no_prior_std = test_response_no_prior_result[1]
        except Exception:
            test_response_no_prior_std = None
    else:
        test_response_no_prior = test_response_no_prior_result
        test_response_no_prior_std = None

    print("  Computing response with task prior...")
    test_response_task_result = response_by_orientation(
        test_images, model, most_responsive_angles,
        prior_image=aggregated_laplace,
        absolute=True,
        one_layer=True,
        return_output_laplace=False,
        moment=test_moment,
        threshold=test_threshold,
        natural_prior_weight=None,
        bin_size=bin_size,
        return_variability=True,
    )
    if isinstance(test_response_task_result, tuple):
        test_response_task = test_response_task_result[0]
        try:
            test_response_task_std = test_response_task_result[1]
        except Exception:
            test_response_task_std = None
    else:
        test_response_task = test_response_task_result
        test_response_task_std = None

    with_priors = {"aggregated_final": test_response_task}
    with_priors_std = {"aggregated_final": test_response_task_std} if test_response_task_std is not None else {}
    print(f"  Completed test for angle {test_angle}°")
    return (test_response_no_prior, test_response_no_prior_std, with_priors, with_priors_std)


def prepare_data_and_model(
    spatial_frequency: int,
    prior_contrast: float,
    test_moment: str,
    test_threshold: float,
) -> Tuple[Dict, torch.nn.Module, np.ndarray]:
    """
    Generate gratings, load model, compute most responsive angles - streamlined version.
    """
    print("Generating gratings data...")
    data = generate_gratings(
        pca_model=None,
        hist_match=False,
        whiten=False,
        train_fraction=1,
        shuffle=True,
        noise_std=None,
        spherical_mask=False,
        n_cores=4,  # Reduced from -1 to speed up
    )
    print(f"Generated {len(data['train_images'])} training images")
    
    print("Loading model...")
    model = load_model()
    print("Model loaded successfully")
    
    print("Computing tuning curves (simplified)...")
    # Use a smaller subset for tuning analysis to speed up
    avg_abs_z_by_angle = calculate_avg_abs_z_by_angle(
        data, model,
        spatial_frequency=spatial_frequency,
        contrast=prior_contrast,
        abs=True,
        moment=test_moment,
        threshold=test_threshold,
    )
    print(f"Calculated tuning for {len(avg_abs_z_by_angle)} angles")
    
    most_responsive_angles, _ = find_most_responsive_angles(
        avg_abs_z_by_angle,
        method='von_mises',
        return_max_values=True,
        filter_by_max_values=False,
        max_values_percentile=90,
    )
    num_responsive = np.sum(~np.isnan(most_responsive_angles))
    print(f"Found {num_responsive} responsive angles")
    
    return data, model, most_responsive_angles



