"""
Spike Agent Communication Experiments
===============================================
This script runs experiments on spike-based communication agents using
pretrained CommsMod models. It supports single experiments but will be expanded for
comparison studies with multiple configurations in the future.

Date: Jul-2025
"""
import os
import argparse
import torch

from utils.helpers import NeuromorphicExperimentManager, create_experiment_configs
from utils.train   import run_training

def safe_get_final_metric(metrics, key, default=0.0):
    """Safely get the final value of a metric, with fallback"""
    if key in metrics and isinstance(metrics[key], list) and len(metrics[key]) > 0:
        value = metrics[key][-1]
        # Handle numpy types
        if hasattr(value, 'item'):
            return value.item()
        return float(value)
    return default

def main():
    parser = argparse.ArgumentParser(
        description="Neuromorphic Communication Experiments"
    )
    parser.add_argument(
        "--pretrained-commsmod", type=str,
        default="fashion_mnist_improved_snn.pth",
        help="Path to pretrained CommsMod model file"
    )
    parser.add_argument(
        "--experiment", type=str,
        choices=["single", "comparison", "all"],
        default="single",
        help="Which experiment to run"
    )
    parser.add_argument(
        "--config", type=str,
        choices=["rate_fast", "rate_standard", "temporal_standard", "staged_training"],
        default="rate_standard",
        help="Configuration key (for a single experiment)"
    )
    parser.add_argument(
        "--output-dir", type=str, default="experiments",
        help="Base directory under which to store all experiment outputs"
    )
    parser.add_argument(
        "--device", type=str, default="auto",
        help="torch device: 'cuda', 'cpu', or 'auto'"
    )
    args = parser.parse_args()

    # Resolve device
    if args.device == "auto":
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device(args.device)

    print("NEUROMORPHIC AGENT COMMUNICATION EXPERIMENTS")
    print("=" * 60)
    print(f" Device:             {device}")
    print(f" Output base dir:    {args.output_dir}")
    print(f" Pretrained CommsMod: {args.pretrained_commsmod}\n")

    if not os.path.exists(args.pretrained_commsmod):
        print(f"❌ Error: could not find pretrained CommsMod at '{args.pretrained_commsmod}'")
        return

    # Initialize manager and configs
    manager = NeuromorphicExperimentManager(base_dir=args.output_dir)
    configs = create_experiment_configs()
        
    try:
        if args.experiment == "single":
            name = args.config
            cfg  = configs[name]
            metrics, exp_path = manager.run_experiment(
                name=name,
                train_fn=run_training,
                config=cfg,
                pretrained_path=args.pretrained_commsmod,
                plot_fn=None    # run_training handles its own plotting every 5 epochs
            )
            
            final_accuracy = safe_get_final_metric(metrics, 'accuracy', 0.0)
            final_comm_success = safe_get_final_metric(metrics, 'communication_success', 0.0)
            if final_comm_success == 0.0:  # Fallback to other possible keys
                final_comm_success = safe_get_final_metric(metrics, 'avg_communication_success', 0.0)
            if final_comm_success == 0.0:  # If still zero, use protocol discriminability
                final_comm_success = safe_get_final_metric(metrics, 'protocol_discriminability', 0.0) * 100
                
            print(f"\n Experiment '{name}' done.")
            print(f"   Final accuracy:           {final_accuracy:.1f}%")
            print(f"   Final communication succ: {final_comm_success:.1f}%")
            print(f"   Outputs in:               {exp_path}")

        elif args.experiment == "comparison":
            experiments = [
                ("rate_fast",         configs["rate_fast"]),
                ("rate_standard",     configs["rate_standard"]),
                ("temporal_standard", configs["temporal_standard"]),
            ]
            all_metrics = manager.run_multiple(
                experiments=experiments,
                pretrained_path=args.pretrained_commsmod,
                train_fn=run_training,
                plot_fn=None
            )
            print("\n Comparison results:")
            for (name, _), m in zip(experiments, all_metrics):
                print(f"  {name}: {m['accuracy'][-1]:.1f}%")

        else:  # args.experiment == "all"
            experiments = [
                ("rate_fast",         configs["rate_fast"]),
                ("rate_standard",     configs["rate_standard"]),
                ("temporal_standard", configs["temporal_standard"]),
                ("staged_training",   configs["staged_training"]),
            ]
            all_metrics = manager.run_multiple(
                experiments=experiments,
                pretrained_path=args.pretrained_commsmod,
                train_fn=run_training,
                plot_fn=None
            )
            print("\n Full suite results:")
            for (name, _), m in zip(experiments, all_metrics):
                print(f"  {name}: {m['accuracy'][-1]:.1f}%")

    except KeyboardInterrupt:
        print("\n⚠️  Experiment interrupted by user")
    except Exception as e:
        print(f"\n❌ Experiment failed: {e}")
        import traceback; traceback.print_exc()
        raise

    print(f"\n All outputs under: {manager.experiment_dir}")
    print(" Neuromorphic experiments completed!")

if __name__ == "__main__":
    main()