import argparse
import os
import json
import re
from pathlib import Path

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--output', type=str, default='../output', help='Directory containing performance results')
    parser.add_argument('--task', type=str, default='metamath', help='Target task name')
    parser.add_argument('--model', type=str, default='Llama-2-7b', help='Target model name')
    parser.add_argument('--save_path', type=str, default='./', help='Directory to save processed JSONs')
    return parser.parse_args()

def extract_hyperparams_from_folder_name(folder_name):
    """Extract batch size, learning rate, and trial number from subfolder strings."""
    hyperparam = {}
    trial_num = 1
    
    bs_match = re.search(r'bs(\d+)', folder_name)
    if bs_match: hyperparam['batch_size'] = int(bs_match.group(1))
    
    lr_match = re.search(r'lr([\d.]+(?:e[+-]?\d+)?)', folder_name)
    if lr_match: hyperparam['lr'] = float(lr_match.group(1))
    
    trial_match = re.search(r'trial(\d+)', folder_name)
    if trial_match: trial_num = int(trial_match.group(1))
    
    return hyperparam, trial_num

def extract_info_from_exp_folder(folder_name, task, model):
    """
    Extract method, rank, and task name from experiment folder.
    Supports: python:1000-ep3-LoRA-qwen-3-0.6b-r8
    """
    rank_match = re.search(r"-r(\d+)", folder_name)
    rank = int(rank_match.group(1)) if rank_match else None

    # Regex to identify method based on model name and rank suffix
    method_match = re.search(r"-([A-Za-z]+)-" + re.escape(model) + r"-r\d+", folder_name)
    method = method_match.group(1).lower() if method_match else None

    # Task name is the prefix before the method
    if method_match:
        task_name = folder_name.split(f"-{method_match.group(1)}-")[0]
    else:
        task_name = task 

    return method, rank, task_name

def generate_json_filename(task, model, rank, batch_size):
    model_clean = model.lower().replace('-', '_')
    return f"{task}_{model_clean}_rank{rank}_bs{batch_size}.json"

def load_existing_data(filepath):
    if os.path.exists(filepath):
        with open(filepath, 'r', encoding='utf-8') as f:
            return json.load(f)
    return {}

def save_data(filepath, data):
    os.makedirs(os.path.dirname(filepath) if os.path.dirname(filepath) else '.', exist_ok=True)
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=2, ensure_ascii=False)

def main():
    args = parse_arguments()
    output_dir = Path(args.output)    
    if not output_dir.exists():
        print(f"Output directory not found: {output_dir}")
        return
    
    experiment_groups = {}
    
    # Iterate through experiment directories
    for exp_folder in output_dir.iterdir():
        if not exp_folder.is_dir(): continue
        
        method, rank, task_name = extract_info_from_exp_folder(exp_folder.name, args.task, args.model)
        
        # Filter by the specified task
        if task_name != args.task: continue
        if method is None or rank is None:
            print(f"Skipping invalid folder: {exp_folder.name}")
            continue
        
        print(f"Processing: {exp_folder.name} (Method: {method}, Rank: {rank})")
        
        # Iterate through hyperparameter subfolders (e.g., bs16-lr2e-5-trial1)
        for sub_folder in exp_folder.iterdir():
            if not sub_folder.is_dir(): continue
            
            hyperparam, trial_num = extract_hyperparams_from_folder_name(sub_folder.name)
            if not hyperparam: continue
            
            perf_file = sub_folder / 'perf.json'
            if not perf_file.exists(): continue
            
            try:
                with open(perf_file, 'r', encoding='utf-8') as f:
                    perf_data = json.load(f)
                
                hyperparam.update({'rank': rank, 'epoch': 1})
                
                record = {
                    "hyparam": hyperparam,
                    f"acc-{trial_num}": perf_data
                }
                
                # Group data by configuration
                batch_size = hyperparam.get('batch_size', 16)
                group_key = (task_name, args.model, rank, batch_size)
                
                if group_key not in experiment_groups:
                    experiment_groups[group_key] = {}
                if method not in experiment_groups[group_key]:
                    experiment_groups[group_key][method] = []
                
                experiment_groups[group_key][method].append(record)
                
            except Exception as e:
                print(f"Error reading {perf_file}: {e}")
    
    # Save grouped results into specific JSON files
    for (task_name, model, rank, batch_size), methods_data in experiment_groups.items():
        json_filename = generate_json_filename(task_name, model, rank, batch_size)
        json_filepath = os.path.join(args.save_path, json_filename)
        
        data = load_existing_data(json_filepath)
        if task_name not in data: data[task_name] = {}
        
        for method, records in methods_data.items():
            data[task_name][method] = records
        
        save_data(json_filepath, data)
        print(f"Saved/Updated: {json_filepath}")

if __name__ == '__main__':
    main()