"""
Generate optimizer configurations from LR test results
"""
import json
import os
import itertools

def extract_all_optimal_lrs(results_dir):
    """Extract optimal learning rates from ALL combination LR test results"""
    summary_file = os.path.join(results_dir, 'lr_range_test_summary.json')
    
    if not os.path.exists(summary_file):
        print(f"Summary file not found: {summary_file}")
        return {}, []
    
    with open(summary_file, 'r') as f:
        data = json.load(f)
    
    optimal_lrs = {}
    successful_configs = []
    
    print("Optimal learning rates found:")
    print("-" * 60)
    for result in data['results']:
        if 'recommended_lr' in result:
            config_name = result['config_name']
            lr = result['recommended_lr']
            alpha = result['alpha']
            beta = result['beta']
            gamma = result['gamma']
            
            optimal_lrs[config_name] = lr
            successful_configs.append(result)
            print(f"{config_name:<25} LR={lr:.4f}")
    
    print(f"\nFound {len(successful_configs)} successful configurations")
    return optimal_lrs, successful_configs

def generate_optimizer_configs_code(successful_configs):
    """Generate the optimizer_configs.py code using all successful configurations"""
    
    # Group by parameter values for easier organization
    config_lines = []
    
    for i, config in enumerate(successful_configs):
        alpha = config['alpha']
        beta = config['beta'] 
        gamma = config['gamma']
        lr = config['recommended_lr']
        config_name = config['config_name']
        
        config_code = f"""        '{config_name}': {{
            'class': HomM,
            'params': {{
                'lr': [{lr:.6f}],
                'alpha': [{alpha}],
                'beta': [{beta}],
                'gamma': [{gamma}]
            }}
        }}"""
        
        config_lines.append(config_code)
    
    # Fix the f-string backslash issue by using separate variables
    newline = "\n"
    configs_joined = f",{newline}".join(config_lines)
    
    # Generate the full file content
    full_code = f"""\"\"\"
Optimizer configurations with optimal learning rates from LR range tests
Auto-generated from LR test results
\"\"\"
import torch.optim as optim
import itertools
from HomOpt import HomM

def get_optimizer_configurations():
    \"\"\"Define all HomM configurations with their optimal learning rates\"\"\"
    
    configs = {{
{configs_joined}
    }}
    
    print(f"Loaded {{len(configs)}} HomM configurations with optimal learning rates:")
    for opt_name, config in configs.items():
        params = config['params']
        lr = params['lr'][0]
        alpha = params['alpha'][0]
        beta = params['beta'][0]
        gamma = params['gamma'][0]
        print(f"  {{opt_name}}: LR={{lr:.4f}}, a={{alpha}}, b={{beta}}, c={{gamma}}")
    print()
    
    return configs

def generate_param_combinations(config):
    \"\"\"Generate all combinations of parameters for a given optimizer config\"\"\"
    param_names = list(config['params'].keys())
    param_values = list(config['params'].values())
    
    combinations = []
    for combo in itertools.product(*param_values):
        param_dict = dict(zip(param_names, combo))
        combinations.append(param_dict)
    
    return combinations

def get_best_configs(top_n=5):
    \"\"\"Get the top N configurations by learning rate (higher LR often indicates better optimization)\"\"\"
    configs = get_optimizer_configurations()
    
    # Sort by learning rate (descending)
    sorted_configs = sorted(configs.items(), 
                          key=lambda x: x[1]['params']['lr'][0], 
                          reverse=True)
    
    top_configs = dict(sorted_configs[:top_n])
    
    print(f"Top {{top_n}} configurations by learning rate:")
    for name, config in top_configs.items():
        lr = config['params']['lr'][0]
        print(f"  {{name}}: LR={{lr:.4f}}")
    
    return top_configs
"""
    
    return full_code

def main():
    """Main function to process LR test results and generate configs"""
    
    # Get the most recent results directory
    result_dirs = [d for d in os.listdir('.') if d.startswith('lr_range_tests_all_combos_')]
    
    if not result_dirs:
        print("No LR test results found!")
        print("Run your LR test script first.")
        return
    
    # Use the most recent directory
    results_dir = sorted(result_dirs)[-1]
    print(f"Using results from: {results_dir}")
    
    # Extract optimal LRs
    optimal_lrs, successful_configs = extract_all_optimal_lrs(results_dir)
    
    if not successful_configs:
        print("No successful configurations found!")
        return
    
    # Generate optimizer configs code
    config_code = generate_optimizer_configs_code(successful_configs)
    
    # Save to file with UTF-8 encoding to prevent character corruption
    output_file = "optimizer_configs_auto_generated.py"
    with open(output_file, 'w', encoding='utf-8') as f:
        f.write(config_code)
    
    print(f"\nGenerated optimizer configurations saved to: {output_file}")
    print(f"Total configurations: {len(successful_configs)}")
    
    # Show summary statistics
    lrs = [config['recommended_lr'] for config in successful_configs]
    print(f"\nLearning rate statistics:")
    print(f"  Min LR: {min(lrs):.6f}")
    print(f"  Max LR: {max(lrs):.6f}")
    print(f"  Mean LR: {sum(lrs)/len(lrs):.6f}")

if __name__ == "__main__":
    main()