import json
import os
import argparse
from tqdm import tqdm
from src.turtlegfx_datagen.utils.extract_python import extract_json_code_from_text
from src.turtlegfx.utils.base64img import save_base64_image

RUBRICS = {
    "geometry",
    "visual",
    "structure",
    "alignment",
    "education",
    "color",
}


def parse_responses(responses, src_file):
    """
    Function to extract JSON code from the response content and calculate the ground truth score.
    """
    parsed_responses = []

    prompt_file = responses[0]['src_file']  # JSON file
    with open(prompt_file, 'r') as file:
        prompt_file_content = json.load(file)

    for i, response in enumerate(responses):
        content = response['choices'][0]['message']['content']
        try:
            json_code = json.loads(extract_json_code_from_text(content))
            ground_truth_score = sum(int(json_code[rubric]['score']) for rubric in RUBRICS)
        except Exception as e:
            print(f"Error parsing response {response['id']}: {e}")
            json_code = None
            ground_truth_score = -1
        parsed_responses.append({
            "id"          : response['id'],
            "model"       : response['model'],
            "task_image"  : prompt_file_content[i]['params']['data']['task_image'],
            "code"        : prompt_file_content[i]['params']['data']['code'],
            "src_file"    : src_file,
            "model_output": {
                "raw"   : content,
                "parsed": json_code,
                "score" : ground_truth_score,
            }
        })
    return parsed_responses

def filter_responses(responses, filter_options):
    """
    Function to filter task-code pairs based on the top N% scores.
    """
    top_percentage = filter_options.get('top_percentage', None)
    responses.sort(key=lambda x: x['model_output']['score'], reverse=True)
    if top_percentage is not None and 0 < top_percentage <= 1:
        top_n = int(len(responses) * top_percentage)
        print(f"FILTERING_RESPONSES: Filtering top {top_percentage * 100}% responses: {top_n}/{len(responses)} responses")
        responses = responses[:top_n]
    return responses


def postprocessing(input_paths, output_path, filter_options={}, save_img=False):
    """
    Main function to parse responses, filter them, and save images.
    """
    # Load model responses
    responses = []
    parsed_responses = []
    for input_path in tqdm(input_paths, desc="Loading files and parsing responses"):
        with open(input_path, 'r') as file:
            responses_part = json.load(file)

            # Step 1: Parse the inference responses
            parsed_responses_part = parse_responses(responses_part, src_file=input_path)

            responses.extend(responses_part)
            parsed_responses.extend(parsed_responses_part)

    # Step 2: Filter task-code pairs based on filter options
    filtered_responses = filter_responses(parsed_responses, filter_options)

    # Save images
    if save_img:
        image_output_dir = os.path.join(os.path.dirname(output_path), 'images')
        os.makedirs(image_output_dir, exist_ok=True)

        for fr in filtered_responses:
            img_filename = f"{fr['id'].split('--')[0]}--{fr['model_output']['score']}--{fr['id'].split('--')[-1]}.png"
            img_path = os.path.join(image_output_dir, img_filename)
            save_base64_image(fr['task_image'], img_path)
        print(f"Images saved to: {image_output_dir}")

    # Step 3: Save the filtered task-code pairs to a file
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, 'w') as file:
        json.dump(filtered_responses, file, indent=4)

    print("----------------")
    print(f"Statistics: Total task-code pairs: {len(filtered_responses)}")
    print(f"Task-code pairs saved to: {output_path}")

    # Save statistics to a json file
    stats_path = output_path.replace('.json', '_stats.json')
    with open(stats_path, 'w') as file:
        json.dump({
            "n_responses_total": len(responses),
            "n_responses_removed": len(responses) - len(filtered_responses),
            "n_responses_remaining": len(filtered_responses),
            "top_percentage": filter_options.get('top_percentage', None),
        }, file, indent=2)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Postprocess code scoring responses.")

    parser.add_argument('--input_paths', type=str, nargs='+', help='Paths to input JSON files')
    parser.add_argument('--output_path', type=str, required=True, help='Path to output JSON file')
    parser.add_argument('--top_percentage', type=float, help='Select top N% tasks based on score')
    parser.add_argument('--save_img', action='store_true', help='Save images')

    args = parser.parse_args()

    filter_options = {'top_percentage': args.top_percentage}

    postprocessing(
        input_paths=args.input_paths,
        output_path=args.output_path,
        filter_options=filter_options,
        save_img=args.save_img
    )
