import pygame
import random
import time
import json
import os
import argparse
import itertools
from pygame_screen_record import ScreenRecorder

# Initialize Pygame
pygame.init()

# Color definitions
RED = (255, 0, 0)
GREEN = (0, 255, 0)
BLUE = (0, 0, 255)
BLACK = (0, 0, 0)  # For drawing grid lines
COLOR_NAMES = {RED: "red", GREEN: "green", BLUE: "blue"}
COLORS = [RED, GREEN, BLUE]

# Size definitions
SIZES = ["small", "medium", "large"]
change_interval = 3

# Shape definitions
SHAPES = ["circle", "square", "triangle"]

# Function to parse command-line arguments
def parse_arguments():
    parser = argparse.ArgumentParser(description="Pygame Scene Generator with Screen Recording")
    parser.add_argument('--ts', type=int, default=20, help='Total run time of the video in seconds')
    parser.add_argument('--m', type=int, default=5, help='Number of rows in the grid')
    parser.add_argument('--n', type=int, default=5, help='Number of columns in the grid')
    parser.add_argument('--n_q', type=int, default=5, help='Number of questions to generate')
    parser.add_argument('--output_dir', type=str, default="output/angle1", help='Directory to save output files')
    return parser.parse_args()

# Draw shape function
def draw_shape(shape, color, size, x, y, BLOCK_SIZE):
    # Set scale based on size
    if size == "small":
        scale = BLOCK_SIZE // 5
    elif size == "medium":
        scale = BLOCK_SIZE // 3
    else:  # "large"
        scale = BLOCK_SIZE // 2  # Adjust maximum object size to prevent overlap

    center_x = x + BLOCK_SIZE // 2
    center_y = y + BLOCK_SIZE // 2

    if shape == "circle":
        pygame.draw.circle(screen, color, (center_x, center_y), scale)
    elif shape == "square":
        top_left_x = center_x - scale
        top_left_y = center_y - scale
        pygame.draw.rect(screen, color, (top_left_x, top_left_y, 2 * scale, 2 * scale))
    elif shape == "triangle":
        point1 = (center_x, center_y - scale)  # Top point
        point2 = (center_x - scale, center_y + scale)  # Bottom left point
        point3 = (center_x + scale, center_y + scale)  # Bottom right point
        pygame.draw.polygon(screen, color, [point1, point2, point3])

# Draw grid and generate scene
def draw_grid(m, n, BLOCK_SIZE):
    current_scene = []  # Save current scene state
    scene_description = []  # Save current scene text description

    # Draw grid lines
    for i in range(n + 1):
        pygame.draw.line(screen, BLACK, (i * BLOCK_SIZE, 0), (i * BLOCK_SIZE, SCREEN_HEIGHT), 2)
    for j in range(m + 1):
        pygame.draw.line(screen, BLACK, (0, j * BLOCK_SIZE), (SCREEN_WIDTH, j * BLOCK_SIZE), 2)

    for i in range(n):
        row = []
        for j in range(m):
            x = i * BLOCK_SIZE
            y = j * BLOCK_SIZE
            color = random.choice(COLORS)
            size = random.choice(SIZES)
            shape = random.choice(SHAPES)
            draw_shape(shape, color, size, x, y, BLOCK_SIZE)
            row.append({"shape": shape, "color": color, "size": size})  # Save each cell's state
            # Record text description
            color_name = COLOR_NAMES[color]
            description = f"{size} {color_name} {shape}"
            scene_description.append(description)
        current_scene.append(row)
    return current_scene, scene_description  # Return scene state and description

