import pygame
import random
import time
import argparse
from datetime import datetime, timedelta
import json
import os
from pygame_screen_record import ScreenRecorder
# Initialize Pygame
pygame.init()
    # Color map for object names
COLOR_MAP = {
        (255, 0, 0): "Red",
        (0, 255, 0): "Green",
        (0, 0, 255): "Blue"
    }
def generate_numeric_choices(answer):
    """Generate multiple choices for numeric questions."""

    correct_value = int(answer)

    choices = [
        f"A. {correct_value - 1 if correct_value > 0 else 0}",
        f"B. {correct_value}",
        f"C. {correct_value + 1}",
        f"D. {correct_value + 2}"
    ]

    correct_answer = choices[1]
    return choices, correct_answer

def generate_yes_no_choices(answer):
    """Generate choices for yes/no questions."""
    choices = ["A. Yes", "B. No"]
    correct_answer = f"A. {answer.capitalize()}" if answer.lower() == "yes" else "B. No"
    return choices, correct_answer

def generate_choices_for_question(question, answer):
    """Generate clear and logical choices for different types of questions."""
    if "how many" in question.lower():

        choices, correct_answer = generate_numeric_choices(answer)
    elif "yes or no" in question.lower() or "did" in question.lower():

        choices, correct_answer = generate_yes_no_choices(answer)
    else:

        choices = [f"A. {answer}", "B. Option 1", "C. Option 2", "D. Option 3"]
        correct_answer = choices[0]

    return choices, correct_answer
def add_one_second(time_str):
    time_format = "%H:%M:%S"  # Assuming time_str is formatted as "HH:MM:SS"
    time_obj = datetime.strptime(time_str, time_format)
    time_obj += timedelta(seconds=1)
    return time_obj.strftime(time_format)

def generate_questions(object_sequence, n=5):
    qa_pairs = []
    if len(object_sequence) < 2:
        return qa_pairs  # Not enough objects to generate meaningful questions

    for _ in range(n):
        # Randomly select two distinct objects X and Y such that X appears before Y
        if len(object_sequence) < 2:
            break  # Not enough objects to select X and Y

        obj_x_index = random.randint(0, len(object_sequence) - 2)
        obj_y_index = random.randint(obj_x_index + 1, len(object_sequence) - 1)

        obj_x_entry = object_sequence[obj_x_index]
        obj_y_entry = object_sequence[obj_y_index]

        obj_x = obj_x_entry["object"]
        obj_x_time = add_one_second(obj_x_entry["time"])
        obj_x_desc = f"{COLOR_MAP[tuple(obj_x['color'])]} {obj_x['shape']} at {obj_x_time}"

        obj_y = obj_y_entry["object"]
        obj_y_time = add_one_second(obj_y_entry["time"])
        obj_y_desc = f"{COLOR_MAP[tuple(obj_y['color'])]} {obj_y['shape']} at {obj_y_time}"

        # Analyze the objects including X and Y
        shapes = set()
        colors = set()
        object_count = obj_y_index - obj_x_index + 1  # Including obj_x and obj_y
        for i in range(obj_x_index, obj_y_index + 1):
            obj = object_sequence[i]["object"]
            shapes.add(obj["shape"])
            colors.add(COLOR_MAP[tuple(obj["color"])])

        qa_pairs.extend([
            {"question": f"How many different shapes appeared between the {obj_x_desc} and the {obj_y_desc}?", "answer": str(len(shapes))},
            {"question": f"How many different colors appeared between the {obj_x_desc} and the {obj_y_desc}?", "answer": str(len(colors))},
            {"question": f"How many objects appeared between the {obj_x_desc} and the {obj_y_desc}?", "answer": str(object_count)},
            {"question": f"Did any object appear between the {obj_x_desc} and the {obj_y_desc}, yes or no?", "answer": "yes" if object_count > 0 else "no"}
        ])

    return qa_pairs


