import json
import sys
import re

def get_file_map(generator_name):
    if generator_name == 'qwen_rl':  # == qwen_inst
        return {
            'qwen_rl/Qwen-base_final_metrics_rl.json': 'Base',
            'qwen_rl/Qwen-math_final_metrics_rl.json': 'Math',
            'qwen_rl/Qwen-Inst_final_metrics_rl.json': 'Inst',
            'qwen_rl/Qwen-Distill_final_metrics_rl.json': 'Distill',
        }
    elif generator_name == 'qwen_math':
        return {
            'qwen_math/Qwen-base_final_metrics_math.json': 'Base',
            'qwen_math/Qwen-math_final_metrics_math.json': 'Math',
            'qwen_math/Qwen-Inst_final_metrics_math.json': 'Inst',
            'qwen_math/Qwen-Distill_final_metrics_math.json': 'Distill',
        }
    elif generator_name == 'deepseek_rl':
        return {
            'deepseek_rl/Deepseek-base_final_metrics_rl.json': 'Base',
            'deepseek_rl/Deepseek-math_final_metrics_rl.json': 'Math',
            'deepseek_rl/Deepseek-Inst_final_metrics_rl.json': 'Inst',
            'deepseek_rl/Deepseek-rl_final_metrics_rl.json': 'RL'
        }
    elif generator_name == 'deepseek_math':
        return {
            'deepseek_math/Deepseek-base_final_metrics_math.json': 'Base',
            'deepseek_math/Deepseek-math_final_metrics_math.json': 'Math',
            'deepseek_math/Deepseek-Inst_final_metrics_math.json': 'Inst',
            'deepseek_math/Deepseek-rl_final_metrics_math.json': 'RL'
        }
    else:
        raise ValueError(f"Unknown generator name: {generator_name}")
    
def extract_json(file_map, response_type='all'):
    if response_type not in ['all', 'typical', 'creative', 'hallucinated']:
        raise ValueError("response_type must be one of 'all', 'typical', 'creative', or 'hallucinated'.")
    else:
        response_type = response_type.capitalize()

    print(f"{list(file_map.keys())[0].split("/")[0]} model generation\n")

    data_by_model = {}

    for filename, model_label in file_map.items():
        with open(filename, 'r', encoding='utf-8') as f:
            data = json.load(f)

        if response_type == 'All':
            d = [item for item in data]
        else:
            d = [item for item in data if item.get('evaluation', {}).get('final_decision', None) == response_type + "_Solution" and item.get('token_entropy_info', []) != []]

        data_by_model[model_label] = d        

    print(f"# of {response_type} Solutions: {len(data_by_model['Base'])}")

    return data_by_model

def tex_to_text(tex_str):
    # 1. Remove LaTeX math delimiters including multiline \[ ... \]
    text = re.sub(r'\\\[(.*?)\\\]', r'\1', tex_str, flags=re.DOTALL)
    text = re.sub(r'\\\((.*?)\\\)', r'\1', text)
    text = re.sub(r'\$([^\$]+)\$', r'\1', text)

    # 2. Replace common LaTeX math commands
    text = re.sub(r'\\frac\{([^}]+)\}\{([^}]+)\}', r'(\1/\2)', text)
    text = re.sub(r'\\times', '×', text)
    text = re.sub(r'\\cdot', '*', text)
    text = re.sub(r'\\leq', '<=', text)
    text = re.sub(r'\\geq', '>=', text)
    text = re.sub(r'\\neq', '!=', text)
    text = re.sub(r'\\equiv', '≡', text)
    text = re.sub(r'\\pmod\{([^}]+)\}', r'mod \1', text)

    # 3. Remove unnecessary LaTeX commands
    text = re.sub(r'\\left|\\right', '', text)
    text = re.sub(r'\\text', '', text)
    text = re.sub(r'\\[a-zA-Z]+\s*', '', text)  # remove other LaTeX commands

    # 4. Remove curly braces
    text = text.replace('{', '').replace('}', '')

    # 5. Clean up whitespace
    text = re.sub(r'\s+', ' ', text).strip()

    return text

if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("use: python utils.py [generator_name] [response_type]")
        print("example: python utils.py qwen_math all")
        sys.exit(1)

    generator = sys.argv[1]
    response_type = sys.argv[2] if len(sys.argv) > 2 else 'all'

    file_map = get_file_map(generator)
    print(f"✅ Generator: {generator}, Response Type: {response_type}")
    for path, model_name in file_map.items():
        print(f"   - {model_name}: {path}")

    data_by_model = extract_json(file_map, response_type)