# Generate questions, answers, and choices
def generate_questions_and_answers(n_q, scene_states, video_path):
    questions = set()
    qa_pairs = []
    question_types = ["shape", "color", "size", "shape_color", "size_color", "size_shape", "size_shape_color"]
    scene_numbers = [f"Scene {i+1}" for i in range(len(scene_states))]
    all_possible_answers = []

    # Generate all possible combinations of scenes
    for i in range(1, len(scene_numbers)+1):
        for combination in itertools.combinations(scene_numbers, i):
            if len(combination) == 1:
                combined_answer = combination[0]
            else:
                combined_answer = ' and '.join(combination)
            all_possible_answers.append(combined_answer)

    all_possible_answers.append("None of the scenes have such objects.")

    while len(qa_pairs) < n_q:
        question_type = random.choice(question_types)
        question = ""
        correct_answer = ""
        choices = []

        if question_type == "shape":
            shape = random.choice(SHAPES)
            question = f"In which scene are there the most {shape}s?"
            if question in questions:
                continue  # Avoid duplicate questions
            max_count = 0
            scenes_with_max = []
            for idx, scene in enumerate(scene_states):
                count = sum(1 for row in scene for cell in row if cell['shape'] == shape)
                if count > max_count:
                    max_count = count
                    scenes_with_max = [idx]
                elif count == max_count:
                    scenes_with_max.append(idx)
            if max_count == 0:
                correct_answer = "None of the scenes have such objects."
            else:
                scenes = [f"Scene {i+1}" for i in scenes_with_max]
                correct_answer = ' and '.join(scenes)
        elif question_type == "color":
            color_name = random.choice(list(COLOR_NAMES.values()))
            question = f"In which scene are there the most {color_name} objects?"
            if question in questions:
                continue
            max_count = 0
            scenes_with_max = []
            for idx, scene in enumerate(scene_states):
                count = sum(1 for row in scene for cell in row if COLOR_NAMES[cell['color']] == color_name)
                if count > max_count:
                    max_count = count
                    scenes_with_max = [idx]
                elif count == max_count:
                    scenes_with_max.append(idx)
            if max_count == 0:
                correct_answer = "None of the scenes have such objects."
            else:
                scenes = [f"Scene {i+1}" for i in scenes_with_max]
                correct_answer = ' and '.join(scenes)
        elif question_type == "size":
            size = random.choice(SIZES)
            question = f"In which scene are there the most objects of size {size}?"
            if question in questions:
                continue
            max_count = 0
            scenes_with_max = []
            for idx, scene in enumerate(scene_states):
                count = sum(1 for row in scene for cell in row if cell['size'] == size)
                if count > max_count:
                    max_count = count
                    scenes_with_max = [idx]
                elif count == max_count:
                    scenes_with_max.append(idx)
            if max_count == 0:
                correct_answer = "None of the scenes have such objects."
            else:
                scenes = [f"Scene {i+1}" for i in scenes_with_max]
                correct_answer = ' and '.join(scenes)
        elif question_type == "shape_color":
            shape = random.choice(SHAPES)
            color_name = random.choice(list(COLOR_NAMES.values()))
            question = f"In which scene are there the most {color_name} {shape}s?"
            if question in questions:
                continue
            max_count = 0
            scenes_with_max = []
            for idx, scene in enumerate(scene_states):
                count = sum(
                    1 for row in scene for cell in row
                    if cell['shape'] == shape and COLOR_NAMES[cell['color']] == color_name
                )
                if count > max_count:
                    max_count = count
                    scenes_with_max = [idx]
                elif count == max_count:
                    scenes_with_max.append(idx)
            if max_count == 0:
                correct_answer = "None of the scenes have such objects."
            else:
                scenes = [f"Scene {i+1}" for i in scenes_with_max]
                correct_answer = ' and '.join(scenes)
        elif question_type == "size_color":
            size = random.choice(SIZES)
            color_name = random.choice(list(COLOR_NAMES.values()))
            question = f"In which scene are there the most objects of size {size} {color_name}?"
            if question in questions:
                continue
            max_count = 0
            scenes_with_max = []
            for idx, scene in enumerate(scene_states):
                count = sum(
                    1 for row in scene for cell in row
                    if cell['size'] == size and COLOR_NAMES[cell['color']] == color_name
                )
                if count > max_count:
                    max_count = count
                    scenes_with_max = [idx]
                elif count == max_count:
                    scenes_with_max.append(idx)
            if max_count == 0:
                correct_answer = "None of the scenes have such objects."
            else:
                scenes = [f"Scene {i+1}" for i in scenes_with_max]
                correct_answer = ' and '.join(scenes)
        elif question_type == "size_shape":
            size = random.choice(SIZES)
            shape = random.choice(SHAPES)
            question = f"In which scene are there the most {size} {shape}s?"
            if question in questions:
                continue
            max_count = 0
            scenes_with_max = []
            for idx, scene in enumerate(scene_states):
                count = sum(
                    1 for row in scene for cell in row
                    if cell['size'] == size and cell['shape'] == shape
                )
                if count > max_count:
                    max_count = count
                    scenes_with_max = [idx]
                elif count == max_count:
                    scenes_with_max.append(idx)
            if max_count == 0:
                correct_answer = "None of the scenes have such objects."
            else:
                scenes = [f"Scene {i+1}" for i in scenes_with_max]
                correct_answer = ' and '.join(scenes)
        elif question_type == "size_shape_color":
            size = random.choice(SIZES)
            shape = random.choice(SHAPES)
            color_name = random.choice(list(COLOR_NAMES.values()))
            question = f"In which scene are there the most {size} {color_name} {shape}s?"
            if question in questions:
                continue
            max_count = 0
            scenes_with_max = []
            for idx, scene in enumerate(scene_states):
                count = sum(
                    1 for row in scene for cell in row
                    if cell['size'] == size and cell['shape'] == shape and COLOR_NAMES[cell['color']] == color_name
                )
                if count > max_count:
                    max_count = count
                    scenes_with_max = [idx]
                elif count == max_count:
                    scenes_with_max.append(idx)
            if max_count == 0:
                correct_answer = "None of the scenes have such objects."
            else:
                scenes = [f"Scene {i+1}" for i in scenes_with_max]
                correct_answer = ' and '.join(scenes)
        else:
            continue  # Skip if question type doesn't match

        # Generate choices
        incorrect_answers = [ans for ans in all_possible_answers if ans != correct_answer]
        choices = random.sample(incorrect_answers, min(3, len(incorrect_answers)))
        choices.append(correct_answer)
        random.shuffle(choices)

        # Determine the correct option index (A, B, C, D)
        option_letters = ['A', 'B', 'C', 'D']
        correct_option = option_letters[choices.index(correct_answer)]

        # Format choices with prefixes
        choices_formatted = [f"{option_letters[i]}. {choice}" for i, choice in enumerate(choices)]

        questions.add(question)
        qa_pairs.append({
            "question": question,
            "choices": choices_formatted,
            "answer": correct_option,
            "video_path": video_path
        })

    return qa_pairs

