"""
Optimizer configurations for parameter sweeps and experiments
"""
import torch.optim as optim
import itertools
import os
import sys

script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(script_dir, '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)


from lion_pytorch import Lion
from HomOpt import HomM


# --- 5. Configuration function with manual LR setting ---
def get_optimizer_configurations():
    """Define parameter grids for each optimizer with manually set learning rates"""
    
    # MANUALLY SET YOUR LEARNING RATES HERE
    # Modify these lists based on your LR range test results or preferences
    configs = {
        'SGD': {
            'class': optim.SGD,
            'params': {
                'lr': [0.02, 0.03, 0.05],  # <-- Set SGD learning rates here
                'momentum': [0.9, 0.95],
                'nesterov': [False]
            }
        },
        'SGD_Nesterov': {
            'class': optim.SGD,
            'params': {
                'lr': [ 0.005, 0.01, 0.02],  # <-- Set SGD Nesterov learning rates here
                'momentum': [0.9, 0.95],
                'nesterov': [True]
            }
        },
        'Adam': {
            'class': optim.Adam,
            'params': {
                'lr': [0.0001,0.0005,0.001],  # <-- Set Adam learning rates here
                'betas': [(0.9, 0.999), (0.9, 0.99), (0.95, 0.999)]
            }
        },
        'Lion': {
            'class': Lion,
            'params': {
                'lr': [0.0001, 0.00005, 0.00001],  # <-- Set Lion learning rates here
                'betas': [(0.9, 0.99), (0.95, 0.99), (0.9, 0.999)]
            }
        },
        'HomM': {
            'class': HomM,
            'params': {
                'lr': [ 0.005,0.01, 0.02],  # <-- Set HomM learning rates here
                'alpha': [-0.75, -0.5, -0.25],
                'beta': [0.1, 0.3, 0.5, 0.7, 0.9],
                'gamma': [0.9, 0.95, 0.99]
            }
        }
    }
    
    print("Using manually configured learning rates:")
    for opt_name, config in configs.items():
        print(f"  {opt_name}: {config['params']['lr']}")
    print()
    
    return configs

def generate_param_combinations(config):
    """Generate all combinations of parameters for a given optimizer config"""
    param_names = list(config['params'].keys())
    param_values = list(config['params'].values())
    
    combinations = []
    for combo in itertools.product(*param_values):
        param_dict = dict(zip(param_names, combo))
        combinations.append(param_dict)
    
    return combinations



def get_default_optimizer_params():
    """Get single best/default parameters for each optimizer (for quick experiments)"""
    defaults = {
        'SGD': {'lr': 0.01, 'momentum': 0.9, 'nesterov': False},
        'SGD_Nesterov': {'lr': 0.01, 'momentum': 0.9, 'nesterov': True},
        'Adam': {'lr': 0.001, 'betas': (0.9, 0.999)},
        'Lion': {'lr': 0.0001, 'betas': (0.9, 0.99)},
        'HomM': {'lr': 0.005, 'alpha': -0.5, 'beta': 0.5, 'gamma': 0.9}
    }
    
    return defaults