def main(args):
    # Screen settings
    screen = pygame.display.set_mode((800, 600))
    clock = pygame.time.Clock()

    # Font settings
    font = pygame.font.Font(None, 36)

    # Define colors
    WHITE = (255, 255, 255)
    RED = (255, 0, 0)
    GREEN = (0, 255, 0)
    BLUE = (0, 0, 255)
    BLACK = (0, 0, 0)



    # Generate unique ID
    id = str(int(time.time()))+'_D'
    t_total = args.t_total  # Total animation time is 15 seconds
    t = args.t  # Interval to generate a new object
    # File saving paths
    output_dir = args.output_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    file_name = os.path.join(output_dir, f"angel4_{id}.mp4")
    qa_file_name = os.path.join(output_dir, f"angel4_{id}_qa.json")

    # Define shapes
    SHAPES = ["circle", "rect", "triangle"]

    # Random object generation function
    def create_random_object():
        shape = random.choice(SHAPES)
        color = random.choice([RED, GREEN, BLUE])
        pos = [random.randint(50, 750), random.randint(50, 550)]

        if shape == "circle":
            return {"shape": "circle", "color": color, "pos": pos, "radius": 20}
        elif shape == "rect":
            return {"shape": "rect", "color": color, "pos": pos, "width": 50, "height": 50}
        elif shape == "triangle":
            return {"shape": "triangle", "color": color, "pos": pos, "size": 30, "angle": 0}

    # Draw objects on screen
    def draw_object(screen, obj):
        if obj["shape"] == "circle":
            pygame.draw.circle(screen, obj["color"], obj["pos"], obj["radius"])
        elif obj["shape"] == "rect":
            pygame.draw.rect(screen, obj["color"], pygame.Rect(obj["pos"][0], obj["pos"][1], obj["width"], obj["height"]))
        elif obj["shape"] == "triangle":
            points = [
                (obj["pos"][0], obj["pos"][1] - obj["size"]),
                (obj["pos"][0] - obj["size"], obj["pos"][1] + obj["size"]),
                (obj["pos"][0] + obj["size"], obj["pos"][1] + obj["size"])
            ]
            pygame.draw.polygon(screen, obj["color"], points)

    # Display the simulated time
    def draw_time(screen, simulated_time):
        time_str = simulated_time.strftime("%H:%M:%S")
        time_surface = font.render(time_str, True, BLACK)
        screen.blit(time_surface, (20, 20))

    # Initialization
    objects = []  # Store the objects
    object_history = {}  # Record object states at each second
    object_sequence = []  # Record the sequence of objects as they appear

    last_object_time = time.time()  # Time of the last object generation

    # Simulated time settings
    simulated_time = datetime.now()  # Start from the current system time
    real_time = time.time()  # Track real time for second updates

    # Total animation runtime (seconds)

    start_time = time.time()  # Record the start time of the animation
    recorder = ScreenRecorder(60)
    recorder.start_rec()
    # Main loop
    running = True
    while running:
        screen.fill(BLACK)  # Clear the screen
        screen.fill((225, 225, 225))
        # Display simulation time every frame
        draw_time(screen, simulated_time)

        # Check if the animation has run for more than t_total seconds
        if time.time() - start_time >= t_total:
            running = False  # End the animation

        # Increment simulated time every second and update the objects at the same time
        if time.time() - real_time >= 1:
            simulated_time += timedelta(seconds=1)
            real_time = time.time()  # Update real time

            # Record the current scene for the exact time shown on screen
            time_str = simulated_time.strftime("%H:%M:%S")
            object_history[time_str] = [obj.copy() for obj in objects]

        # Generate a new object at intervals
        if time.time() - last_object_time >= t:
            new_object = create_random_object()  # Generate a random object
            objects = [new_object]  # Replace the old objects with the new one
            object_sequence.append({
                "time": simulated_time.strftime("%H:%M:%S"),
                "object": new_object.copy()
            })
            last_object_time = time.time()  # Update the last object creation time

        # Update and draw objects
        for obj in objects:
            draw_object(screen, obj)  # Draw the object

        # Event handling
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False

        pygame.display.flip()
        clock.tick(60)  # Run at 60 frames per second
    recorder.stop_rec()
    recorder.save_recording(file_name)
    # Generate questions and answers based on your specified formats






    # Generate questions and answers
    qa_pairs = generate_questions(object_sequence, n=args.n_q)
    for qa in qa_pairs:
        qa['ans']=qa['answer']
        qa['choices'],qa['answer']=generate_choices_for_question(qa['question'],qa['answer'])
    # Serialize object history for better display
    def serialize_object(obj):
        obj_copy = obj.copy()
        obj_copy['color'] = COLOR_MAP[tuple(obj['color'])]
        return obj_copy

    serialized_object_history = {}
    for time_str, objs in object_history.items():
        serialized_object_history[time_str] = [serialize_object(obj) for obj in objs]

    # Serialize object sequence for better display
    serialized_object_sequence = []
    for entry in object_sequence:
        serialized_object_sequence.append({
            "time": entry["time"],
            "object": serialize_object(entry["object"])
        })

    # Save the results
    output_data = {
        "questions_answers": qa_pairs,
        "object_history": serialized_object_history,
        "object_sequence": serialized_object_sequence
    }

    with open(qa_file_name, 'w') as f:
        json.dump(output_data, f, indent=4)

    pygame.quit()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Pygame animation and question generator.")
    parser.add_argument('--t', type=int, default=2, help="Interval between generating new objects.")
    parser.add_argument('--t_total', type=int, default=26, help="Total running time of the animation.")
    parser.add_argument('--n_q', type=int, default=5, help="Number of questions to generate.")
    parser.add_argument('--output_dir', type=str, default="output/angle1", help="Directory to save the output files.")

    args = parser.parse_args()
    main(args)