#!/usr/bin/env python3
"""
CHT-CNN Training Preset Manager

This script manages preset configurations for training experiments, simplifying command line usage.
Users only need to specify preset name and a few variable parameters to run experiments.
Additional parameters can be overridden using --override 11flag.
"""

import argparse
import sys
import subprocess
from typing import Dict, Any


# Base configurations
VGG_CIFAR10_BASE = {
    'model_type': 'VGG16_CIFAR_BN',
    'dataset': 'CIFAR10',
    'num_epochs': 100,
    'batch_size': 128,
    'num_processes': 3,
    'learning_rate': 0.01
}

VGG_CIFAR100_BASE = {
    'model_type': 'VGG16_CIFAR_BN',
    'dataset': 'CIFAR100',
    'num_epochs': 240,
    'batch_size': 128,
    'num_processes': 3,
    'learning_rate': 0.1
}

RESNET50_CIFAR10_BASE = {
    'model_type': 'ResNet50_CIFAR',
    'dataset': 'CIFAR10',
    'num_epochs': 240,
    'batch_size': 128,
    'num_processes': 3,
    'learning_rate': 0.1,
}

RESNET50_CIFAR100_BASE = {
    'model_type': 'ResNet50_CIFAR',
    'dataset': 'CIFAR100',
    'num_epochs': 240,
    'batch_size': 128,
    'num_processes': 3,
    'learning_rate': 0.1,
}

# TinyImageNet base configurations
VGG_TINY_BASE = {
    'model_type': 'VGG16_CIFAR_BN',
    'dataset': 'TINY',
    'num_epochs': 240,
    'batch_size': 128,
    'num_processes': 3,
    'learning_rate': 0.1
}

RESNET50_TINY_BASE = {
    'model_type': 'ResNet50_CIFAR',
    'dataset': 'TINY',
    'num_epochs': 240,
    'batch_size': 128,
    'num_processes': 3,
    'learning_rate': 0.1,
}

# SPARSITY_70 = {
#     'sparsity': 0.7,
#     'link_update_ratio': 0.5,
# }

