#!/usr/bin/env python3
"""
Unified training script for Neural ODE/RNN experiments.

This script replaces all the individual architecture-specific training scripts
with a single configurable entry point.
"""

import torch
import wandb
import numpy as np
import os
import sys
import hydra
import matplotlib.pyplot as plt
import pickle
# from hydra import initialize, compose
from typing import Dict, Any

# Add project root to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from training.utils import merge_windows
from metrics import apply_metric, r2_score, pseudo_r2_score
from utils.args import build_experiment_config, get_experiment_name
from utils.model_factory import (create_model, create_optimizer,
                                 setup_directories, load_model_from_checkpoint,
                                 save_experiment_config)
from data import NeuralData, to_causal_segments
from training.train import train, validate


def load_dataset(config: Dict[str, Any]) -> tuple:
    """Load and preprocess dataset.
    Supports npy and pkl, returns windowed, causal segments if configured.
    """
    dataset_path = config['dataset_path']
    print(f"Loading dataset from: {dataset_path}")
    
    # Load dataset based on file type
    if dataset_path.endswith('.npy'):
        dataset = np.load(dataset_path, allow_pickle=True)
    elif dataset_path.endswith('.pkl'):
        # Load pickle file
        with open(dataset_path, 'rb') as f:
            dataset = pickle.load(f)

        # Support both flat and nested dict dataset formats
        # Nested format: {'populations': {pop: [np.ndarray(T,N_pop), ...]}, 'indices': {...}, ...}
        if isinstance(dataset, dict) and 'populations' in dataset:
            pops = dataset['populations']
            # Determine number of trials from first population
            first_key = next(iter(pops.keys()))
            num_trials = len(pops[first_key])

            combined_data = []
            for t in range(num_trials):
                trial_data = []
                # Use target_neurons if provided to enforce per-pop sizes/order
                pop_order = list(config.get('target_neurons', config['neurons']).keys())
                for pname in pop_order:
                    arr = pops[pname][t]
                    if 'target_neurons' in config:
                        start, end = config['target_neurons'][pname]
                        n_target = int(end - start)
                        # Trim overflow or pad if needed
                        if arr.shape[1] > n_target:
                            if t == 0:
                                print(f"Alignment fallback: trimming population {pname} from {arr.shape[1]} to {n_target} neurons")
                            arr = arr[:, :n_target]
                        elif arr.shape[1] < n_target:
                            if t == 0:
                                print(f"Alignment fallback: padding population {pname} from {arr.shape[1]} to {n_target} neurons ({n_target - arr.shape[1]} zeros)")
                            pad = np.zeros((arr.shape[0], n_target - arr.shape[1]), dtype=arr.dtype)
                            arr = np.concatenate([arr, pad], axis=1)
                    trial_data.append(arr)
                # Concatenate populations along neuron dimension
                combined_data.append(np.concatenate(trial_data, axis=1))

            dataset = combined_data
        else:
            # Flat format: {pop: [np.ndarray(T,N_pop), ...]}
            first_key = next(iter(dataset.keys()))
            num_trials = len(dataset[first_key])
            combined_data = []
            for t in range(num_trials):
                trial_data = []
                pop_order = list(config.get('target_neurons', config['neurons']).keys())
                for pname in pop_order:
                    arr = dataset[pname][t]
                    if 'target_neurons' in config:
                        start, end = config['target_neurons'][pname]
                        n_target = int(end - start)
                        if arr.shape[1] > n_target:
                            if t == 0:
                                print(f"Alignment fallback: trimming population {pname} from {arr.shape[1]} to {n_target} neurons")
                            arr = arr[:, :n_target]
                        elif arr.shape[1] < n_target:
                            if t == 0:
                                print(f"Alignment fallback: padding population {pname} from {arr.shape[1]} to {n_target} neurons ({n_target - arr.shape[1]} zeros)")
                            pad = np.zeros((arr.shape[0], n_target - arr.shape[1]), dtype=arr.dtype)
                            arr = np.concatenate([arr, pad], axis=1)
                    trial_data.append(arr)
                combined_data.append(np.concatenate(trial_data, axis=1))

            dataset = combined_data

    else:
        raise ValueError(f"Unsupported dataset format: {dataset_path}")
    
    # Create data module

    data_module = NeuralData(config)
    data_module.setup(dataset)
    
    return data_module.train_dl, data_module.val_windows


