import os
import time
import argparse
import torch
import subprocess
import yaml

def get_args():
    parser = argparse.ArgumentParser(description='Deep Predictive Coding Network')

    # Device
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
                        help='Device to use for training')

    # Model - New Args System
    parser.add_argument('--energy_option', type=str, default='vanilla_pc', choices=['vanilla_pc', 'meta_pc'],
                        help='Energy option: meta_pc (default for success reproduction) or vanilla_pc')

    # Marker
    parser.add_argument('--marker', type=str, default=None,
                        help='Marker to use for training')

    # Model settings
    parser.add_argument('--z_init', type=str, default='ff', choices=['ff', 'zero', 'random'],
                        help='Z initialization method')
    parser.add_argument('--update_rule', type=str, default='pcn', choices=['pcn', 'bp'],
                        help='Update rule to use')
    parser.add_argument('--backbone', type=str, default='vgg13',
                        help='Backbone architecture')

    # Training hyperparameters
    parser.add_argument('--eta', type=float, default=0.2,
                        help='Learning rate for predictive coding updates')
    parser.add_argument('--T', type=int, default=20,
                        help='Number of inference iterations')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='Batch size for training')
    parser.add_argument('--epochs', type=int, default=50,
                        help='Number of training epochs')
    parser.add_argument('--lr', type=float, default=1e-4,
                        help='Learning rate for optimizer')
    parser.add_argument('--weight_decay', type=float, default=5e-4,
                        help='Weight decay for optimizer')
    parser.add_argument('--optimizer', type=str, default='adamw', choices=['adam', 'adamw', 'sgd'],
                        help='Optimizer type')
    parser.add_argument('--grad_clip_norm', type=float, default=0.0,
                        help='Gradient clipping max norm (0 = no clipping)')

    # Dataset settings
    parser.add_argument('--data_dir', type=str, default=os.path.expanduser('./datasets'),
                        help='Directory containing datasets')
    parser.add_argument('--dataset', type=str, default='cifar10',
                        choices=['cifar10', 'cifar100', 'imagenet'], # 'mnist', 'svhn', 'fmnist' - commented for future expansion
                        help='Dataset to use (currently: cifar10, cifar100, imagenet)')
    parser.add_argument('--num_workers', type=int, default=8,
                        help='Number of data loader workers')

    # === Unified Solver System ===
    parser.add_argument('--solver_type', type=str, default='vanilla',
                       choices=['vanilla'], # 'anderson', 'broyden' - commented for future expansion
                       help='Type of solver to use for latent inference (currently: vanilla only)')

    # === Update Rule Options ===
    parser.add_argument('--update_latent_rule', type=str, default='jacobi',
                       choices=['jacobi', 'gs', 'block_sweep_gs'],
                       help='Latent update rule')
    parser.add_argument('--loop_scheduler', type=str, default='jacobi',
                       choices=['jacobi', 'gs', 'block_sweep_gs', 'gauss_seidel', 'block_sweep'],
                       help='Loop scheduler for latent updates (overrides update_latent_rule if specified)')

    # Update rule specific parameters
    parser.add_argument('--custom_layer_order', type=str, default=None,
                       help='Comma-separated layer order for custom_order (e.g., "1,3,2,4")')

    # === Common Solver Parameters ===
    parser.add_argument('--deq_max_iter', type=int, default=50,
                       help='Maximum iterations for advanced solvers')
    parser.add_argument('--deq_tol', type=float, default=1e-4,
                       help='Convergence tolerance for advanced solvers')

    # === Anderson Acceleration Parameters === (Commented for future expansion)
    # parser.add_argument('--anderson_m', type=int, default=6,
    #                    help='Anderson acceleration memory size')
    # parser.add_argument('--anderson_lam', type=float, default=1e-4,
    #                    help='Anderson regularization parameter')
    # parser.add_argument('--anderson_tau', type=float, default=1.0,
    #                    help='Anderson damping factor')
    # parser.add_argument('--anderson_block_wise', type=str, default='true',
    #                    choices=['true', 'false'], help='Use block-wise Anderson')
    # parser.add_argument('--anderson_safeguard', type=str, default='true',
    #                    choices=['true', 'false'], help='Use Anderson safeguards')

    # === Broyden Method Parameters === (Commented for future expansion)
    # parser.add_argument('--broyden_memory_size', type=int, default=None,
    #                    help='Broyden memory size (None for auto)')
    # parser.add_argument('--broyden_line_search', type=str, default='false',
    #                    choices=['true', 'false'], help='Use line search in Broyden')
    # parser.add_argument('--broyden_block_wise', type=str, default='true',
    #                    choices=['true', 'false'], help='Use block-wise Broyden')
    # parser.add_argument('--broyden_regularization', type=float, default=1e-6,
    #                    help='Broyden regularization parameter')

    # === Vanilla PCN Parameters ===
    parser.add_argument('--vanilla_early_stopping', type=str, default='false',
                       choices=['true', 'false'], help='Enable early stopping for vanilla PCN')

    # === Legacy Support ===
    parser.add_argument('--use_deq_solver', type=str, default='false', choices=['true', 'false'],
                       help='[Legacy] Use DEQ solvers')
    # parser.add_argument('--deq_solver_type', type=str, default='anderson',
    #                    choices=['anderson', 'broyden'], help='[Legacy] DEQ solver type')

    # === Diagnostics and Performance ===
    parser.add_argument('--timing_enabled', type=str, default='false',
                       choices=['true', 'false'], help='Enable timing measurements')
    parser.add_argument('--diagnostics_enabled', type=str, default='false',
                       choices=['true', 'false'], help='Enable solver diagnostics')

    # === Unified Weight Normalization ===
    parser.add_argument('--norm_type', type=str, default='none',
                       choices=['none', 'spectral', 'frobenius', 'variance'],
                       help='Type of weight normalization: spectral, frobenius, variance, or none')
    parser.add_argument('--norm_clip', type=str, default='false',
                       choices=['true', 'false'],
                       help='Enable clipping in normalization')
    parser.add_argument('--norm_clip_value', type=float, default=1.0,
                       help='Clipping value for normalization (max scale factor)')
    parser.add_argument('--norm_learnable_scale', type=str, default='false',
                       choices=['true', 'false'],
                       help='Use learnable g parameter (true) or fixed target_norm (false)')
    parser.add_argument('--norm_target_norm', type=float, default=0.9,
                       help='Target norm value when learnable_scale=false')
    parser.add_argument('--norm_filter_out', type=str, default=None,
                       help='Module names to skip normalization (comma-separated)')

    # === Logging and Performance Settings ===
    parser.add_argument('--train_only', action='store_true',
                        help='Skip plotting and extra computations for faster training')
    parser.add_argument('--log_interval', type=int, default=1,
                        help='Unified logging interval for train/test evaluation (every N epochs)')
    parser.add_argument('--save_detailed_metrics', action='store_true',
                        help='Save detailed experiment metrics to JSON')
    parser.add_argument('--experiment_name', type=str, default=None,
                        help='Experiment name for detailed logging')
    parser.add_argument('--track_batch_metrics', action='store_true',
                        help='Enable batch-level performance tracking in detailed metrics')

    # Parse arguments
    args = parser.parse_args()

    # Convert energy_option to boolean for backward compatibility
    args.use_pred = args.energy_option == 'meta_pc'

    if args.use_pred:
        args.update_param_rule = 'pred_freeze'
    else:
        args.update_param_rule = 'default'

    # Solver system boolean conversions (Commented for future expansion)
    # args.anderson_block_wise = args.anderson_block_wise == 'true'
    # args.anderson_safeguard = args.anderson_safeguard == 'true'
    # args.broyden_line_search = args.broyden_line_search == 'true'
    # args.broyden_block_wise = args.broyden_block_wise == 'true'
    args.vanilla_early_stopping = args.vanilla_early_stopping == 'true'
    args.timing_enabled = args.timing_enabled == 'true'
    args.diagnostics_enabled = args.diagnostics_enabled == 'true'

    # Weight normalization boolean conversions
    args.norm_clip = args.norm_clip == 'true'
    args.norm_learnable_scale = args.norm_learnable_scale == 'true'

    # Legacy support
    args.use_deq_solver = args.use_deq_solver == 'true'

    # Legacy compatibility: map old arguments to new system (Commented for future expansion)
    # if args.use_deq_solver and args.solver_type == 'vanilla':
    #     args.solver_type = args.deq_solver_type

    # Process custom layer order
    if args.custom_layer_order:
        try:
            args.custom_layer_order = [int(x.strip()) for x in args.custom_layer_order.split(',')]
        except ValueError:
            print(f"Warning: Invalid custom_layer_order format: {args.custom_layer_order}")
            args.custom_layer_order = None

    # Dataset-specific settings
    # Commented for future expansion:
    # if args.dataset == 'mnist':
    #     args.num_classes = 10
    #     args.img_shape = (1, 28, 28)
    #     args.mean = [0.1307]
    #     args.std = [0.3081]
    if args.dataset == 'cifar10':
        args.num_classes = 10
        args.img_shape = (3, 32, 32)
        args.mean = [0.4914, 0.4822, 0.4465]
        args.std = [0.2023, 0.1994, 0.2010]
    elif args.dataset == 'cifar100':
        args.num_classes = 100
        args.img_shape = (3, 32, 32)
        args.mean = [0.5071, 0.4867, 0.4408]
        args.std = [0.2675, 0.2565, 0.2761]
    elif args.dataset == 'imagenet':
        args.num_classes = 200  # tiny-imagenet-200
        args.img_shape = (3, 64, 64)
        args.mean = [0.485, 0.456, 0.406]
        args.std = [0.229, 0.224, 0.225]
    # Commented for future expansion:
    # elif args.dataset == 'svhn':
    #     args.num_classes = 10
    #     args.img_shape = (3, 32, 32)
    #     args.mean = [0.4376821, 0.4437697, 0.47280442]
    #     args.std = [0.19803012, 0.20101562, 0.19703614]
    # elif args.dataset == 'fmnist':
    #     args.num_classes = 10
    #     args.img_shape = (1, 28, 28)
    #     args.mean = [0.2860]
    #     args.std = [0.3530]
    else:
        raise ValueError(f"Dataset {args.dataset} not supported")

    # Logging directory - unified under workspace
    if args.marker is None:
        now = time.strftime('%Y%m%d-%H%M%S')
        args.log_dir = os.path.join('workspace', 'z_log', now)
    else:
        args.log_dir = os.path.join('workspace', 'z_log', args.marker)
    os.makedirs(args.log_dir, exist_ok=True)

    # Set commit hash without git operations (removed auto-commit for VastAI compatibility)
    args.commit_hash = f"experiment_{time.strftime('%Y%m%d_%H%M%S')}"
    print(f"Using timestamp-based commit identifier: {args.commit_hash}")

    # save args.__dict__ to yaml
    with open(os.path.join(args.log_dir, 'args.yaml'), 'w') as f:
        yaml.dump(args.__dict__, f)

    return args
