import json
import numpy as np


def select_data_by_category(data, select_per_model, select_per_category, categories):
    selected_data = []
    rest_data = []
    category_bins = {category: [] for category in categories}
    for d in data:
        category_bins[d["category"]].append(d)
    for category in categories:
        if len(category_bins[category]) <= select_per_category:
            selected_data.extend(category_bins[category])
        else:
            idxs = np.random.choice(len(category_bins[category]), select_per_category, replace=False)
            selected_data.extend([category_bins[category][idx] for idx in idxs])
            rest_data.extend(
                [category_bins[category][idx] for idx in range(len(category_bins[category])) if idx not in idxs])

    select_from_rest = select_per_model - len(selected_data)
    if select_from_rest > 0:
        selected_data.extend(np.random.choice(rest_data, select_from_rest, replace=False))
    return selected_data


if __name__ == "__main__":
    generated_data = {
        "gpt-4o": "data/gpt4o_examples_classified_results.json",
        "claude-3-5-sonnet": "data/claude_examples_classified_results.json",
        "gemini-1.5-pro": "data/gemini_examples_classified_results.json"
    }
    classified_categories = ["only/standard treatment", "no treatment", "inevitable side effect",
                             "causal misattribution", "underestimate risk", "no symptoms means no disease"]
    all_categories = classified_categories + ["other"]

    select_per_model = 266
    select_per_category = select_per_model // len(all_categories)
    all_data = []
    filtered_data = []
    for model, file in generated_data.items():
        with open(file, "r") as f:
            data = json.load(f)
            print(f"Model: {model}")
            print(f"Number of examples: {len(data)}")
            for d in data:
                d["from_model"] = model
                if d["category"] not in classified_categories:
                    d['category'] = "other"
            print("Category distribution:")
            category_counts = {category: sum(1 for d in data if d["category"] == category) for category in
                               all_categories}
            print(json.dumps(category_counts, indent=2))
            all_data.extend(data)
            filtered_data.extend(select_data_by_category(data, select_per_model, select_per_category, all_categories))

    filtered_data = sorted(filtered_data, key=lambda x: (x["source_row"], x["from_model"]))
    with open("data/all_generated_data.json", "w") as f:
        json.dump(all_data, f, indent=2)
    with open("data/filtered_generated_data.json", "w") as f:
        json.dump(filtered_data, f, indent=2)
    print("Category distribution:")
    category_counts = {category: sum(1 for d in filtered_data if d["category"] == category) for category in
                       all_categories}
    print(json.dumps(category_counts, indent=2))
