import torch
import torch.nn.functional as F
import metrics
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import sklearn
import os
import random
import json
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiplicativeLR

from metrics import (r2_score, regression_r2_score, pseudo_r2_score,
                     average_r2_score, apply_metric)
from plot import plot_prediction, plot_lograte, plot_units
from data import to_causal_segments
from training.utils import merge_windows
from utils.ic_solver import solve_ic_monotonic as solve_ic_direct, solve_ic_mlp, solve_ic_rollout


def get_module_weights(module):
    """
    Extract weights from a PyTorch module.
    """
    # For nn.Linear or similar layers, return the weight tensor
    if hasattr(module, 'weight'):
        return module.weight
    # For custom MLPLayer instances
    elif hasattr(module, 'log_weight'):
        return module.log_weight.weight
    # For nn.ModuleDict (used for separate readouts)
    elif isinstance(module, torch.nn.ModuleDict):
        return [get_module_weights(m) for m in module.values()]
    else:
        raise ValueError(
            "Readout weights could not be extracted, "
            "cannot apply L1 regularization."
        )


def train(
        model, optimizer, train_dataloader, checkpoint_path, config,
        log=None, device=None, val_dataloader=None,
        wandb_run=None, **kwargs):
    # Move model to the specified device if it's not already there
    model = model.to(device)
    model.train()

    epochs = config['epochs']
    # Define z0 logging milestones: after 100 epochs (if within range), halfway, and end
    z0_milestones = set()
    if epochs > 0:
        z0_milestones.add(max(0, epochs - 1))  # end
        z0_milestones.add(max(0, epochs // 2)) # halfway
        if epochs > 100:
            z0_milestones.add(100)

    # Define learning rate scheduler
    if config['lr_scheduler']:
        lr_scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.1,
                                         patience=5, threshold=1e-5,
                                         min_lr=1e-4)

    all_epoch_losses = []
    all_epoch_scores = []
    all_epoch_old_r2 = []
    all_epoch_old_pseudo_r2 = []
    all_epoch_old_regression_r2 = []
    all_epoch_old_average_r2 = []

    # Add validation metrics tracking
    all_val_epoch_scores = []
    best_avg_r2_score = -float('inf')
    best_epoch = -1

    # --- Define Relative Paths and Ensure Directories Exist ---
    base_dir = os.path.dirname(checkpoint_path)
    filename = os.path.basename(checkpoint_path)
    name, ext = os.path.splitext(filename)
    final_checkpoint_path = os.path.join(base_dir, f"{name}_final{ext}")

    # Use paths from config
    heatmap_dir = os.path.join(config['BASE_PATH'], 'heatmaps')
    log_dir = config['LOG_PATH']
    metrics_dir = config['METRICS_PATH']

    epochs_per_group = config['epochs_per_group']
    points_per_group = config['points_per_group']

    # Create all necessary directories
    os.makedirs(base_dir, exist_ok=True)
    os.makedirs(heatmap_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(metrics_dir, exist_ok=True)

    # Construct file paths
    best_heatmap_path = os.path.join(heatmap_dir, 'best_heatmap.png')
    final_heatmap_path = os.path.join(heatmap_dir, 'final_heatmap.png')

    # Ensure log is a string and construct log file path
    if log is None:
        # Descriptive default: dataset, window/step, causal flag, dynamics/readout, latent size, seed
        log = (
            f"{config.get('data_type', 'data')}"
            f"_W{config.get('window_size', 'T')}"
            f"_S{config.get('step_size', 'S')}"
            f"_{'causal' if config.get('causal_model', False) else 'acausal'}"
            f"_{config.get('dynamics_model_type', 'dyn')}"
            f"_{config.get('readout_type', 'readout')}"
            f"_L{config.get('latent_size', 'L')}"
            f"_seed{config.get('seed', 0)}"
        )


    elif isinstance(log, str) and log.startswith(config['LOG_PATH']):
        # If log already contains the full path, extract just the filename
        log = os.path.basename(log)

    log_file_path = os.path.join(log_dir, f'{log}.json')
    metrics_file_path = os.path.join(metrics_dir, 'training_metrics.json')
    training_curves_path = os.path.join(metrics_dir, 'training_curves.png')
    training_curves_focused_path = os.path.join(metrics_dir, 'training_curves_focused.png')
    # ---------------------------------------------------------

    for epoch in range(epochs):
        # Initialize last_batch_loss for this epoch
        last_batch_loss = 0.0
        # Whether to calculate training/validation metrics
        compute_metrics = (epoch % 10 == 0 or epoch == epochs - 1)
        # Create empty list to store results
        all_batch_pred_rates = []
        all_batch_decoding_spikes = []

        for batch in train_dataloader:
            #NOTE: Talk to clay about whether this was added for inputs
            # # New windowed batches: (encoding_spikes, decoding_spikes)
            # if isinstance(batch, (list, tuple)) and len(batch) == 2:
            #     encoding_spikes, decoding_spikes = batch
            # else:
            #     encoding_spikes = batch[0]
            #     decoding_spikes = batch[0]

            # Get encoding and decoding windows from batch
            encoding_spikes, decoding_spikes = batch

            # Move batch data to device
            encoding_spikes = encoding_spikes.to(device)
            decoding_spikes = decoding_spikes.to(device)

            optimizer.zero_grad()
            
            # Forward pass using encoding window
            _, pred_logrates, _, _, _, _ = model(encoding_spikes)
            pred_rates = torch.exp(pred_logrates)

            loss_all = F.poisson_nll_loss(
                pred_logrates, decoding_spikes, full=True, reduction="none"
            )

            # Add results to running lists
            all_batch_pred_rates.append(pred_rates.detach())
            all_batch_decoding_spikes.append(decoding_spikes.detach())

            # Incrementally consider more points in the loss
            total_points = loss_all.shape[1]

            group_number = int(epoch / epochs_per_group) + 1
            num_points = min(group_number * points_per_group, total_points)

            # Compute weighted loss
            loss = torch.mean(loss_all[:, :num_points, :])

            ##### L1 Regularization On Readout ########
            if config['l1_reg'] is not None:
                # Get readout weights
                readout_weights = get_module_weights(model.readout)
                # Concatenate weights if they are a list
                if isinstance(readout_weights, list):
                    readout_weights = torch.cat(
                        [w.view(-1) for w in readout_weights]
                    )
                # Calculate L1 regularization term and add to loss
                l1_reg = torch.sum(
                    torch.abs(readout_weights)
                )
                loss += config['l1_reg'] * l1_reg

            # Encourage orthogonality of readout weights
            if config['orth_reg'] is not None:
                # Get readout weights
                readout_weights = get_module_weights(model.readout)
                # If multiple readouts, sum their orthogonality losses
                if isinstance(readout_weights, list):
                    latents_orth_loss = 0.0
                    for rw in readout_weights:
                        latents_orth_loss += torch.norm(
                            torch.eye(rw.size(1), device=device) -
                            torch.matmul(rw.T, rw)
                        )
                else:
                    # Calculate orthogonality loss
                    latents_orth_loss = torch.norm(
                        torch.eye(readout_weights.size(1), device=device) -
                        torch.matmul(readout_weights.T, readout_weights)
                    )
                # Add orthogonality loss to the total loss
                loss += config['orth_reg'] * latents_orth_loss

            # Backpropagation
            loss.backward()
            # Clip gradients to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            # Update model parameters
            optimizer.step()

            # save the loss from the last batch
            last_batch_loss += loss.item()

        # Calculate average loss
        avg_train_loss = last_batch_loss / len(train_dataloader)
        # Run validation
        _, _, _, _, val_metrics = validate(
            model, val_dataloader, config=config, device=device,
            calculate_metrics=compute_metrics, verbose=False
        )
        avg_val_loss = val_metrics['val_loss']
        # Set model back to training mode
        model.train()
        
        if config['lr_scheduler']:
            # Step the scheduler based on validation loss
            lr_scheduler.step(avg_val_loss)
        
        # Log losses to wandb
        if wandb_run:
            wandb_run.log({
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'epoch': epoch,
                'learning_rate': optimizer.param_groups[0]['lr']
            })

        # Calculate metrics only for every 10th epoch (or the last epoch)
        if compute_metrics:
            # Log validation metrics to wandb
            if wandb_run:
                # Add epoch and remove val loss
                val_metrics['epoch'] = epoch
                val_metrics.pop('val_loss', None)
                wandb_run.log(val_metrics)

            print(f"Epoch {epoch}: Train Loss = {avg_train_loss:.4f}, "
                  f"Val Loss = {avg_val_loss:.4f}")

            if config['l1_reg'] is not None:
                print(f"{config['l1_reg']} L1 Regularization applied")
            
            # Calculate comprehensive R² scores
            with torch.no_grad():
                # Concatentate training results
                y_pred = torch.cat(all_batch_pred_rates, dim=0).cpu()
                y_true = torch.cat(all_batch_decoding_spikes, dim=0).cpu()

                # Calculate metrics using consistent approach w/ validate func
                score = {}

                # Primary metrics using custom funcs (consistent with validate)
                score['r2'] = apply_metric(r2_score, y_pred, y_true)

                # Population-specific custom metrics
                population_scores_filtered = {}
                for name, (start, end) in config.get('neurons', {}).items():
                    pred_slice = y_pred[:, :, start:end].cpu()
                    spike_slice = y_true[:, :, start:end].cpu()

                    # R2 Score
                    score[f'r2_{name}'] = apply_metric(
                        r2_score, pred_slice, spike_slice
                    )
                    # Pseudo R2 Score
                    score[f'pseudo_r2_{name}'] = apply_metric(
                        pseudo_r2_score, pred_slice, spike_slice
                    )
                    # Average R2 Score
                    score[f'average_r2_{name}'] = apply_metric(
                        average_r2_score, pred_slice, spike_slice
                    )

                    # Store population scores for structured output
                    population_scores_filtered[name] = {
                        'r2': score[f"r2_{name}"],
                        'pseudo_r2': score[f"pseudo_r2_{name}"],
                        'average_r2': score[f"average_r2_{name}"]
                    }

                if config['data_type'] != 'wang_100T':
                    # Prediction horizon metrics
                    for horizon in range(min(y_pred.shape[1], 3)):
                        # Get Nth prediction and target
                        pred_horizon = y_pred[:, horizon, :].squeeze()
                        tgt_horizon = y_true[:, horizon, :].squeeze()
                    # Calculate R2 Scores
                    score[f'r2_horizon_{horizon}'] = \
                        apply_metric(r2_score,
                                    pred_horizon,
                                    tgt_horizon)
                    score[f'pseudo_r2_horizon_{horizon}'] = \
                        apply_metric(pseudo_r2_score,
                                    pred_horizon,
                                    tgt_horizon)

                # Calculate summary metrics using both approaches
                population_r2_values = [score.get(f'r2_{name}', float('nan')) for name in config.get('neurons', {}).keys()]
                valid_r2_values = [r for r in population_r2_values if not np.isnan(r)]
                simple_avg_r2 = np.mean(valid_r2_values) if valid_r2_values else float('nan')
                simple_median_r2 = np.median(valid_r2_values) if valid_r2_values else float('nan')

                # Calculate overall metrics using the same functions as metrics.py
                overall_r2_custom = metrics.r2_score(y_pred.cpu(), y_true.cpu()).item()
                overall_pseudo_r2_custom = metrics.pseudo_r2_score(y_pred.cpu(), y_true.cpu()).item()
                overall_average_r2_custom = metrics.average_r2_score(y_pred.cpu(), y_true.cpu()).item()

                # Print metrics
                print(f"Training Metrics:")
                print(f"  Overall R² (median): {overall_r2_custom:.4f} | Simple Pop Avg: {simple_avg_r2:.4f} | Simple Pop Median: {simple_median_r2:.4f}")
                print(f"  Overall Pseudo R² (median): {overall_pseudo_r2_custom:.4f}")
                print(f"  Overall Average R² (trial-averaged): {overall_average_r2_custom:.4f}")
                for name in config.get('neurons', {}).keys():
                    print(f"  {name} - R²: {score.get(f'r2_{name}', float('nan')):.4f}, "
                          f"Pseudo R²: {score.get(f'pseudo_r2_{name}', float('nan')):.4f}, "
                          f"Avg R²: {score.get(f'average_r2_{name}', float('nan')):.4f}")

                # Print Validation Metrics
                print(f"Validation Metrics (Average):")
                for name in config.get('neurons', {}).keys():
                    print(f"  {name} - R²: {val_metrics.get(f'val_r2_{name}', float('nan')):.4f}, "
                          f"Pseudo R²: {val_metrics.get(f'val_pseudo_r2_{name}', float('nan')):.4f}, "
                          f"Avg R²: {val_metrics.get(f'val_average_r2_{name}', float('nan')):.4f}")

                # Store the best score for model selection (use custom metrics for consistency)
                current_r2_score_for_best_model = score["r2"] if not np.isnan(score["r2"]) else -float('inf')

                # Update epoch score metrics for tracking (consistent structure)
                epoch_metrics = {
                    'overall': {
                        'r2': score["r2"],
                        'r2_median': overall_r2_custom,
                        'pseudo_r2_median': overall_pseudo_r2_custom,
                        'average_r2_trial_avg': overall_average_r2_custom,
                        'simple_pop_avg': simple_avg_r2,
                        'simple_pop_median': simple_median_r2
                    },
                    'populations': population_scores_filtered,
                }

                # Weights and Biases logging
                if wandb_run:
                    wandb_run.log({
                        'epoch': epoch,
                        'train_overall_r2': score["r2"],
                        'train_overall_pseudo_r2': overall_pseudo_r2_custom,
                        'train_overall_average_r2': overall_average_r2_custom,
                        **{f'train_r2_horizon_{horizon}': score.get(f'r2_horizon_{horizon}', float('nan')) for horizon in range(min(y_pred.shape[1], 3))},
                        **{f'train_pseudo_r2_horizon_{horizon}': score.get(f'pseudo_r2_horizon_{horizon}', float('nan')) for horizon in range(min(y_pred.shape[1], 3))},
                        **{f'train_{name}_r2': score.get(f'r2_{name}', float('nan')) for name in config.get('neurons', {}).keys()},
                        **{f'train_{name}_pseudo_r2': score.get(f'pseudo_r2_{name}', float('nan')) for name in config.get('neurons', {}).keys()},
                        **{f'train_{name}_average_r2': score.get(f'average_r2_{name}', float('nan')) for name in config.get('neurons', {}).keys()}
                    })

                # Check if this is the best model so far
                if current_r2_score_for_best_model > best_avg_r2_score:
                    best_avg_r2_score = current_r2_score_for_best_model
                    best_epoch = epoch
                    # Save best model checkpoint
                    torch.save({
                        'weight': model.state_dict(),
                        'kwargs': kwargs,
                        'config': config,  # Save the model configuration
                        'epoch': epoch,
                        'avg_r2_score': best_avg_r2_score
                    }, checkpoint_path)

                    # Save the corresponding best heatmap
                    if config['heatmap']:
                        try:
                            if model.readout_type == 'monotonic':
                                plt.figure()  # Create a new figure
                                sns.heatmap(model.readout.log_weight.weight.detach().cpu().numpy(), annot=False, cmap='coolwarm', fmt=".2f", cbar=True)
                                plt.title(f"Best Heatmap (Epoch {epoch}, R2: {best_avg_r2_score:.4f})")
                                plt.savefig(best_heatmap_path)
                                plt.close()  # Close the figure
                        except Exception as e:
                            print(f"Warning: Could not save best heatmap for epoch {epoch}: {e}")

            # Log z0 (encoder-estimated initial conditions) at milestones without re-solving
            if epoch in z0_milestones and val_dataloader is not None:
                try:
                    print(f"Logging z0 snapshots at epoch {epoch}...")
                    model.eval()
                    z0_all = []
                    with torch.no_grad():
                        for bout in val_dataloader:
                            if config['causal_model']:
                                encoding_spikes, _ = to_causal_segments(bout, config['ic_window_size'])
                            else:
                                encoding_spikes = bout
                            encoding_spikes = encoding_spikes.to(device)
                            ic_drop, _, _, _, _, _ = model(encoding_spikes)
                            # ic_drop shape: (B, latent_size)
                            z0_all.append(ic_drop.detach().cpu())
                    if len(z0_all) > 0:
                        z0_cat = torch.cat(z0_all, dim=0).numpy()
                        z0_dir = os.path.join(metrics_dir, 'z0')
                        os.makedirs(z0_dir, exist_ok=True)
                        z0_path = os.path.join(z0_dir, f'z0_epoch_{epoch}.npy')
                        np.save(z0_path, z0_cat)
                        print(f"Saved z0 snapshot to {z0_path} (shape {z0_cat.shape})")
                        if wandb_run is not None:
                            try:
                                wandb_run.log({
                                    'epoch': epoch,
                                    'z0_mean': float(np.mean(z0_cat)),
                                    'z0_std': float(np.std(z0_cat))
                                })
                            except Exception:
                                pass
                    model.train()
                except Exception as e:
                    print(f"Warning: Failed to log z0 at epoch {epoch}: {e}")

            # Add the detailed epoch metrics to history
            all_epoch_scores.append(epoch_metrics)  # Changed from avg_epoch_score

            # Store custom metrics for plotting (consistent with validate function)
            all_epoch_old_r2.append(score["r2"])
            all_epoch_old_pseudo_r2.append(score.get("pseudo_r2_E1", float('nan')))  # Use first population as example
            all_epoch_old_regression_r2.append(float('nan'))  # Not calculated in new approach
            all_epoch_old_average_r2.append(score.get("average_r2_E1", float('nan')))  # Use first population as example

    # --- Save Final Model and Heatmap --- #
    print(f"Saving final model to {final_checkpoint_path} and final heatmap to {final_heatmap_path}.")

    # Get the last valid R² score from all_epoch_scores (consistent with new structure)
    final_r2_score = 0.0
    if all_epoch_scores:
        last_metrics = all_epoch_scores[-1]
        if last_metrics and last_metrics.get('overall') and last_metrics['overall'].get('r2') is not None:
            final_r2_score = last_metrics['overall']['r2']
        elif not np.isnan(last_metrics.get('overall', {}).get('r2', 0.0)):
            final_r2_score = last_metrics.get('overall', {}).get('r2', 0.0)

    torch.save({
        'weight': model.state_dict(),
        'kwargs': kwargs,
        'config': config,  # Save the model configuration
        'epoch': epochs - 1,  # Last epoch index
        'avg_r2_score': final_r2_score  # Score from the last calculation
    }, final_checkpoint_path)

    if config['heatmap']:
        try:
            if model.readout_type == 'monotonic':
                plt.figure()  # Create a new figure
                sns.heatmap(model.readout.log_weight.weight.detach().cpu().numpy(), annot=False, cmap='coolwarm', fmt=".2f", cbar=True)
                plt.title(f"Final Heatmap (Epoch {epochs-1}, R2: {final_r2_score:.4f})")
                plt.savefig(final_heatmap_path)
                plt.close()  # Close the figure
        except Exception as e:
            print(f"Warning: Could not save final heatmap: {e}")

    # --- Plot Training and Validation Curves --- #
    try:
        recorded_epochs = [e for e in range(epochs) if e % 10 == 0 or e == epochs - 1]

        # Extract data for plotting
        losses_to_plot = all_epoch_losses
        training_overall_r2_to_plot = [
            s['overall']['r2'] for s in all_epoch_scores
            if s and s.get('overall') and s['overall'].get('r2') is not None and not np.isnan(s['overall']['r2'])
        ]

        # Get validation R2 if available
        validation_overall_r2_to_plot = []
        if val_dataloader is not None and all_val_epoch_scores:
            validation_overall_r2_to_plot = [
                s['r2'] for s in all_val_epoch_scores
                if s and s.get('r2') is not None and not np.isnan(s['r2'])
            ]

        # Get old metrics (custom metrics from metrics.py)
        old_r2_filtered = [r for r in all_epoch_old_r2 if not np.isnan(r)]
        old_pseudo_r2_filtered = [r for r in all_epoch_old_pseudo_r2 if not np.isnan(r)]
        old_regression_r2_filtered = [r for r in all_epoch_old_regression_r2 if not np.isnan(r)]
        old_average_r2_filtered = [r for r in all_epoch_old_average_r2 if not np.isnan(r)]

        plt.figure(figsize=(20, 12))

        # 1. Training Loss
        plt.subplot(2, 4, 1)
        plt.plot(recorded_epochs[:len(losses_to_plot)], losses_to_plot, 'b-', label='Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training Loss')
        plt.legend()
        plt.grid(True)

        # 2. Overall R² Comparison (Training vs Validation)
        plt.subplot(2, 4, 2)
        if training_overall_r2_to_plot:
            plt.plot(recorded_epochs[:len(training_overall_r2_to_plot)], training_overall_r2_to_plot, 'b-', label='Train Custom R²')
        if old_r2_filtered:
            plt.plot(recorded_epochs[:len(old_r2_filtered)], old_r2_filtered, 'r--', label='Custom R²')
        if validation_overall_r2_to_plot:
            plt.plot(recorded_epochs[:len(validation_overall_r2_to_plot)], validation_overall_r2_to_plot, 'm-', label='Val Custom R²')
        plt.xlabel('Epoch')
        plt.ylabel('R² Score')
        plt.title('Overall R² Comparison')
        plt.legend()
        plt.grid(True)

        # 3. Custom Overall Metrics from metrics.py
        plt.subplot(2, 4, 3)

        # Extract overall custom metrics from epoch_metrics
        overall_r2_median_plot = [
            s['overall']['r2_median'] for s in all_epoch_scores
            if s and s.get('overall') and s['overall'].get('r2_median') is not None and not np.isnan(s['overall']['r2_median'])
        ]
        overall_pseudo_r2_median_plot = [
            s['overall']['pseudo_r2_median'] for s in all_epoch_scores
            if s and s.get('overall') and s['overall'].get('pseudo_r2_median') is not None and not np.isnan(s['overall']['pseudo_r2_median'])
        ]
        overall_average_r2_trial_avg_plot = [
            s['overall']['average_r2_trial_avg'] for s in all_epoch_scores
            if s and s.get('overall') and s['overall'].get('average_r2_trial_avg') is not None and not np.isnan(s['overall']['average_r2_trial_avg'])
        ]
        simple_pop_avg_plot = [
            s['overall']['simple_pop_avg'] for s in all_epoch_scores
            if s and s.get('overall') and s['overall'].get('simple_pop_avg') is not None and not np.isnan(s['overall']['simple_pop_avg'])
        ]
        simple_pop_median_plot = [
            s['overall']['simple_pop_median'] for s in all_epoch_scores
            if s and s.get('overall') and s['overall'].get('simple_pop_median') is not None and not np.isnan(s['overall']['simple_pop_median'])
        ]

        if overall_r2_median_plot:
            plt.plot(recorded_epochs[:len(overall_r2_median_plot)], overall_r2_median_plot, 'r-', label='R² (median)', linewidth=2)
        if overall_pseudo_r2_median_plot:
            plt.plot(recorded_epochs[:len(overall_pseudo_r2_median_plot)], overall_pseudo_r2_median_plot, 'b-', label='Pseudo R² (median)', linewidth=2)
        if overall_average_r2_trial_avg_plot:
            plt.plot(recorded_epochs[:len(overall_average_r2_trial_avg_plot)], overall_average_r2_trial_avg_plot, 'g-', label='Avg R² (trial-avg)', linewidth=2)
        if simple_pop_avg_plot:
            plt.plot(recorded_epochs[:len(simple_pop_avg_plot)], simple_pop_avg_plot, 'm--', label='Simple Pop Avg', linewidth=1)
        if simple_pop_median_plot:
            plt.plot(recorded_epochs[:len(simple_pop_median_plot)], simple_pop_median_plot, 'c--', label='Simple Pop Median', linewidth=1)

        plt.xlabel('Epoch')
        plt.ylabel('Score')
        plt.title('Overall Custom Metrics (from metrics.py)')
        plt.legend(fontsize=8)
        plt.grid(True)

        # 4-7. Population-specific R² scores (Custom Metrics)
        neuron_populations = list(config['neurons'].keys())
        for idx, pop_name in enumerate(neuron_populations[:4]):  # Plot up to 4 populations
            plt.subplot(2, 4, 4 + idx)

            # Extract population-specific Custom scores
            pop_r2_scores = []
            pop_pseudo_r2_scores = []
            pop_average_r2_scores = []

            for epoch_metrics in all_epoch_scores:
                if epoch_metrics and epoch_metrics.get('populations') and epoch_metrics['populations'].get(pop_name):
                    pop_data = epoch_metrics['populations'][pop_name]

                    r2_scores = pop_data.get('r2')
                    pseudo_r2_scores = pop_data.get('pseudo_r2')
                    average_r2_scores = pop_data.get('average_r2')

                    pop_r2_scores.append(r2_scores if r2_scores is not None and not np.isnan(r2_scores) else None)
                    pop_pseudo_r2_scores.append(pseudo_r2_scores if pseudo_r2_scores is not None and not np.isnan(pseudo_r2_scores) else None)
                    pop_average_r2_scores.append(average_r2_scores if average_r2_scores is not None and not np.isnan(average_r2_scores) else None)
                else:
                    pop_r2_scores.append(None)
                    pop_pseudo_r2_scores.append(None)
                    pop_average_r2_scores.append(None)

            # Plot custom population scores
            epochs_to_plot = recorded_epochs[:len(pop_r2_scores)]

            # Custom R² scores
            valid_r2 = [(e, s) for e, s in zip(epochs_to_plot, pop_r2_scores) if s is not None]
            if valid_r2:
                epochs_r2, scores_r2 = zip(*valid_r2)
                plt.plot(epochs_r2, scores_r2, 'b-', label='Custom R²', linewidth=2)

            # Custom Pseudo R² scores
            valid_pseudo_r2 = [(e, s) for e, s in zip(epochs_to_plot, pop_pseudo_r2_scores) if s is not None]
            if valid_pseudo_r2:
                epochs_pseudo, scores_pseudo = zip(*valid_pseudo_r2)
                plt.plot(epochs_pseudo, scores_pseudo, 'g--', label='Custom Pseudo R²', linewidth=2)

            # Custom Average R² scores
            valid_avg_r2 = [(e, s) for e, s in zip(epochs_to_plot, pop_average_r2_scores) if s is not None]
            if valid_avg_r2:
                epochs_avg, scores_avg = zip(*valid_avg_r2)
                plt.plot(epochs_avg, scores_avg, 'r:', label='Custom Avg R²', linewidth=2)

            # Add validation population scores if available (Custom metrics)
            if val_dataloader is not None and all_val_epoch_scores:
                val_r2_scores = []
                val_pseudo_r2_scores = []
                val_avg_r2_scores = []

                for val_scores in all_val_epoch_scores:
                    val_r2 = val_scores.get(f'r2_{pop_name}')
                    val_pseudo = val_scores.get(f'pseudo_r2_{pop_name}')
                    val_avg = val_scores.get(f'average_r2_{pop_name}')

                    val_r2_scores.append(val_r2 if val_r2 is not None and not np.isnan(val_r2) else None)
                    val_pseudo_r2_scores.append(val_pseudo if val_pseudo is not None and not np.isnan(val_pseudo) else None)
                    val_avg_r2_scores.append(val_avg if val_avg is not None and not np.isnan(val_avg) else None)

                # Plot validation scores
                val_epochs = recorded_epochs[:len(val_r2_scores)]

                valid_val_r2 = [(e, s) for e, s in zip(val_epochs, val_r2_scores) if s is not None]
                if valid_val_r2:
                    epochs_val_r2, scores_val_r2 = zip(*valid_val_r2)
                    plt.plot(epochs_val_r2, scores_val_r2, 'c-', label='Val Custom R²', linewidth=1.5, alpha=0.8)

                valid_val_pseudo = [(e, s) for e, s in zip(val_epochs, val_pseudo_r2_scores) if s is not None]
                if valid_val_pseudo:
                    epochs_val_pseudo, scores_val_pseudo = zip(*valid_val_pseudo)
                    plt.plot(epochs_val_pseudo, scores_val_pseudo, 'm--', label='Val Custom Pseudo R²', linewidth=1.5, alpha=0.8)

            plt.xlabel('Epoch')
            plt.ylabel('R² Score')
            plt.title(f'{pop_name} Population (Custom Metrics)')
            plt.legend(fontsize=8)
            plt.grid(True)

        plt.suptitle(f'Training and Validation Curves (Custom Metrics)\nBest R²: {best_avg_r2_score:.4f} (Epoch {best_epoch})', fontsize=16)
        plt.tight_layout()

        # Save the plot
        plt.savefig(training_curves_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Training and validation curves plot saved to {training_curves_path}")

    except Exception as e:
        print(f"Error creating training curves plot: {e}")
        import traceback
        traceback.print_exc()

    # (TorchEval plots removed)

    # --- Plot Focused Training Curves (Epoch 100+) --- #
    try:
        # Filter data for epochs 100 and onwards
        epoch_100_start_idx = next((i for i, e in enumerate(recorded_epochs) if e >= 100), 0)
        focused_epochs = recorded_epochs[epoch_100_start_idx:]

        if len(focused_epochs) > 0:
            # Custom metrics focused plot
            plt.figure(figsize=(20, 12))

            # 1. Training Loss (focused)
            plt.subplot(2, 4, 1)
            focused_losses = all_epoch_losses[epoch_100_start_idx:]
            if focused_losses:
                plt.plot(focused_epochs[:len(focused_losses)], focused_losses, 'b-', label='Training Loss')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.title('Training Loss (Epoch 100+)')
            plt.legend()
            plt.grid(True)

            # 2. Overall R² (focused)
            plt.subplot(2, 4, 2)
            focused_training_r2 = training_overall_r2_to_plot[epoch_100_start_idx:] if len(training_overall_r2_to_plot) > epoch_100_start_idx else []
            focused_validation_r2 = validation_overall_r2_to_plot[epoch_100_start_idx:] if len(validation_overall_r2_to_plot) > epoch_100_start_idx else []

            if focused_training_r2:
                plt.plot(focused_epochs[:len(focused_training_r2)], focused_training_r2, 'b-', label='Train Custom R²')
            if focused_validation_r2:
                plt.plot(focused_epochs[:len(focused_validation_r2)], focused_validation_r2, 'm-', label='Val Custom R²')
            plt.xlabel('Epoch')
            plt.ylabel('R² Score')
            plt.title('Overall R² (Epoch 100+)')
            plt.legend()
            plt.grid(True)

            # 3. Custom Overall Metrics (focused)
            plt.subplot(2, 4, 3)

            # Extract focused overall custom metrics
            focused_r2_median = overall_r2_median_plot[epoch_100_start_idx:] if len(overall_r2_median_plot) > epoch_100_start_idx else []
            focused_pseudo_r2_median = overall_pseudo_r2_median_plot[epoch_100_start_idx:] if len(overall_pseudo_r2_median_plot) > epoch_100_start_idx else []
            focused_average_r2_trial_avg = overall_average_r2_trial_avg_plot[epoch_100_start_idx:] if len(overall_average_r2_trial_avg_plot) > epoch_100_start_idx else []
            focused_simple_pop_avg = simple_pop_avg_plot[epoch_100_start_idx:] if len(simple_pop_avg_plot) > epoch_100_start_idx else []
            focused_simple_pop_median = simple_pop_median_plot[epoch_100_start_idx:] if len(simple_pop_median_plot) > epoch_100_start_idx else []

            if focused_r2_median:
                plt.plot(focused_epochs[:len(focused_r2_median)], focused_r2_median, 'r-', label='R² (median)', linewidth=2)
            if focused_pseudo_r2_median:
                plt.plot(focused_epochs[:len(focused_pseudo_r2_median)], focused_pseudo_r2_median, 'b-', label='Pseudo R² (median)', linewidth=2)
            if focused_average_r2_trial_avg:
                plt.plot(focused_epochs[:len(focused_average_r2_trial_avg)], focused_average_r2_trial_avg, 'g-', label='Avg R² (trial-avg)', linewidth=2)
            if focused_simple_pop_avg:
                plt.plot(focused_epochs[:len(focused_simple_pop_avg)], focused_simple_pop_avg, 'm--', label='Simple Pop Avg', linewidth=1)
            if focused_simple_pop_median:
                plt.plot(focused_epochs[:len(focused_simple_pop_median)], focused_simple_pop_median, 'c--', label='Simple Pop Median', linewidth=1)

            plt.xlabel('Epoch')
            plt.ylabel('Score')
            plt.title('Overall Custom Metrics (Epoch 100+)')
            plt.legend(fontsize=8)
            plt.grid(True)

            # 4-7. Population-specific R² scores (focused)
            for idx, pop_name in enumerate(neuron_populations[:4]):
                plt.subplot(2, 4, 4 + idx)

                # Extract focused population data
                focused_pop_r2 = []
                focused_pop_pseudo_r2 = []
                focused_pop_avg_r2 = []

                for i in range(epoch_100_start_idx, min(len(all_epoch_scores), len(recorded_epochs))):
                    if i < len(all_epoch_scores) and all_epoch_scores[i] and all_epoch_scores[i].get('populations') and all_epoch_scores[i]['populations'].get(pop_name):
                        pop_data = all_epoch_scores[i]['populations'][pop_name]
                        focused_pop_r2.append(pop_data.get('r2'))
                        focused_pop_pseudo_r2.append(pop_data.get('pseudo_r2'))
                        focused_pop_avg_r2.append(pop_data.get('average_r2'))
                    else:
                        focused_pop_r2.append(None)
                        focused_pop_pseudo_r2.append(None)
                        focused_pop_avg_r2.append(None)

                # Plot focused population scores
                valid_r2 = [(e, s) for e, s in zip(focused_epochs[:len(focused_pop_r2)], focused_pop_r2) if s is not None and not np.isnan(s)]
                valid_pseudo_r2 = [(e, s) for e, s in zip(focused_epochs[:len(focused_pop_pseudo_r2)], focused_pop_pseudo_r2) if s is not None and not np.isnan(s)]
                valid_avg_r2 = [(e, s) for e, s in zip(focused_epochs[:len(focused_pop_avg_r2)], focused_pop_avg_r2) if s is not None and not np.isnan(s)]

                if valid_r2:
                    epochs_r2, scores_r2 = zip(*valid_r2)
                    plt.plot(epochs_r2, scores_r2, 'b-', label='Custom R²', linewidth=2)
                if valid_pseudo_r2:
                    epochs_pseudo, scores_pseudo = zip(*valid_pseudo_r2)
                    plt.plot(epochs_pseudo, scores_pseudo, 'g--', label='Custom Pseudo R²', linewidth=2)
                if valid_avg_r2:
                    epochs_avg, scores_avg = zip(*valid_avg_r2)
                    plt.plot(epochs_avg, scores_avg, 'r:', label='Custom Avg R²', linewidth=2)

                plt.xlabel('Epoch')
                plt.ylabel('R² Score')
                plt.title(f'{pop_name} Population (Epoch 100+)')
                plt.legend(fontsize=8)
                plt.grid(True)

            plt.suptitle(f'Focused Training Curves (Epoch 100+)\nBest R²: {best_avg_r2_score:.4f} (Epoch {best_epoch})', fontsize=16)
            plt.tight_layout()
            plt.savefig(training_curves_focused_path, dpi=300, bbox_inches='tight')
            plt.close()
            print(f"Focused training curves plot saved to {training_curves_focused_path}")

            # (TorchEval focused plots removed)
        else:
            print("No data available for focused plots (epoch 100+)")

    except Exception as e:
        print(f"Error creating focused training curves plots: {e}")
        import traceback
        traceback.print_exc()

    # --- Final Log Saving --- #
    # epoch_metrics here will be from the last reporting epoch
    last_epoch_metrics = all_epoch_scores[-1] if all_epoch_scores else {}

    final_scores_and_logs = {
        'best_avg_r2_score': best_avg_r2_score,
        'best_epoch': best_epoch,
        'best_checkpoint_path': checkpoint_path,
        'best_heatmap_path': best_heatmap_path,
        'final_checkpoint_path': final_checkpoint_path,
        'final_heatmap_path': final_heatmap_path,
        'training_curves_path': training_curves_path,
        'training_curves_focused_path': training_curves_focused_path,
        'all_epoch_losses': all_epoch_losses,
        'all_epoch_scores': all_epoch_scores,
        'all_epoch_old_r2': all_epoch_old_r2,
        'all_epoch_old_pseudo_r2': all_epoch_old_pseudo_r2,
        'all_epoch_old_regression_r2': all_epoch_old_regression_r2,
        'all_epoch_old_average_r2': all_epoch_old_average_r2,
        'kwargs': {k: str(v) if isinstance(v, torch.device) else v for k, v in kwargs.items()},
        'population_metrics': last_epoch_metrics,  # Add population-specific metrics from the last reporting epoch
        'val_epoch_scores': all_val_epoch_scores,
    }

    try:
        with open(log_file_path, "w") as f:
            json.dump(final_scores_and_logs, f, indent=4)
        print(f"Training logs saved to {log_file_path}")

        # Save metrics separately
        metrics_data = {
            'losses': all_epoch_losses,
            'scores': all_epoch_scores,
            'best_r2': best_avg_r2_score,
            'best_epoch': best_epoch,
            'training_curves_path': training_curves_path,
            'training_curves_focused_path': training_curves_focused_path
        }
        with open(metrics_file_path, "w") as f:
            json.dump(metrics_data, f, indent=4)
        print(f"Training metrics saved to {metrics_file_path}")
    except Exception as e:
        print(f"Error saving log file {log_file_path}: {e}")
    # ------------------------

    # Return the results based on the best model found
    result_data = {
        'best_avg_r2_score': best_avg_r2_score,
        'best_epoch': best_epoch
    }
    result = pd.DataFrame([result_data])

    return model, result


def validate(model, val_data, config=None, device=None, 
             calculate_metrics=True, verbose=True,
             use_ic_solver=None, ic_steps=None, ic_l2=None):
    """
    Validate model performance with comprehensive metrics.
    TODO: Make sure this works with simulated Wang dataset now.

    Args:
        model: The model to validate
        dataloader: Validation data loader
        config: Configuration dictionary containing 'neurons' mapping and other settings
        device: Device to run validation on (auto-detected if None)
        save_csv: Whether to save results to CSV
        verbose: Whether to print detailed scores

    Returns:
        tuple: (pred_latents, pred_rates, decoding_spikes, ic_drop, pred_logrates, influence, mean_scores)
    """
    # Device management
    if device is None:
        device = next(model.parameters()).device
    model = model.to(device)
    model.eval()

    scores = []

    with torch.no_grad():
        # Initialize lists to store predictions and dict for results
        all_pred_rates = []
        all_pred_logrates = []
        all_pred_latents = []
        all_decoding_spikes = []
        val_scores = {}

        # Determine whether to use the IC solver and its hyperparameters
        use_solver = bool(use_ic_solver) if use_ic_solver is not None else bool(config.get('use_ic_solver', False))
        ic_opt_steps = int(ic_steps) if ic_steps is not None else int(config.get('ic_steps', 200))
        ic_opt_l2 = float(ic_l2) if ic_l2 is not None else float(config.get('ic_l2', 1e-2))
        ic_opt_lr = float(config.get('ic_lr', 0.05))
        ic_loss_type = str(config.get('ic_loss_type', 'poisson')).lower()  # 'poisson' or 'mse'
        ic_loss_horizon = int(config.get('ic_loss_horizon', 1))            # 1 => single timepoint (t0), >1 => multiple
        ic_opt_mode = str(config.get('ic_solver_mode', 'direct')).lower()  # 'direct' or 'mlp'
        ic_mlp_hidden = int(config.get('ic_mlp_hidden', 0))

        # Prepare optional aggregation settings for IC solver
        agg_mode = str(config.get('ic_aggregate', 'none')).lower()
        spans_list = [(int(s), int(e)) for (_name, (s, e)) in config.get('neurons', {}).items()] if agg_mode != 'none' else None

        # Initialize running batch loss
        running_batch_loss = 0.0
        for bout in val_data:
            # Add batch dim if data is unwindowed
            bout = bout.unsqueeze(0) if bout.dim() == 2 else bout
            # Get windows from bout
            if use_solver: # Expects perturbation to be first bin of perturb dataset
                # Use first bin for z0 evidence; targets for loss determined by horizon
                encoding_spikes = bout[:, :1, :]
                # For metrics/plotting we'll still evaluate on the decoded segment
                decoding_spikes = bout[:, 1:, :]
            else:
                if config['causal_model']:
                    encoding_spikes, decoding_spikes = to_causal_segments(
                        bout, config['ic_window_size']
                    )
                else:
                    encoding_spikes, decoding_spikes = bout, bout

            encoding_spikes = encoding_spikes.to(device)
            decoding_spikes = decoding_spikes.to(device)


            if not use_solver:
                # Forward Pass via encoder-inferred ICs
                ic_drop, pred_logrates, pred_latents, _, _, _ = model(encoding_spikes)
                pred_rates = torch.exp(pred_logrates)
            else:
                # IC-optimized forward: solve z0 from the first time bin only
                win_latents = []
                win_logrates = []
                win_rates = []
                for w in range(encoding_spikes.shape[0]):
                    first_bin = encoding_spikes[w, :1, :]  # (1, O) -> t=0
                    # Enable gradients for the solver while keeping model params frozen
                    with torch.enable_grad():
                        if ic_opt_mode == 'mlp':
                            z0 = solve_ic_mlp(
                                model,
                                first_bin,
                                steps=ic_opt_steps,
                                lr=ic_opt_lr,
                                lam=ic_opt_l2,
                                hidden=ic_mlp_hidden,
                                aggregate=agg_mode,
                                spans=spans_list,
                            )
                        else:
                            # Build the target window for the requested horizon
                            if ic_loss_horizon <= 1:
                                # Single timepoint loss on t0
                                target_window = first_bin.squeeze(0).view(1, -1)  # (1, O)
                                # Use direct solver for IC-only fit
                                z0 = solve_ic_direct(
                                    model,
                                    target_window,
                                    lam=ic_opt_l2,
                                    steps=ic_opt_steps,
                                    aggregate=agg_mode,
                                    spans=spans_list,
                                    loss_type=ic_loss_type,
                                )
                            else:
                                # Multi-step rollout loss on t=1..H using dynamics
                                avail = bout[w:w+1, 1:1+ic_loss_horizon, :].squeeze(0)  # (H, O)
                                z0 = solve_ic_rollout(
                                    model,
                                    avail,
                                    steps=ic_opt_steps,
                                    lr=ic_opt_lr,
                                    lam=ic_opt_l2,
                                    loss_type=ic_loss_type,
                                    aggregate=agg_mode,
                                    spans=spans_list,
                                )
                    B = 1
                    z0_b = z0.view(B, -1)
                    # Rollout for predictions (use full remaining window)
                    T_pred = max(bout.shape[1] - 1, 1)
                    u = torch.zeros((B, T_pred, 1), device=z0_b.device)
                    latents, _, _, _, _ = model.decoder(u, z0_b)
                    logrates = model._apply_readout(latents)
                    rates = torch.exp(logrates)
                    win_latents.append(latents)
                    win_logrates.append(logrates)
                    win_rates.append(rates)
                pred_latents = torch.cat(win_latents, dim=0)
                pred_logrates = torch.cat(win_logrates, dim=0)
                pred_rates = torch.cat(win_rates, dim=0)


            # Add to running list
            all_pred_rates.append(pred_rates)
            all_pred_logrates.append(pred_logrates)
            all_pred_latents.append(pred_latents)
            all_decoding_spikes.append(decoding_spikes)

            # Calculate validation loss
            val_loss_batch = F.poisson_nll_loss(
                pred_logrates, decoding_spikes, full=True, reduction="mean"
            )
            running_batch_loss += val_loss_batch.item()

        # Calculate total loss
        val_scores['val_loss'] = running_batch_loss / len(val_data)
            
        if calculate_metrics:
            # Concatenate rates and spikes
            bout_pred_rates = torch.cat(all_pred_rates, dim=0)
            bout_decoding_spikes = torch.cat(all_decoding_spikes, dim=0)

            # Primary metrics using custom functions
            val_scores['val_r2'] = apply_metric(r2_score,
                                                bout_pred_rates.cpu(),
                                                bout_decoding_spikes.cpu())

            val_scores['val_average_r2'] = apply_metric(average_r2_score,
                                                        bout_pred_rates.cpu(),
                                                        bout_decoding_spikes.cpu())
            val_scores['val_pseudo_r2'] = apply_metric(pseudo_r2_score,
                                                       bout_pred_rates.cpu(),
                                                       bout_decoding_spikes.cpu())

            # Population-specific custom metrics
            for name, (start, end) in config.get('neurons', {}).items():
                # Get predicted and true population activity for entire bout
                bout_pred_slice = bout_pred_rates[:, :, start:end].cpu()
                bout_spike_slice = bout_decoding_spikes[:, :, start:end].cpu()
                
                # Calculate R2 Score
                val_scores[f'val_r2_{name}'] = apply_metric(r2_score,
                                                            bout_pred_slice,
                                                            bout_spike_slice)
                # Calculate Pseudo R2 Score
                val_scores[f"val_pseudo_r2_{name}"] = \
                    apply_metric(pseudo_r2_score,
                                 bout_pred_slice,
                                 bout_spike_slice)
                
                # Calculate average R2 Score
                val_scores[f'val_average_r2_{name}'] = \
                    apply_metric(average_r2_score,
                                 bout_pred_slice,
                                 bout_spike_slice)

                if config['data_type'] != 'wang_100T':
                    # Prediction horizon metrics for each population
                    for horizon in range(min(bout_pred_rates.shape[1], 3)):
                        # Get Nth prediction and target
                        pred_horizon = bout_pred_rates[:, horizon, start:end].squeeze()
                        tgt_horizon = bout_decoding_spikes[:, horizon, start:end].squeeze()
                        # Calculate R2 Scores
                        val_scores[f'val_r2_{name}_horizon_{horizon}'] = \
                            apply_metric(r2_score,
                                        pred_horizon,
                                        tgt_horizon)
                        val_scores[f'val_pseudo_r2_{name}_horizon_{horizon}'] = \
                            apply_metric(pseudo_r2_score,
                                        pred_horizon,
                                        tgt_horizon)
                    
            if config['data_type'] != 'wang_100T':
                # Prediction horizon metrics for total population
                for horizon in range(min(bout_pred_rates.shape[1], 3)):
                    # Get Nth prediction and target
                    pred_horizon = bout_pred_rates[:, horizon, :].squeeze()
                    tgt_horizon = bout_decoding_spikes[:, horizon, :].squeeze()
                    # Calculate R2 Scores
                    val_scores[f'val_r2_horizon_{horizon}'] = \
                        apply_metric(r2_score,
                                    pred_horizon,
                                    tgt_horizon)
                    val_scores[f'val_pseudo_r2_horizon_{horizon}'] = \
                        apply_metric(pseudo_r2_score,
                                    pred_horizon,
                                    tgt_horizon)

        # Print results
        if verbose and calculate_metrics:
            print("Validation Scores:")
            for key, value in val_scores.items():
                if not np.isnan(value):
                    print(f"  {key}: {value:.4f}")

    return (all_pred_latents, all_pred_rates, all_decoding_spikes,
            all_pred_logrates, val_scores)

