import json
import os
import argparse
import logging
from pebble import ProcessPool
from concurrent.futures import TimeoutError

from tqdm import tqdm
from src.turtlegfx_datagen.utils.extract_python import extract_python_code_from_text
from src.turtlegfx.utils.base64img import save_base64_image
from src.turtlegfx_datagen.utils.img_utils import get_image_filename_from_id
from src.turtlegfx.emulate.emulator import _turtle_worker

# Set up logging
logging.basicConfig(level=logging.INFO)

ERROR_UNKNOWN = 'ERROR_UNKNOWN'

def process_response(response, src_file):
    """
    Process a single response. Extracts Python code, executes it, and returns the result.
    """
    response_id = response['id']
    content = response['choices'][0]['message']['content']
    python_code = extract_python_code_from_text(content)

    result = {
        "id": response_id,
        "model": response['model'],
        "task_image": None,
        "code": python_code,
        "src_file": src_file,
        "model_output": {
            "raw": content,
            "parsed": python_code,
            "exec": {
                "status": "fail",
                "error_type": ERROR_UNKNOWN,
                "message": "Unknown error."
            },
        }
    }

    try:
        worker_result = _turtle_worker(python_code, show_screen=False)
        if worker_result["status"] == "success":
            base64image = worker_result.pop('image')
            result['task_image'] = base64image
            result['model_output']['exec'] = worker_result
        else:
            result['model_output']['exec'] = worker_result
    except Exception as e:
        result['model_output']['exec'] = {
            "status": "fail",
            "error_type": "exception",
            "message": str(e)
        }
    return result 

def process_responses(responses, src_file, timeout=10, max_workers=128):
    results = []
    with ProcessPool(max_workers=max_workers, max_tasks=20) as pool:
        future_to_response = {}
        for response in responses:
            future = pool.schedule(process_response, args=(response, src_file), timeout=timeout)
            future_to_response[future] = response

        for future in tqdm(future_to_response, total=len(future_to_response), desc="Processing responses"):
            response = future_to_response[future]
            try:
                result = future.result()
                results.append(result)
            except TimeoutError:
                # Handle timeout
                results.append({
                    "id": response['id'],
                    "model": response['model'],
                    "task_image": None,
                    "code": extract_python_code_from_text(response['choices'][0]['message']['content']),
                    "src_file": src_file,
                    "model_output": {
                        "raw": response['choices'][0]['message']['content'],
                        "parsed": extract_python_code_from_text(response['choices'][0]['message']['content']),
                        "exec": {
                            "status": "fail",
                            "error_type": "TIMEOUT",
                            "message": f"Execution timed out after {timeout} seconds."
                        },
                    }
                })
            except Exception as e:
                # Handle other exceptions
                results.append({
                    "id": response['id'],
                    "model": response['model'],
                    "task_image": None,
                    "code": extract_python_code_from_text(response['choices'][0]['message']['content']),
                    "src_file": src_file,
                    "model_output": {
                        "raw": response['choices'][0]['message']['content'],
                        "parsed": extract_python_code_from_text(response['choices'][0]['message']['content']),
                        "exec": {
                            "status": "fail",
                            "error_type": "EXCEPTION",
                            "message": str(e)
                        },
                    }
                })

    return results

def filter_responses(responses, filter_options):
    """
    Filters task-code pairs based on execution status.
    """
    output_type = filter_options.get('output_type', None)
    if output_type and output_type.lower() == 'success':
        responses = [r for r in responses if r['model_output']['exec']['status'] == 'success']
    elif output_type and output_type.lower() == 'fail':
        responses = [r for r in responses if r['model_output']['exec']['status'] != 'success']
    return responses

def postprocessing(input_path, output_path, filter_options={}, save_img=False):
    """
    Main function to parse responses, filter them, and save images.
    """
    # Load model responses
    with open(input_path, 'r') as file:
        responses = json.load(file)

    # Set timeout and max_workers
    timeout = 10  # seconds
    max_workers = filter_options.get('max_workers', 128)
    logging.info(f"Using {max_workers} workers")

    # Step 1: Parse the inference responses
    parsed_responses = process_responses(responses, src_file=input_path, timeout=timeout, max_workers=max_workers)

    # Step 2: Filter task-code pairs based on filter options
    filtered_responses = filter_responses(parsed_responses, filter_options)

    # Save images
    if save_img:
        image_output_dir = os.path.join(os.path.dirname(output_path), 'images')
        os.makedirs(image_output_dir, exist_ok=True)

        for fr in filtered_responses:
            if fr['model_output']['exec']['status'] == 'success' and fr['task_image']:
                img_filename = get_image_filename_from_id(fr['id'])
                img_path = os.path.join(image_output_dir, img_filename)
                save_base64_image(fr['task_image'], img_path)
        print(f"Images saved to: {image_output_dir}")

    # Step 3: Save the filtered task-code pairs to a file
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, 'w') as file:
        json.dump(filtered_responses, file, indent=4)

    print("----------------")
    print(f"Statistics: Total task-code pairs: {len(filtered_responses)}")
    print(f"Task-code pairs saved to: {output_path}")

    # Save statistics to a json file
    stats_path = output_path.replace('.json', '_stats.json')
    with open(stats_path, 'w') as file:
        json.dump({
            "n_responses_total": len(responses),
            "n_responses_removed": len(responses) - len(filtered_responses),
            "n_responses_remaining": len(filtered_responses),
            "filter_output_type": filter_options.get('output_type', None),
        }, file, indent=2)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Postprocess code generation responses.")
    parser.add_argument('--input_path', type=str, required=True, help='Path to input JSON file')
    parser.add_argument('--output_path', type=str, required=True, help='Path to output JSON file')
    parser.add_argument('--output_type', type=str, choices=['all', 'success', 'fail'],
                        help='Output type to filter results')
    parser.add_argument('--max_workers', type=int, default=128, help='Maximum number of workers')
    parser.add_argument('--save_img', action='store_true', help='Save images')

    args = parser.parse_args()

    filter_options = {'output_type': args.output_type, 'max_workers': args.max_workers}

    postprocessing(
        input_path=args.input_path,
        output_path=args.output_path,
        filter_options=filter_options,
        save_img=args.save_img
    )