# Save data to JSON
def save_data_to_json(qa_pairs, scene_descriptions, file_name):
    data = {
        "questions_answers": qa_pairs,
        "scene_descriptions": {}
    }
    for idx, descriptions in enumerate(scene_descriptions):
        scene_number = f"Scene {idx + 1}"
        data["scene_descriptions"][scene_number] = descriptions

    with open(file_name, 'w') as f:
        json.dump(data, f, indent=4)

# Main game loop
def main(ts=20, n_q=5, m=5, n=5, output_dir="output/angle1"):
    global SCREEN_WIDTH, SCREEN_HEIGHT, screen, BLOCK_SIZE

    # Import itertools here if not globally
    # import itertools

    # Set grid dimensions based on arguments
    BLOCK_SIZE = 150  # Increased block size for better visibility
    SCREEN_WIDTH = n * BLOCK_SIZE
    SCREEN_HEIGHT = m * BLOCK_SIZE

    # Ensure output directory exists
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    unique_id = str(int(time.time()))
    video_file_name = os.path.join(output_dir, f"qa1_{unique_id}.mp4")
    qa_file_name = os.path.join(output_dir, f"qa1_{unique_id}.json")

    # Initialize screen
    screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
    pygame.display.set_caption("物体变换游戏")

    recorder = ScreenRecorder(60)

    # Save each scene's state
    scene_states = []
    scene_descriptions = []  # Save each scene's text description

    # Start recording
    recorder.start_rec()
    screen.fill((255, 255, 255))  # Set white background at game start

    # Draw initial scene
    current_scene, scene_description = draw_grid(m, n, BLOCK_SIZE)
    pygame.display.flip()
    # Save initial scene state
    scene_states.append(current_scene)
    scene_descriptions.append(scene_description)

    clock = pygame.time.Clock()
    start_time = pygame.time.get_ticks()  # Record start time
    last_change_time = start_time

    try:
        running = True
        while running:
            # Calculate elapsed time
            elapsed_time = (pygame.time.get_ticks() - start_time) / 1000  # Convert to seconds

            # Check if exceeded set time
            if elapsed_time > ts:
                running = False

            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    running = False

            # Change objects every change_interval seconds
            current_time = pygame.time.get_ticks()
            if (current_time - last_change_time) / 1000 > change_interval:
                screen.fill((255, 255, 255))  # Clear screen
                current_scene, scene_description = draw_grid(m, n, BLOCK_SIZE)
                pygame.display.flip()
                # Save current scene state
                scene_states.append(current_scene)
                scene_descriptions.append(scene_description)
                last_change_time = current_time
            else:
                pygame.display.flip()

            # Control frame rate
            clock.tick(60)

    finally:
        recorder.stop_rec()  # Stop recording
        recorder.save_recording(video_file_name)  # Save the recording file
        pygame.quit()

        # Generate questions and answers
        qa_pairs = generate_questions_and_answers(n_q, scene_states, video_file_name)
        save_data_to_json(qa_pairs, scene_descriptions, qa_file_name)  # Save to JSON file
        print(f"Questions and answers saved to {qa_file_name}")
        print(f"Video saved to {video_file_name}")

if __name__ == "__main__":
    args = parse_arguments()
    main(ts=args.ts, n_q=args.n_q, m=args.m, n=args.n, output_dir=args.output_dir)
