import os
import json
from collections import defaultdict
import argparse
from collections import defaultdict, Counter
import random

def balance_dataset(data, max_diff):
    # Step 1: group by answer_type
    answer_type_groups = defaultdict(list)
    for item in data:
        answer_type_groups[item["answer_type"]].append(item)

    # Find the target count to balance answer_type
    min_answer_count = min(len(items)+max_diff for items in answer_type_groups.values())

    # Step 2: balance within each answer_type
    balanced_data = []
    for answer_type, items in answer_type_groups.items():
        # group by question_type within this answer_type
        question_type_groups = defaultdict(list)
        for item in items:
            question_type_groups[item["question_type"]].append(item)

        # Determine how many samples we can draw for each question_type
        # Total should be min_answer_count, spread as evenly as possible
        question_types = list(question_type_groups.keys())
        per_qtype = min_answer_count // len(question_types)
        remainder = min_answer_count % len(question_types)

        subset = []
        for qtype in question_types:
            group = question_type_groups[qtype]
            count = per_qtype + (1 if remainder > 0 else 0)
            remainder -= 1 if remainder > 0 else 0
            sampled = random.sample(group, min(len(group), count))
            subset.extend(sampled)

        # If subset is still too small, randomly fill in from remaining
        if len(subset) < min_answer_count:
            remaining = [item for group in question_type_groups.values() for item in group if item not in subset]
            extra_needed = min_answer_count - len(subset)
            subset.extend(random.sample(remaining, min(len(remaining), extra_needed)))

        balanced_data.extend(subset)

    return balanced_data

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--file_path", type=str, required=True
    )
    parser.add_argument(
        "--max_diff", type=int, required=True
    )
    args = parser.parse_args()

    # Output file
    output_file = args.file_path.replace(".json", '_balanced.json')

    # List of JSON file paths
    json_files = [args.file_path]

    # List to hold all dictionaries
    merged_data = []

    # Iterate over the provided file paths
    data = []
    for file_path in json_files:
        with open(file_path, 'r') as file:
            data += list(file)

    # Count
    answer_type_count = defaultdict(int)
    question_type_count = defaultdict(int)
    for element in data:
        element = json.loads(element)
        answer_type_count[element["answer_type"]] += 1
        question_type_count[element["question_type"]] += 1

    data = [json.loads(d) for d in data]
    data = balance_dataset(data, args.max_diff)

    # Write the merged data to the output file
    with open(output_file, 'w') as file:
        for element in data:
            file.write(json.dumps(element) + "\n")
            file.flush()

    print(f'Merged {len(data)} entries into {output_file}')
