import json
import yaml
from pathlib import Path

TRAIN_DATA_PATH='/home/toolkit/data/train'
BEST_SEQ_LENGTH_FILE='assets/best_seq_length.json'
TRAIN_CONFIG_FILE_DIR='train_configs'

BS_EXCEPTIONS = {
    "lima": 2,
    "sharegpt": 2
}


def read_config(config_file):
    with Path(config_file).open('r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    return config

def read_best_seq_length(best_seq_length_file):
    with Path(best_seq_length_file).open('r') as f:
        best_seq_length = json.load(f)
        
    stats = {}
    for s in best_seq_length:
        stats[s['task']] = s['best_seq_length'] 
        
    return stats

def get_base_configs(train_config_folder, prefix):
        with config_file.open('r') as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        yield config

def get_task_names(train_data_path):
    tasks = []
    for data_dir in Path(train_data_path).iterdir():
        task_name = data_dir.name
        tasks.append(task_name)
    return tasks

def main():
    model_name = "meta_llama_Llama_2_13b_hf"
    stats = read_best_seq_length(BEST_SEQ_LENGTH_FILE)
    task_names = get_task_names(TRAIN_DATA_PATH)

    baseline_configs = {} 
    baseline_config_file = Path(TRAIN_CONFIG_FILE_DIR) / "_baseline_base.yaml"
    baseline_config = read_config(baseline_config_file)

    # Baseline
    for task_name in task_names:
        config = baseline_config.copy()
        if task_name not in stats:
            print(f"Task {task_name} not in best seq length file. Skipping...")
            continue


        seq_length = stats[task_name]
        batch_size = BS_EXCEPTIONS.get(task_name, 4)  
        eval_batch_size = BS_EXCEPTIONS.get(task_name, 4) * 2
        config['model_name_or_path'] = model_name
        config['train_file'] = task_name
        config['max_seq_length'] = seq_length
        config['per_device_train_batch_size'] = batch_size
        config['per_device_eval_batch_size'] = eval_batch_size
            
        output_file = Path(TRAIN_CONFIG_FILE_DIR) / f"baseline_{task_name}_13b.yaml" 
        print(f"Writing config for {task_name} to {output_file}")
        with output_file.open('w') as f:
            yaml.dump(config, f)
            
        baseline_configs[task_name] = config.copy()
            
    # Spaced sampling
    for config_file in Path(TRAIN_CONFIG_FILE_DIR).glob(f"_ss*_base.yaml"):
        ss_config_name = config_file.name[1:].replace('_base.yaml', '')
        for task_name in task_names:
            config = baseline_configs[task_name].copy()
            ss_config = read_config(config_file)
            config.update(ss_config)

            output_file = Path(TRAIN_CONFIG_FILE_DIR) / f"{ss_config_name}_{task_name}_13b.yaml" 
            print(f"Writing config for {task_name} to {output_file}")
            with output_file.open('w') as f:
                yaml.dump(config, f)
            
            
            
if __name__ == '__main__':
    main()