def generate_plots(model, val_dl, config: Dict[str, Any], wandb_run=None,
                   validate_only=False):
    """Generate plots after training."""
    try:
        from plot import (plot_prediction, plot_flowfield,
                          plot_heatmaps, plot_prediction_by_pseudo_r2)
        import torch.nn.functional as F
        
        print("Generating plots...")
        
        # Set up config for plotting with required fields
        plot_config = config.copy()
        plot_config['data'] = 'val'
        
        # Ensure data_type is set (required by get_spike function)
        if 'data_type' not in plot_config:
            plot_config['data_type'] = config['data_type']
        
        # Ensure BASE_PATH is set (should be set from setup_directories)
        if 'BASE_PATH' not in plot_config:
            plot_config['BASE_PATH'] = config.get('output_dir', './results')
            print("Warning: BASE_PATH not set, using fallback: " \
                  f"{plot_config['BASE_PATH']}")
        
        # Add additional fields that plotting functions might need
        plot_config.setdefault('epoch', 'final')
        plot_config.setdefault('lr', config.get('lr', 'unknown'))
        plot_config.setdefault('encoder_size',
                               config.get('encoder_size', 'unknown'))
        plot_config.setdefault('l1_reg', config.get('l1_reg', None))
        
        # Get validation predictions
        pred_latents, pred_rates, spikes, pred_logrates, val_metrics = validate(
            model, val_dl, config=plot_config, device=config.get('device'),
            verbose=False
        )

        # Note: IC solver is used only within validate() to produce predictions; no re-solving here.

        # Prepare containers for optional wandb logging
        flowfield_plots = []
        heatmap_plots = []
        
        if isinstance(pred_latents, torch.Tensor):
            pred_latents = pred_latents.cpu().numpy()
        
        # --- Per-population decoded prediction plots (use existing helpers over all bouts) ---
        print(f"Generating per-population decoded prediction plots...")
        pred_plots = []
        lograte_plots = []
        out_pred_dir = os.path.join(plot_config['BASE_PATH'], 'prediction')
        os.makedirs(out_pred_dir, exist_ok=True)
        for name, (start, end) in config['neurons'].items():
            print(f"  - Generating decoded prediction for: {name}")
            try:
                # These helpers accept the full spikes/preds structure
                fig_pred = plot_prediction(spikes, pred_rates, config=plot_config, target=name)
                # Save to 'prediction' directory
                fig_pred.savefig(os.path.join(out_pred_dir, f'{name}_prediction.png'))
                if wandb_run:
                    pred_plots.append(wandb.Image(fig_pred))
                plt.close(fig_pred)
            except Exception as e:
                print(f"    ✗ Error plotting decoded prediction for {name}: {e}")
                if wandb_run:
                    pred_plots.append(wandb.Image(plt.figure()))

        print(f"Generating per-bout trial-average decoded traces...")
        print(f"Populations found: {list(config['neurons'].keys())}")

        # Do not reset pred_plots/lograte_plots here to preserve earlier figures

        # Helper: average across selected neuron spans and across windows
        def _trial_avg_decoded(windows_3d, spans):
            # Accepts either [num_windows, T_decoded, N] or [T_decoded, N]
            if isinstance(windows_3d, torch.Tensor):
                arr = windows_3d.detach().cpu().numpy()
            else:
                arr = windows_3d
            if arr is None:
                return None
            if arr.ndim == 3:
                # Average across selected neurons, then across windows
                parts = [arr[:, :, int(s):int(e)] for (s, e) in spans]
                if len(parts) == 0:
                    return None
                concat = np.concatenate(parts, axis=2)  # [num_windows, T_decoded, N_sel]
                return concat.mean(axis=2).mean(axis=0)  # [T_decoded]
            elif arr.ndim == 2:
                # Single decoded segment [T_decoded, N]
                parts = [arr[:, int(s):int(e)] for (s, e) in spans]
                if len(parts) == 0:
                    return None
                concat = np.concatenate(parts, axis=1)  # [T_decoded, N_sel]
                return concat.mean(axis=1)  # [T_decoded]
            else:
                print(f"Unexpected decoded array dims: {arr.ndim}, shape: {getattr(arr, 'shape', None)}")
                return None

        # --- Per-population trial-average decoded traces ---
        os.makedirs(os.path.join(plot_config['BASE_PATH'], 'trial_avg_population'), exist_ok=True)
        for pop_name, (start, end) in config['neurons'].items():
            print(f"  - Population: {pop_name}")
            spans = [(int(start), int(end))]
            # pick up to 16 bouts
            max_plots = 16
            idxs = np.linspace(0, len(spikes) - 1, min(max_plots, len(spikes)), dtype=int) if len(spikes) > max_plots else np.arange(len(spikes))
            n_cols = 4
            n_rows = int(np.ceil(len(idxs) / n_cols)) if len(idxs) > 0 else 1
            fig, axs = plt.subplots(n_rows, n_cols, figsize=(20, n_rows * 3))
            axs = np.atleast_2d(axs)
            for k, bi in enumerate(idxs):
                r, c = divmod(k, n_cols)
                y_true = _trial_avg_decoded(spikes[bi], spans)
                y_pred = _trial_avg_decoded(pred_rates[bi], spans)
                axs[r, c].plot(y_true, label='True')
                axs[r, c].plot(y_pred, label='Pred', linestyle='--')
                axs[r, c].set_title(f'Bout {bi+1}')
                axs[r, c].set_xlabel('Decoded time')
            for k in range(len(idxs), n_rows * n_cols):
                r, c = divmod(k, n_cols)
                fig.delaxes(axs[r, c])
            plt.tight_layout()
            plt.legend(loc='upper right')
            fig.savefig(os.path.join(plot_config['BASE_PATH'], 'trial_avg_population', f'{pop_name}.png'))

        # --- Per-region trial-average decoded traces ---
        os.makedirs(os.path.join(plot_config['BASE_PATH'], 'trial_avg_region'), exist_ok=True)
        region_to_spans = {}
        for pop_name, (start, end) in config['neurons'].items():
            region = pop_name.split('_')[0]
            region_to_spans.setdefault(region, []).append((int(start), int(end)))

        for region, spans in region_to_spans.items():
            print(f"  - Region: {region}")
            max_plots = 16
            idxs = np.linspace(0, len(spikes) - 1, min(max_plots, len(spikes)), dtype=int) if len(spikes) > max_plots else np.arange(len(spikes))
            n_cols = 4
            n_rows = int(np.ceil(len(idxs) / n_cols)) if len(idxs) > 0 else 1
            fig, axs = plt.subplots(n_rows, n_cols, figsize=(20, n_rows * 3))
            axs = np.atleast_2d(axs)
            for k, bi in enumerate(idxs):
                r, c = divmod(k, n_cols)
                y_true = _trial_avg_decoded(spikes[bi], spans)
                y_pred = _trial_avg_decoded(pred_rates[bi], spans)
                axs[r, c].plot(y_true, label='True')
                axs[r, c].plot(y_pred, label='Pred', linestyle='--')
                axs[r, c].set_title(f'Bout {bi+1}')
                axs[r, c].set_xlabel('Decoded time')
            for k in range(len(idxs), n_rows * n_cols):
                r, c = divmod(k, n_cols)
                fig.delaxes(axs[r, c])
            plt.tight_layout()
            plt.legend(loc='upper right')
            fig.savefig(os.path.join(plot_config['BASE_PATH'], 'trial_avg_region', f'{region}.png'))

        # --- IC fit scatter (validate-only, using encoder-inferred z0; no re-solving) ---
        try:
            # Build IC scatter from encoder-inferred z0 across validation windows
            print("Generating IC scatter from encoder-inferred z0 (no re-solving)...")
            true_vals = []
            pred_vals = []
            agg_mode = str(config.get('ic_aggregate', 'none')).lower()
            spans_list = [(int(s), int(e)) for (_name, (s, e)) in config.get('neurons', {}).items()] if agg_mode != 'none' else None

            model.eval()
            with torch.no_grad():
                for bout in val_dl:
                    # Use encoder window for IC inference
                    if config['causal_model']:
                        encoding_spikes, _ = to_causal_segments(bout, config['ic_window_size'])
                    else:
                        encoding_spikes = bout
                    encoding_spikes = encoding_spikes.to(config.get('device'))
                    ic_drop, _, _, _, _, _ = model(encoding_spikes)
                    # Readout at IC
                    logr = model._apply_readout(ic_drop.view(ic_drop.shape[0], 1, -1)).squeeze(1)
                    rates = torch.exp(logr).cpu().numpy()  # (B, N)
                    y_true = encoding_spikes[:, 0, :].cpu().numpy()  # (B, N)

                    if agg_mode != 'none' and spans_list is not None:
                        for b in range(rates.shape[0]):
                            pred_groups = []
                            true_groups = []
                            for (s, e) in spans_list:
                                s_i, e_i = int(s), int(e)
                                if agg_mode in ('population_mean','mean','pop_mean'):
                                    pred_groups.append(float(np.mean(rates[b, s_i:e_i])))
                                    true_groups.append(float(np.mean(y_true[b, s_i:e_i])))
                                elif agg_mode in ('population_sum','sum','pop_sum'):
                                    pred_groups.append(float(np.sum(rates[b, s_i:e_i])))
                                    true_groups.append(float(np.sum(y_true[b, s_i:e_i])))
                                else:
                                    pred_groups = None
                                    break
                            if pred_groups is not None:
                                pred_vals.extend(pred_groups)
                                true_vals.extend(true_groups)
                            else:
                                pred_vals.extend(rates[b].tolist())
                                true_vals.extend(y_true[b].tolist())
                    else:
                        true_vals.extend(y_true.reshape(-1).tolist())
                        pred_vals.extend(rates.reshape(-1).tolist())

            # Scatter plot
            if len(true_vals) > 0 and len(pred_vals) == len(true_vals):
                ic_dir = os.path.join(plot_config['BASE_PATH'], 'prediction')
                os.makedirs(ic_dir, exist_ok=True)
                plt.figure(figsize=(5,5))
                tv = np.array(true_vals)
                pv = np.array(pred_vals)
                plt.scatter(tv, pv, s=6, alpha=0.5)
                lo = float(min(tv.min(), pv.min()))
                hi = float(max(tv.max(), pv.max()))
                plt.plot([lo, hi], [lo, hi], 'r--', linewidth=1)
                plt.title('IC fit scatter (validate, encoder z0)')
                plt.xlabel('True first bin')
                plt.ylabel('Pred from z0')
                plt.tight_layout()
                out_path = os.path.join(ic_dir, 'ic_scatter.png')
                plt.savefig(out_path, dpi=150)
                if wandb_run:
                    wandb_run.log({'val/ic_scatter': wandb.Image(plt.gcf())})
                plt.close()
        except Exception as e:
            print(f"  ✗ Error generating IC scatter: {e}")

        # --- Pseudo-R² example neurons per population ---
        try:
            for name in config['neurons'].keys():
                print(f"  - Pseudo-R² examples for: {name}")
                try:
                    fig_examples = plot_prediction_by_pseudo_r2(spikes, pred_rates, config=plot_config, target=name, top_k=3, bottom_k=3, trial_index=0)
                    if wandb_run:
                        wandb_run.log({f'val/{name}_pseudo_r2_examples': wandb.Image(fig_examples)})
                    plt.close(fig_examples)
                except Exception as e:
                    print(f"    ✗ Error plotting pseudo-R² examples for {name}: {e}")
        except Exception as e:
            print(f"  ✗ Error in pseudo-R² example plotting block: {e}")

        
        # Generate flow field plot
        if config['data_type'] == 'wang_100T':
            try:
                flowfield_plot = plot_flowfield(model, pred_latents, plot_config,
                                                validate_only=validate_only)
                # TODO: Can't figure out how to add all flowfield plots to wandb
                if wandb_run:
                    flowfield_plots.append(wandb.Image(flowfield_plot))

                print("  ✓ Saved flow field plot")
            except Exception as e:
                print(f"  ✗ Error generating flow field plot: {e}")
                if wandb_run:
                    flowfield_plots.append(None)
        
        # Generate heatmaps if applicable
        if config.get('heatmap'):
            try:
                if model.readout_type == 'monotonic':
                    weight_matrix = model.readout.log_weight.weight.detach().cpu().numpy()
                    if hasattr(model.readout.log_weight, 'mask'):
                        weight_matrix *= (model.readout.log_weight.mask).detach().cpu().numpy()
                    heat_plot = plot_heatmaps(weight_matrix, plot_config)
                    if wandb_run:
                        heatmap_plots.append(wandb.Image(heat_plot))
                    print("  ✓ Saved heatmap plots")
            except Exception as e:
                print(f"  ✗ Error generating heatmap plots: {e}")
        else:
            print("  - Heatmaps not enabled in config, skipping...")
            
        
        # --- Additional plots: per-bout trial-average traces by population and by region ---
        try:
            os.makedirs(os.path.join(plot_config['BASE_PATH'], 'trial_avg_population'), exist_ok=True)
            os.makedirs(os.path.join(plot_config['BASE_PATH'], 'trial_avg_region'), exist_ok=True)

            # Helper removed: use the earlier _trial_avg_decoded defined above

            # 1) Per-population: average across windows (bouts) to ONE summary trace
            for pop_name, (start, end) in plot_config['neurons'].items():
                spans = [(int(start), int(end))]
                # For each bout: compute per-bout trial-average decoded trace
                per_bout = []
                for bi in range(len(spikes)):
                    if bi == 0:
                        try:
                            print(f"Decoded window shapes for {pop_name}: spikes[0]={spikes[bi].shape}, pred[0]={pred_rates[bi].shape}")
                        except Exception:
                            pass
                    y_true = _trial_avg_decoded(spikes[bi], spans)
                    y_pred = _trial_avg_decoded(pred_rates[bi], spans)
                    if y_true is not None and y_pred is not None:
                        per_bout.append((y_true, y_pred))
                if len(per_bout) == 0:
                    continue
                # Average across bouts → single trace
                true_stack = np.stack([p[0] for p in per_bout], axis=0)
                pred_stack = np.stack([p[1] for p in per_bout], axis=0)
                y_true_avg = true_stack.mean(axis=0)
                y_pred_avg = pred_stack.mean(axis=0)

                fig, ax = plt.subplots(figsize=(10, 4))
                ax.plot(y_true_avg, label='True (avg bouts)')
                ax.plot(y_pred_avg, label='Pred (avg bouts)', linestyle='--')
                ax.set_title(f'{pop_name} decoded trial-average (avg across bouts)')
                ax.set_xlabel('Decoded time')
                ax.legend(loc='upper right')
                fig.savefig(os.path.join(plot_config['BASE_PATH'], 'trial_avg_population', f'{pop_name}.png'))

            # 2) Per-region: group spans by prefix and average across bouts → ONE summary trace
            region_to_spans = {}
            for pop_name, (start, end) in plot_config['neurons'].items():
                region = pop_name.split('_')[0]
                region_to_spans.setdefault(region, []).append((int(start), int(end)))

            for region, spans in region_to_spans.items():
                per_bout = []
                for bi in range(len(spikes)):
                    y_true = _trial_avg_decoded(spikes[bi], spans)
                    y_pred = _trial_avg_decoded(pred_rates[bi], spans)
                    if y_true is not None and y_pred is not None:
                        per_bout.append((y_true, y_pred))
                if len(per_bout) == 0:
                    continue
                true_stack = np.stack([p[0] for p in per_bout], axis=0)
                pred_stack = np.stack([p[1] for p in per_bout], axis=0)
                y_true_avg = true_stack.mean(axis=0)
                y_pred_avg = pred_stack.mean(axis=0)

                fig, ax = plt.subplots(figsize=(10, 4))
                ax.plot(y_true_avg, label='True (avg bouts)')
                ax.plot(y_pred_avg, label='Pred (avg bouts)', linestyle='--')
                ax.set_title(f'{region} decoded trial-average (avg across bouts)')
                ax.set_xlabel('Decoded time')
                ax.legend(loc='upper right')
                fig.savefig(os.path.join(plot_config['BASE_PATH'], 'trial_avg_region', f'{region}.png'))

            print("  ✓ Saved trial-average plots for populations and regions")
        except Exception as e:
            print(f"  ✗ Error generating trial-average plots: {e}")

        print("Plots generated successfully!")
        print(f"Plots saved to: {plot_config['BASE_PATH']}")

        # Clamp sweep overlay removed with latent clamping feature

        # If wandb_run is provided, log images and metrics individually for robustness
        if wandb_run:
            try:
                for name, img in zip(config['neurons'].keys(), pred_plots):
                    wandb_run.log({f'val/{name}_pred': img})
                
                for idx, img in enumerate(flowfield_plots):
                    if img is not None:
                        wandb_run.log({f'val/flow_field_{idx}': img})
                for idx, img in enumerate(heatmap_plots):
                    wandb_run.log({f'val/heatmap_{idx}': img})
                # Log scalar metrics
                wandb_run.log({f'val/{k}': v for k, v in val_metrics.items()})
                print("Plots and metrics logged to Weights & Biases successfully!")
            except Exception as e:
                print(f"Warning: Failed to log to Weights & Biases: {e}")
            finally:
                plt.close('all')

    except ImportError as e:
        print("Warning: Could not generate plots due to missing dependencies:"\
              f" {e}")
    except Exception as e:
        print(f"Warning: Plot generation failed: {e}")
        import traceback
        traceback.print_exc()