# Preset configurations dictionary
PRESET_CONFIGS = {
    # -----Dense networks-----
    'vgg-cifar10-dense': {
        **VGG_CIFAR10_BASE,
        'sparsity': 0.0,
    },

    'vgg-cifar100-dense': {
        **VGG_CIFAR100_BASE,
        'sparsity': 0.0,
    },

    'resnet50-cifar10-dense': {
        **RESNET50_CIFAR10_BASE,
        'sparsity': 0.0,
    },

    'resnet50-cifar100-dense': {
        **RESNET50_CIFAR100_BASE,
        'sparsity': 0.0,
    },

    'vgg-tiny-dense': {
        **VGG_TINY_BASE,
        'sparsity': 0.0,
    },

    'resnet50-tiny-dense': {
        **RESNET50_TINY_BASE,
        'sparsity': 0.0,
    },
    
    # 'vgg-cifar10-static': {
    #     **VGG_CIFAR10_BASE,
    #     'sparsity': 0.5,
    #     'link_update_ratio': 0.,
    # },

    # 'vgg-cifar100-static': {
    #     **VGG_CIFAR100_BASE,
    #     'sparsity': 0.5,
    #     'link_update_ratio': 0.,
    # },

    # -----Sparse networks with SET-----
    'vgg-cifar10-set': {
        **VGG_CIFAR10_BASE,
        'sparsity': 0.5,
        'link_update_ratio': 0.3,
        'remove_method': 'wm',
        'regrow_method': 'rand',
        'shared_mask_sw': True
    },
    
    'vgg-cifar100-set': {
        **VGG_CIFAR100_BASE,
        'sparsity': 0.5,
        'link_update_ratio': 0.3,
        'remove_method': 'wm',
        'regrow_method': 'rand',
        'shared_mask_sw': True
    },

    'resnet50-cifar10-set': {
        **RESNET50_CIFAR10_BASE,
        'sparsity': 0.5,
        'link_update_ratio': 0.3,
        'remove_method': 'wm',
        'regrow_method': 'rand',
        'shared_mask_sw': True
    },

    'resnet50-cifar100-set': {
        **RESNET50_CIFAR100_BASE,
        'sparsity': 0.5,
        'link_update_ratio': 0.3,
        'remove_method': 'wm',
        'regrow_method': 'rand',
        'shared_mask_sw': True
    },

    'vgg-tiny-set': {
        **VGG_TINY_BASE,
        'sparsity': 0.5,
        'link_update_ratio': 0.3,
        'remove_method': 'wm',
        'regrow_method': 'rand',
        'shared_mask_sw': True
    },

    'resnet50-tiny-set': {
        **RESNET50_TINY_BASE,
        'sparsity': 0.5,
        'link_update_ratio': 0.3,
        'remove_method': 'wm',
        'regrow_method': 'rand',
        'shared_mask_sw': True
    },

    # -----CHT1-----
    # 'vgg-cifar10-cht1': {
    #     **VGG_CIFAR10_BASE,
    #     'sparsity': 0.5,
    #     'link_update_ratio': 0.3,
    #     'remove_method': 'wm',
    #     'regrow_method': 'L3n',
    #     'zone_sz': 1,
    #     'shared_mask_zone': True
    # },
    
    # 'vgg-cifar10-cht2': {
    #     **VGG_CIFAR10_BASE,
    #     'sparsity': 0.5,
    #     'link_update_ratio': 0.3,
    #     'remove_method': 'wm',
    #     'regrow_method': 'L3n',
    #     'zone_sz': 1,
    #     'shared_mask_zone': True,
    #     'avg_regrow': True
    # },
    
    # 'vgg-cifar10-cht3': {
    #     **VGG_CIFAR10_BASE,
    #     'sparsity': 0.5,
    #     'link_update_ratio': 0.3,
    #     'remove_method': 'wm',
    #     'regrow_method': 'L3n',
    #     'zone_sz': 1,
    #     'shared_mask_zone': True,
    #     'avg_remove': True,
    #     'avg_regrow': True
    # },
    
    # -----CHT4-----
    # 'vgg-cifar10-cht4': {
    #     **VGG_CIFAR10_BASE,
    #     'sparsity': 0.5,
    #     'link_update_ratio': 0.3,
    #     'remove_method': 'wm',
    #     'regrow_method': 'L3n',
    #     'zone_sz': 1,
    #     'shared_mask_zone': True,
    #     'avg_remove': True,
    #     'avg_regrow': True,
    #     'use_opt4': True
    # },

    # 'vgg-cifar100-cht4': {
    #     **VGG_CIFAR100_BASE,
    #     'sparsity': 0.5,
    #     'link_update_ratio': 0.3,
    #     'remove_method': 'wm',
    #     'regrow_method': 'L3n',
    #     'zone_sz': 1,
    #     'shared_mask_zone': True,
    #     'avg_remove': True,
    #     'avg_regrow': True,
    #     'use_opt4': True
    # },
    
    # -----CHTs-----
    # 'vgg-cifar10-chts1': {
    #     **VGG_CIFAR10_BASE,
    #     'sparsity': 0.5,
    #     'link_update_ratio': 0.3,
    #     'remove_method': 'wm',
    #     'regrow_method': 'L3n',
    #     'zone_sz': 1,
    #     'shared_mask_zone': True,
    #     'soft': True
    # },
    
    # 'vgg-cifar10-chts2': {
    #     **VGG_CIFAR10_BASE,
    #     'sparsity': 0.5,
    #     'link_update_ratio': 0.3,
    #     'remove_method': 'wm',
    #     'regrow_method': 'L3n',
    #     'zone_sz': 1,
    #     'shared_mask_zone': True,
    #     'avg_regrow': True,
    #     'soft': True
    # },
    
    # 'vgg-cifar10-chts3': {
    #     **VGG_CIFAR10_BASE,
    #     'sparsity': 0.5,
    #     'link_update_ratio': 0.3,
    #     'remove_method': 'wm',
    #     'regrow_method': 'L3n',
    #     'zone_sz': 1,
    #     'shared_mask_zone': True,
    #     'avg_remove': True,
    #     'avg_regrow': True,
    #     'soft': True
    # },
    
    # -----CHTs#4-----
    'vgg-cifar10-chts4': {
        **VGG_CIFAR10_BASE,
        'sparsity': 0.5,
        'link_update_ratio': 0.3,
        'remove_method': 'wm',
        'regrow_method': 'L3n',
        'zone_sz': 1,
        'shared_mask_zone': True,
        'avg_remove': True,
        'avg_regrow': True,
        'use_opt4': True,
        'soft': True
    },

    'vgg-cifar100-chts4': {
        **VGG_CIFAR100_BASE,
        'sparsity': 0.5,
        'link_update_ratio': 0.3,
        'remove_method': 'wm',
        'regrow_method': 'L3n',
        'zone_sz': 1,
        'shared_mask_zone': True,
        'avg_remove': True,
        'avg_regrow': True,
        'use_opt4': True,
        'soft': True
    },

    'resnet50-cifar100-chts4': {
        **RESNET50_CIFAR100_BASE,
        'sparsity': 0.5,
        'link_update_ratio': 0.3,
        'remove_method': 'wm',
        'regrow_method': 'L3n',
        'shared_mask_zone': True,
        'avg_remove': True,
        'avg_regrow': True,
        'use_opt4': True,
        'soft': True,
        'zone_sz': 16
    },

    'vgg-tiny-chts4': {
        **VGG_TINY_BASE,
        'sparsity': 0.5,
        'link_update_ratio': 0.3,
        'remove_method': 'wm',
        'regrow_method': 'L3n',
        'zone_sz': 1,
        'shared_mask_zone': True,
        'avg_remove': True,
        'avg_regrow': True,
        'use_opt4': True,
        'soft': True
    },

    'resnet50-tiny-chts4': {
        **RESNET50_TINY_BASE,
        'sparsity': 0.5,
        'link_update_ratio': 0.3,
        'remove_method': 'wm',
        'regrow_method': 'L3n',
        'shared_mask_zone': True,
        'avg_remove': True,
        'avg_regrow': True,
        'use_opt4': True,
        'soft': True,
        'zone_sz': 16
    },

    # -----Shared_mask_sw CHT4-----
    'vgg-cifar10-scht4': {
        **VGG_CIFAR10_BASE,
        'sparsity': 0.5,
        'link_update_ratio': 0.3,
        'remove_method': 'wm',
        'regrow_method': 'L3n',
        'shared_mask_sw': True,
        'zone_sz': 1,
        'shared_mask_zone': True,
        'avg_remove': True,
        'avg_regrow': True,
        'use_opt4': True
    },

    'vgg-cifar100-scht4': {
        **VGG_CIFAR100_BASE,
        'sparsity': 0.5,
        'link_update_ratio': 0.3,
        'remove_method': 'wm',
        'regrow_method': 'L3n',
        'shared_mask_sw': True,
        'zone_sz': 1,
        'shared_mask_zone': True,
        'avg_remove': True,
        'avg_regrow': True,
        'use_opt4': True
    },

    'resnet50-cifar10-scht4': {
        **RESNET50_CIFAR10_BASE,
        'sparsity': 0.5,
        'link_update_ratio': 0.3,
        'remove_method': 'wm',
        'regrow_method': 'L3n',
        'shared_mask_sw': True,
        'shared_mask_zone': True,
        'avg_remove': True,
        'avg_regrow': True,
        'use_opt4': True,
        'zone_sz': 16
    },

    'resnet50-cifar100-scht4': {
        **RESNET50_CIFAR100_BASE,
        'sparsity': 0.5,
        'link_update_ratio': 0.3,
        'remove_method': 'wm',
        'regrow_method': 'L3n',
        'shared_mask_sw': True,
        'shared_mask_zone': True,
        'avg_remove': True,
        'avg_regrow': True,
        'use_opt4': True,
        'zone_sz': 16
    },

    'vgg-tiny-scht4': {
        **VGG_TINY_BASE,
        'sparsity': 0.5,
        'link_update_ratio': 0.3,
        'remove_method': 'wm',
        'regrow_method': 'L3n',
        'shared_mask_sw': True,
        'zone_sz': 1,
        'shared_mask_zone': True,
        'avg_remove': True,
        'avg_regrow': True,
        'use_opt4': True
    },

    'resnet50-tiny-scht4': {
        **RESNET50_TINY_BASE,
        'sparsity': 0.5,
        'link_update_ratio': 0.3,
        'remove_method': 'wm',
        'regrow_method': 'L3n',
        'shared_mask_sw': True,
        'shared_mask_zone': True,
        'avg_remove': True,
        'avg_regrow': True,
        'use_opt4': True,
        'zone_sz': 16
    },
}


