import os
import sys
import yaml
import copy
import numpy as np
import argparse
from itertools import product
import subprocess

yaml_template = {
    'gamma': None,
    'server_params': {},
    'worker_params': {}
}

run_template = 'source setup.sh && python3 -m simulator.run_experiment --config {config_name} --save_path {dumps_path}/{experiments_name}/results/{result_name}'

def generate_yaml_configs(yaml_config, output_path, source_yaml_path):
    """Generate individual YAML configs from a master config file."""
    experiment_name = yaml_config.get('experiment_name', 'experiment')
    sim_time = yaml_config.get('sim_time')
    metric_check_num = yaml_config.get('metric_check_num')
    batch_size = yaml_config.get('batch_size')
    optimizer = yaml_config.get('optimizer')
    num_workers = yaml_config.get('num_workers')
    times_to_calculate = yaml_config.get('times_to_calculate')
    times_to_communicate = yaml_config.get('times_to_communicate')

    # Create directory structure
    dumps_path = os.path.abspath(output_path)
    experiments_dir = os.path.join(dumps_path, experiment_name)
    
    if os.path.exists(experiments_dir):
        raise FileExistsError(f"Experiment directory already exists: {experiments_dir}")
    
    os.makedirs(experiments_dir)
    
    source_dir = os.path.join(experiments_dir, 'source_folder')
    if not os.path.exists(source_dir):
        os.makedirs(source_dir)
    
    run_dir = os.path.join(experiments_dir, 'run_folder')
    if not os.path.exists(run_dir):
        os.makedirs(run_dir)
    
    results_dir = os.path.join(experiments_dir, 'results')
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    
    # Store command that was used to generate configs
    with open(os.path.join(experiments_dir, 'cmd.txt'), 'w') as f:
        f.write(" ".join(sys.argv))
    
    # Store the original YAML config for reference
    subprocess.run(f"cp {source_yaml_path} {experiments_dir}/", shell=True, check=True)
    print(f"Copied original YAML config to {experiments_dir}")
    
    # Initialize random number generator for reproducibility
    generator = np.random.default_rng(seed=42)
    
    # Generate configurations
    index = 0
    for config in yaml_config.get('configs', []):
        algorithm = config.get('algorithm')
        step_size_range = config.get('step_size_range')
        server_params = config.get('server_params', {})
        worker_params = config.get('worker_params', {})
        
        # Extract list parameters and non-list parameters for server
        server_list_params = {}
        server_non_list_params = {}
        
        for param_name, param_value in server_params.items():
            if param_name.endswith('_list') and isinstance(param_value, list):
                # Remove '_list' suffix for the actual parameter name
                actual_param_name = param_name[:-5]
                server_list_params[actual_param_name] = param_value
            else:
                server_non_list_params[param_name] = param_value
        
        # Extract list parameters and non-list parameters for worker
        worker_list_params = {}
        worker_non_list_params = {}
        
        for param_name, param_value in worker_params.items():
            if param_name.endswith('_list') and isinstance(param_value, list):
                # Remove '_list' suffix for the actual parameter name
                actual_param_name = param_name[:-5]
                worker_list_params[actual_param_name] = param_value
            else:
                worker_non_list_params[param_name] = param_value
        
        # Generate gamma values from step_size_range
        gamma_values = [2**i for i in range(step_size_range[0], step_size_range[1])]
        
        # Generate all combinations of server list parameters
        server_param_names = list(server_list_params.keys())
        if server_param_names:
            server_param_values = [server_list_params[param] for param in server_param_names]
            server_param_combinations = list(product(*server_param_values))
        else:
            server_param_combinations = [()]
        
        # Generate all combinations of worker list parameters
        worker_param_names = list(worker_list_params.keys())
        if worker_param_names:
            worker_param_values = [worker_list_params[param] for param in worker_param_names]
            worker_param_combinations = list(product(*worker_param_values))
        else:
            worker_param_combinations = [()]
        
        # Generate configs for each combination of gamma, server parameters and worker parameters
        for gamma in gamma_values:
            for server_param_combination in server_param_combinations:
                for worker_param_combination in worker_param_combinations:
                    yaml_prepared = copy.deepcopy(yaml_template)
                    yaml_prepared['gamma'] = gamma
                    yaml_prepared['server'] = algorithm
                    yaml_prepared['sim_time'] = sim_time
                    yaml_prepared['metric_check_num'] = metric_check_num
                    yaml_prepared['num_workers'] = num_workers
                    yaml_prepared['batch_size'] = batch_size
                    yaml_prepared['optimizer'] = optimizer
                    
                    # Set times_to_calculate and times_to_communicate
                    # Randomly choose from the list for each worker
                    if times_to_calculate == "sqrt":
                        yaml_prepared['times_to_calculate'] = np.sqrt(np.array(range(num_workers)) + 1).tolist()
                    else:
                        yaml_prepared['times_to_calculate'] = generator.choice(times_to_calculate, (num_workers,)).tolist()
                    yaml_prepared['times_to_communicate'] = generator.choice(times_to_communicate, (num_workers,)).tolist()
                    
                    # Set server non-list parameters
                    for param_name, param_value in server_non_list_params.items():
                        yaml_prepared['server_params'][param_name] = param_value
                    
                    # Set server list parameters with their specific values from the current combination
                    for i, param_name in enumerate(server_param_names):
                        if i < len(server_param_combination):
                            yaml_prepared['server_params'][param_name] = server_param_combination[i]
                    
                    # Initialize worker_params if not already present
                    if 'worker_params' not in yaml_prepared:
                        yaml_prepared['worker_params'] = {}
                        
                    if algorithm in ['ringmaster_sgd_compcomm', 'local_sgd', 'subset_ring_reduce']:
                        yaml_prepared['worker_params']['gamma'] = gamma
                    
                    # Set worker non-list parameters
                    for param_name, param_value in worker_non_list_params.items():
                        yaml_prepared['worker_params'][param_name] = param_value
                    
                    # Set worker list parameters with their specific values from the current combination
                    for i, param_name in enumerate(worker_param_names):
                        if i < len(worker_param_combination):
                            yaml_prepared['worker_params'][param_name] = worker_param_combination[i]
                    
                    # Save configuration
                    config_name = os.path.join(source_dir, f'config_{index}_{algorithm}_{gamma}.yaml')
                    with open(config_name, 'w') as fd:
                        yaml.dump(yaml_prepared, fd, default_flow_style=None)
                
                    # Generate run script
                    sh_name = os.path.join(run_dir, f'run_{index}.sh')
                    with open(sh_name, 'w') as fd:
                        fd.write(run_template.format(
                            dumps_path=dumps_path,
                            experiments_name=experiment_name,
                            config_name=config_name,
                            result_name=f'config_{index}.json'
                        ))
                    
                    # Set executable permissions for the run script
                    os.chmod(sh_name, 0o755)
                    
                    index += 1
    
    print(f"Generated {index} configuration files in {source_dir}")
    print(f"Generated {index} run scripts in {run_dir}")
    print(f"Experiment directory: {experiments_dir}")
    return index

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Generate experiment configurations from YAML file')
    parser.add_argument('--config', required=True, help='Path to the YAML configuration file')
    parser.add_argument('--output', required=True, help='Output directory for generated files')
    args = parser.parse_args()
    
    # Load YAML configuration
    with open(args.config, 'r') as f:
        yaml_config = yaml.safe_load(f)
    
    # Generate configurations
    num_configs = generate_yaml_configs(yaml_config, args.output, args.config)
    print(f"Successfully generated {num_configs} configurations")