@hydra.main(version_base=None, config_path="config")
def main(args):
    """Main training/validation loop."""
    
    # Parse arguments and build configuration
    config = build_experiment_config(args)

    # Infer sensible defaults for population_latent_sizes when appropriate.
    # Note: utils/args defaults this to all-ones; override that implicit default for CO datasets.
    try:
        ds_name = str(config.get('data_type', '')).lower()
        ds_path = str(config.get('dataset_path', '')).lower()
        is_co_dataset = ('co' in ds_name) or ('co' in os.path.basename(ds_path))
        if is_co_dataset:
            required_keys = ['CFA_E', 'CFA_I', 'RFA_E', 'RFA_I']
            if all(k in config.get('neurons', {}) for k in required_keys):
                current = config.get('population_latent_sizes')
                inferred_sizes = {
                    'CFA_E': 10,
                    'CFA_I': 5,
                    'RFA_E': 20,
                    'RFA_I': 5,
                }
                should_override = False
                if not current:
                    should_override = True
                elif isinstance(current, dict):
                    vals = list(current.values())
                    # Override the implicit all-ones default
                    if len(vals) == len(config.get('neurons', {})) and all(int(v) == 1 for v in vals):
                        should_override = True
                if should_override:
                    config['population_latent_sizes'] = inferred_sizes
                    config['latent_size'] = int(sum(inferred_sizes.values()))

        print(config.get('population_latent_sizes'))
    except Exception:
        pass
    
    # Initialize wandb logging if specified
    if args.use_wandb:
        # Generate unique identifier for this run
        unique_id = f"{wandb.util.generate_id()}"
        config['wandb_name'] = f"{config['wandb_name']}_{unique_id}"
        run = wandb.init(
            project=config['wandb_project'],
            entity=config['wandb_entity'],
            name=config['wandb_name'],
            config=config,
            dir=config['wandb_dir'],
            mode=config['wandb_mode']
        )
    else:
        run = None

    # Set up device
    if args.device == 'auto':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        device = torch.device(args.device)
    config['device'] = device
    print(f"Using device: {device}")
    
    # Generate experiment name and set up directories
    experiment_name = get_experiment_name(config, args.experiment_name)
    print(f"Experiment name: {experiment_name}")
    
    # For validate-only: align dataset input and latent/pop sizes to checkpoint
    if getattr(args, 'validate_only', False) and getattr(args, 'checkpoint_path', None):
        try:
            ckpt = torch.load(args.checkpoint_path, map_location='cpu', weights_only=False)
            if isinstance(ckpt, dict) and 'config' in ckpt:
                saved_cfg = ckpt['config']
                # Use trained per-pop neuron counts to define target trimming/padding
                if 'neurons' in saved_cfg:
                    config['target_neurons'] = saved_cfg['neurons']
                # Preserve trained latent population sizes if available
                if 'population_latent_sizes' in saved_cfg:
                    config['population_latent_sizes'] = saved_cfg['population_latent_sizes']
                    # Keep a consistent latent_size derived from saved sizes
                    try:
                        config['latent_size'] = int(sum(saved_cfg['population_latent_sizes'].values()))
                    except Exception:
                        pass
        except Exception as e:
            print(f"Warning: Could not pre-load checkpoint config for alignment: {e}")

    # Convenience: auto-pick latest checkpoint if none provided
    if getattr(args, 'validate_only', False) and not getattr(args, 'checkpoint_path', None):
        try:
            # Search broadly under ./results for latest checkpoint, prefer *_final.pth
            search_root = config.get('output_dir', './results')
            latest_path = None
            latest_mtime = -1
            for root, _, files in os.walk(search_root):
                for fname in files:
                    if fname.endswith('.pth'):
                        fpath = os.path.join(root, fname)
                        mtime = os.path.getmtime(fpath)
                        # Prefer *_final.pth when times tie
                        prefer = fname.endswith('_final.pth')
                        if mtime > latest_mtime or (mtime == latest_mtime and prefer):
                            latest_mtime = mtime
                            latest_path = fpath
            if latest_path:
                print(f"Auto-selected latest checkpoint: {latest_path}")
                setattr(args, 'checkpoint_path', latest_path)
            else:
                print(f"No checkpoints found under: {search_root}")
        except Exception as e:
            print(f"Warning: auto-select checkpoint failed: {e}")

    # Load dataset
    train_dl, val_data = load_dataset(config)

    
    if args.validate_only:
        # Validation-only mode
        if not args.checkpoint_path:
            raise ValueError("--checkpoint_path required for"\
                             " validation-only mode")
        
        print(f"Loading model from: {args.checkpoint_path}")
        model, model_config = load_model_from_checkpoint(args.checkpoint_path, config)
        model.to(device)
        
        # Run validation
        print("Running validation...")

        _, _, _, _, _ = validate(
            model, val_data, config=config, device=device)
        
        # Generate plots by default; allow disabling via config flag
        if config.get('plot', True) and not config.get('disable_plots', False):
            generate_plots(model, val_data, config, wandb_run=run,
                           validate_only=True)
    
    else:
        # Set up output directories
        dirs = setup_directories(config, experiment_name)
        config.update(dirs)
        print(f"Output directory: {dirs['BASE_PATH']}")
        
        # Save experiment configuration
        save_experiment_config(config, experiment_name)

        # Training modes
        print("Creating model...")
        model = create_model(config).to(device)
        
        # Log gradients of model
        if args.use_wandb:
            wandb.watch(model, log="all")

        if config['causal_model']:
            print("Causal training enabled: Model will only use current "
                  "and past data to predict future data...")

        # Create optimizer
        optimizer = create_optimizer(model, config)
        
        # Set up checkpoint path
        checkpoint_path = os.path.join(
            config['CHECKPOINT_PATH'], 
            f"{experiment_name}.pth"
        )
        
        print(f"Starting training for {config['epochs']} epochs...")
        print(f"Solver: {config['solver_type']},"\
              f" Dynamics: {config['dynamics_model_type']},"\
              f" Readout: {config['readout_type']}")
        print(f"Dataset: {config['data_type']} ({config['input_size']}"\
              f" neurons, {config['window_size']} timepoints)")
        
        # Train model
        model, metrics = train(
            model=model,
            optimizer=optimizer,
            train_dataloader=train_dl,
            checkpoint_path=checkpoint_path,
            config=config,
            log=experiment_name,
            device=device,

            val_dataloader=val_data,
            wandb_run=run
        )
        
        print("Training completed!")

        # Run final validation
        print("Running final validation...")

        _, val_pred_rates, val_spikes, _, val_metrics = validate(
            model, val_data, config=config, device=device
        )


        # Log final validation metrics to wandb
        if run:
            for k, v in val_metrics.items():
                run.summary[f'final_{k}'] = v
        
        # Generate plots by default; allow disabling via config flag
        if config.get('plot', True) and not config.get('disable_plots', False):
            generate_plots(model, val_data, config, wandb_run=run)

    
    print("Experiment completed successfully!")

    # Finish wandb run if it was created
    if run:
        # Create wandb artifact
        model_artifact = wandb.Artifact(
            f"{config['wandb_name']}_model",
            type='model',
            description='Trained model for the experiment',
            metadata={
                'config': config
            }
        )
        # Add the model weights to artifact
        model_artifact.add_file(checkpoint_path)

        # Log the artifact to the current run
        run.log_artifact(model_artifact)

        # Finish the wandb run
        run.finish()


if __name__ == "__main__":
    main()