def get_preset_config(preset_name: str) -> Dict[str, Any]:
    """
    Get preset configuration
    
    Args:
        preset_name: Preset configuration name
        
    Returns:
        Preset configuration dictionary
        
    Raises:
        ValueError: If preset name doesn't exist
    """
    if preset_name not in PRESET_CONFIGS:
        available_presets = ', '.join(PRESET_CONFIGS.keys())
        raise ValueError(f'Preset "{preset_name}" not found. Available presets: {available_presets}')
    
    return PRESET_CONFIGS[preset_name].copy()


def get_valid_parameters() -> set:
    """
    Get set of all valid parameter names from arg_parser
    
    Returns:
        Set of valid parameter names
    """
    return {
        'sparsity', 'mlp_sparsity', 'link_update_ratio', 'remove_method', 'regrow_method',
        'shared_mask_sw', 'shared_mask_zone', 'zone_sz', 'avg_remove', 'avg_regrow',
        'soft', 'use_opt4', 'delta', 'delta_max', 'delta_d', 'ch_method', 'use_hidden',
        'model_type', 'dataset', 'num_epochs', 'learning_rate', 'batch_size',
        'num_processes', 'seed', 'gpus', 'checkpoint', 'tag'
    }


def validate_override_parameters(overrides: Dict[str, Any]) -> None:
    """
    Validate that all override parameters are valid
    
    Args:
        overrides: Dictionary of override parameters
        
    Raises:
        ValueError: If any override parameter is not valid
    """
    valid_params = get_valid_parameters()
    invalid_params = []
    
    for param in overrides.keys():
        if param not in valid_params:
            invalid_params.append(param)
    
    if invalid_params:
        raise ValueError(f'Invalid override parameters: {", ".join(invalid_params)}')


