#!/usr/bin/env python
# run_scale_experiment.py - Script to run weight variance(scale)-controlled RNN experiments

import numpy as np
import torch
import random
from pathlib import Path
import json
from datetime import datetime
import argparse
import sys

# Add the src directory to the Python path
sys.path.append('./src')

from cpro import CPRO, CPROConfig
from model import RichLazyControlledRNN
from train import train_epoch, evaluate, run_experiment
from utils import set_seed, convert_tensors_to_serializable, log_scale_experiment

def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description='Run a scale experiment')
    
    # Required arguments
    parser.add_argument('--training-mode', type=str, required=True, 
                        choices=['minimal', 'balanced_16', 'balanced_32', 'balanced_48', 'maximal'],
                        help='Number of tasks used for training')
    parser.add_argument('--scale', type=float, required=True,
                    help='Scale parameter for weight initialization variance')
    parser.add_argument('--seed', type=int, required=True,
                        help='Random seed for reproducibility')
    
    # Optional arguments with defaults
    parser.add_argument('--hidden-size', type=int, default=128,
                        help='Size of hidden layer (default: 128)')
    parser.add_argument('--optimizer', type=str, default='adamw',
                        choices=['adamw', 'sgd'],
                        help='Optimizer to use (default: adamw)')
    parser.add_argument('--weight-decay', type=float, default=0.01,
                        help='Weight decay parameter (default: 0.01)')
    parser.add_argument('--learning-rate', type=float, default=0.001,
                        help='Learning rate (default: 0.001)')
    parser.add_argument('--n-epochs', type=int, default=1000,
                        help='Number of training epochs (default: 500)')
    parser.add_argument('--eval-every', type=int, default=100,
                        help='Evaluate every N epochs (default: 100)')
    parser.add_argument('--early-stopping', action='store_true',
                        help='Enable early stopping')
    parser.add_argument('--results-dir', type=str, default=None,
                        help='Custom results directory (default: auto-generated)')
    parser.add_argument('--description', type=str, default="",
                        help='Experiment description for logging')
    parser.add_argument('--use-gpu', action='store_true', default=False,
                    help='Use GPU acceleration if available')
    
    args = parser.parse_args()
    
    # Set up result directory
    if args.results_dir:
        results_dir = Path(args.results_dir)
    else:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        BASE_RESULTS_DIR = Path("./results")
        results_dir = BASE_RESULTS_DIR / f"experiment_{timestamp}"
    
    results_dir.mkdir(exist_ok=True, parents=True)
    
    # Print experiment configuration
    print(f"Running experiment with the following configuration:")
    print(f"  Training Mode: {args.training_mode}")
    print(f"  Hidden Size: {args.hidden_size}")
    print(f"  Scale: {args.scale}")
    print(f"  Seed: {args.seed}")
    print(f"  Optimizer: {args.optimizer}")
    print(f"  Weight Decay: {args.weight_decay}")
    print(f"  Results Directory: {results_dir}")
    
    # Set seed for reproducibility
    set_seed(args.seed)
    
    # Create environment
    env = CPRO(training_mode=args.training_mode)
    
    # Store configuration
    config = {
        'hidden_size': args.hidden_size,
        'scales': [args.scale],
        'seeds': [args.seed],
        'training_mode': args.training_mode,
        'optimizer': args.optimizer,
        'weight_decay': args.weight_decay,
        'learning_rate': args.learning_rate
    }
    
    # Run experiment
    print(f"Starting experiment...")
    history = run_experiment(
        env=env,
        hidden_size=args.hidden_size,
        scale=args.scale,
        optimizer_type=args.optimizer,
        weight_decay=args.weight_decay,
        n_epochs=args.n_epochs,
        eval_every=args.eval_every,
        use_early_stopping=args.early_stopping,
        results_dir=results_dir,
        seed=args.seed,
        use_gpu=args.use_gpu,
        save_all_stimuli=True,
        save_initial_states=True
    )
    
    # Store results
    all_results = {
        'config': config,
        'results': {
            str(args.scale): {
                str(args.seed): history
            }
        }
    }
    
    # Convert tensors before saving
    serializable_history = convert_tensors_to_serializable(history)
    with open(results_dir / f"results_scale{args.scale}_seed{args.seed}.json", 'w') as f:
        json.dump(serializable_history, f)
    
    # Save complete results
    serializable_results = convert_tensors_to_serializable(all_results)
    with open(results_dir / "all_results.json", 'w') as f:
        json.dump(serializable_results, f)
    
    print(f"Experiment completed!")
    print(f"Results saved in: {results_dir}")
    
    # Log experiment
    description = args.description if args.description else f"Scale {args.scale} experiment with {args.training_mode} training mode"
    log_scale_experiment(results_dir, description)

if __name__ == "__main__":
    main()
