import json
from collections import defaultdict

def create_latex_table(json_file_path):
    # Read the JSON file
    with open(json_file_path, 'r') as f:
        data = json.load(f)
    
    # Start LaTeX table
    latex_output = [
        "\\begin{table}[t]",
        "\\centering",
        "\\resizebox{\\textwidth}{!}{",
        "\\begin{tabular}{l|ccc|ccc|ccc}",
        "\\toprule",
        "& \\multicolumn{3}{c|}{FARE} & \\multicolumn{3}{c}{LORE} & \\multicolumn{3}{c}{CLIP} \\\\",
        "Dataset & Acc@1 & Acc@5 & Mean Recall & Acc@1 & Acc@5 & Mean Recall & Acc@1 & Acc@5 & Mean Recall \\\\",
        "\\midrule"
    ]
    
    # Process datasets
    all_datasets = set()
    for exp_type in ['FARE', 'LORE', 'CLIP']:
        if exp_type in data:
            all_datasets.update(data[exp_type].keys())
    
    # Sort datasets for consistent ordering
    sorted_datasets = sorted(all_datasets)
    
    # Create rows
    for dataset in sorted_datasets:
        row_data = []
        
        # Format dataset name for LaTeX
        dataset_name = dataset.replace('_', '\\_').split('/')[-1]
        row_data.append(dataset_name)
        
        # Add metrics for each experiment type
        for exp_type in ['FARE', 'LORE', 'CLIP']:
            if exp_type in data and dataset in data[exp_type]:
                metrics = data[exp_type][dataset][0]['metrics']  # Taking first result
                row_data.extend([
                    f"{metrics['acc1']*100:.1f}",
                    f"{metrics['acc5']*100:.1f}",
                    f"{metrics['mean_per_class_recall']*100:.1f}"
                ])
            else:
                row_data.extend(['-', '-', '-'])
        
        latex_output.append(' & '.join(row_data) + ' \\\\')
    
    # End LaTeX table
    latex_output.extend([
        "\\bottomrule",
        "\\end{tabular}}",
        "\\caption{Comparison of FARE and LORE performance across different datasets.}",
        "\\label{tab:fare-lore-comparison}",
        "\\end{table}"
    ])
    
    return '\n'.join(latex_output)

def save_latex_table(latex_content, output_file="results/results_table.tex"):
    with open(output_file, 'w') as f:
        f.write(latex_content)

if __name__ == "__main__":
    # Generate and save the LaTeX table
    latex_content = create_latex_table("results/combined_results.json")
    save_latex_table(latex_content)
    print("LaTeX table has been generated and saved to results_table.tex")
    print("\nTable preview:")
    print(latex_content)