def parse_override_args(override_args: list) -> Dict[str, Any]:
    """
    Parse override arguments in format key=value
    
    Args:
        override_args: List of override arguments
        
    Returns:
        Dictionary of override parameters
        
    Raises:
        ValueError: If override format is invalid or parameters are not valid
    """
    overrides = {}
    
    for arg in override_args:
        if '=' not in arg:
            raise ValueError(f'Invalid override format: {arg}. Use format: key=value')
        
        key, value = arg.split('=', 1)
        key = key.strip()
        value = value.strip()
        
        # Convert value to appropriate type
        if value.lower() == 'true':
            overrides[key] = True
        elif value.lower() == 'false':
            overrides[key] = False
        elif value.lower() == 'none':
            overrides[key] = None
        else:
            # Try to convert to int or float
            try:
                if '.' in value:
                    overrides[key] = float(value)
                else:
                    overrides[key] = int(value)
            except ValueError:
                overrides[key] = value
    
    # Validate all override parameters
    validate_override_parameters(overrides)
    
    return overrides


def merge_config_with_overrides(base_config: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
    """
    Merge base configuration with overrides
    
    Args:
        base_config: Base preset configuration
        overrides: Override parameters
        
    Returns:
        Merged configuration
    """
    merged_config = base_config.copy()
    
    for key, value in overrides.items():
        if value is None:
            # Remove parameter if value is None
            merged_config.pop(key, None)
        else:
            merged_config[key] = value
    
    return merged_config


def build_command_args(preset_config: Dict[str, Any], gpus: str, tag: str) -> list:
    """
    Build command line arguments list
    
    Args:
        preset_config: Preset configuration dictionary
        gpus: GPU devices string
        tag: Experiment tag
        
    Returns:
        Command line arguments list
    """
    args = ['python', 'experiments.py']
    
    # Add parameters from preset configuration
    for key, value in preset_config.items():
        if isinstance(value, bool):
            if value:
                args.append(f'--{key}')
        else:
            args.append(f'--{key}')
            args.append(str(value))
    
    # Add GPU and tag parameters
    if gpus:
        args.extend(['--gpus', gpus])
    if tag:
        args.extend(['--tag', tag])
    
    return args


def list_presets():
    """List all available preset configurations"""
    print('Available preset configurations:')
    print('=' * 50)
    
    for preset_name in PRESET_CONFIGS.keys():
        preset_config = PRESET_CONFIGS[preset_name]
        print(f'\n{preset_name}:')
        print(f'  Model: {preset_config["model_type"]}')
        print(f'  Dataset: {preset_config["dataset"]}')
        print(f'  Sparsity: {preset_config["sparsity"]}')
        print(f'  CH method: {preset_config.get("ch_method", "CH3")}')
        print(f'  Soft: {preset_config.get("soft", False)}')
    
    print('\n' + '=' * 50)
    print('Valid override parameters:')
    valid_params = get_valid_parameters()
    print(', '.join(sorted(valid_params)))


def create_parser() -> argparse.ArgumentParser:
    """Create command line argument parser"""
    parser = argparse.ArgumentParser(
        description='CHT-CNN Training Preset Manager',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    
    parser.add_argument(
        'preset',
        type=str,
        nargs='?',
        help='Preset configuration name'
    )
    
    parser.add_argument(
        '--gpus',
        type=str,
        default='',
        help='GPU devices string like "0123"'
    )
    
    parser.add_argument(
        '--tag',
        type=str,
        default='',
        help='Experiment tag'
    )
    
    parser.add_argument(
        '--list',
        action='store_true',
        help='List all available preset configurations'
    )
    
    parser.add_argument(
        '--override',
        type=str,
        nargs='+',
        help='Override preset parameters. Format: key=value (e.g., --override num_epochs=50 sparsity=0.7)'
    )
    
    return parser


def main():
    """Main function"""
    parser = create_parser()
    args = parser.parse_args()
    
    # If --list is specified, list all presets
    if args.list:
        list_presets()
        return
    
    # If no preset name is specified and not --list, print all preset names
    if not args.preset:
        for preset_name in PRESET_CONFIGS.keys():
            print(preset_name)
        return
    
    try:
        # Get preset configuration
        preset_config = get_preset_config(args.preset)
        
        # Apply overrides if specified
        if args.override:
            try:
                overrides = parse_override_args(args.override)
                preset_config = merge_config_with_overrides(preset_config, overrides)
                print(f'Applied overrides: {overrides}')
            except ValueError as e:
                print(f'Error parsing override arguments: {e}')
                sys.exit(1)
        
        # Build command arguments
        command_args = build_command_args(preset_config, args.gpus, args.tag)
        
        # Show command to be executed
        print('Command to be executed:')
        print(' '.join(command_args))
        print()
        
        # Execute command
        print(f'Starting experiment with preset "{args.preset}"...')
        print('=' * 60)
        
        result = subprocess.run(command_args, check=True)
        
        print('=' * 60)
        print('Experiment completed!')
        
    except ValueError as e:
        print(f'Error: {e}')
        sys.exit(1)
    except subprocess.CalledProcessError as e:
        print(f'Experiment failed: {e}')
        sys.exit(1)
    except KeyboardInterrupt:
        print('\nExperiment interrupted by user')
        sys.exit(1)


if __name__ == '__main__':
    main()
