import random
import pygame
from pygame.locals import *
from OpenGL.GL import *
from OpenGL.GLU import *
from OpenGL.GLUT import *
import time
import cv2  # 使用 OpenCV 保存视频
import numpy as np
import json
import os
from datetime import datetime  # 导入 datetime 模块
import argparse
import re

def extract_edges_from_answer(answer):
    """Use regular expression to extract edges from answer."""
    return re.findall(r"'(\w+)'", answer)

def generate_edge_choices(answer):
    EDGES = ['AB', 'BC', 'CD', 'DE', 'EF', 'FG', 'GH', 'HE']
    correct_edges = extract_edges_from_answer(answer)
    incorrect_edges = [edge for edge in EDGES if edge not in correct_edges]
    
    if correct_edges:
        choices = random.sample(incorrect_edges, 3) + correct_edges
        random.shuffle(choices)
        choices = [f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)]
        correct_answer = next((choice for choice in choices if any(edge in choice for edge in correct_edges)), None)
        return choices, correct_answer
    else:
        return [], None

def generate_question_answer(first_edge):
    question = "On which edge does the ball start moving in the first scene?"
    answer = f"小球开始沿着边 '{first_edge}' 移动。"
    return {"question": question, "answer": answer}

def parse_arguments():
    parser = argparse.ArgumentParser(description="Pygame Scene Generator with Screen Recording")
    parser.add_argument('--ts', type=int, default=20, help='Total run time of the video in seconds')
    parser.add_argument('--at', type=float, default=0.5)
    parser.add_argument('--n_q', type=int, default=5, help='Number of questions to generate')
    parser.add_argument('--output_dir', type=str, default="output/angle1", help='Directory to save output files')
    return parser.parse_args()

def main(args):
    vertices = [
        [1, 1, -1],    # A
        [1, -1, -1],   # B
        [-1, -1, -1],  # C
        [-1, 1, -1],   # D
        [1, 1, 1],     # E
        [1, -1, 1],    # F
        [-1, -1, 1],   # G
        [-1, 1, 1]     # H
    ]

    edges = [
        (0, 1), (1, 2), (2, 3), (3, 0),  # 背面边
        (4, 5), (5, 6), (6, 7), (7, 4),  # 前面边
        (0, 4), (1, 5), (2, 6), (3, 7)   # 连接前后面的边
    ]

    letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
    edge_names = ['AB', 'BC', 'CD', 'DA', 'EF', 'FG', 'GH', 'HE', 'AE', 'BF', 'CG', 'DH']
    t_total = args.ts
    total_edges = len(edges)
    t_per_edge = args.at
    output_dir = args.output_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    id = str(int(time.time())) + '_D'
    file_name = os.path.join(output_dir, 'cube_angle1' + id + ".mp4")
    qa_file_name = os.path.join(output_dir, 'cube_angle1' + id + "_qa.json")

    width, height = 800, 600
    fps = 60
    fourcc = cv2.VideoWriter_fourcc(*'X264')
    out = cv2.VideoWriter(file_name, fourcc, fps, (width, height))

    def draw_cube():
        glBegin(GL_LINES)
        for edge in edges:
            glVertex3fv(vertices[edge[0]])
            glVertex3fv(vertices[edge[1]])
        glEnd()

    def draw_letters():
        glColor3f(0, 0, 0)  # Black letters
        for i, vertex in enumerate(vertices):
            glRasterPos3fv(vertex)
            for char in letters[i]:
                glutBitmapCharacter(GLUT_BITMAP_TIMES_ROMAN_24, ord(char))

    def draw_moving_sphere(position, radius=0.1):
        glPushMatrix()
        glTranslatef(position[0], position[1], position[2])
        glColor3f(1, 0, 0)  # Red sphere
        glutSolidSphere(radius, 20, 20)
        glPopMatrix()

    def get_sphere_position(t):
        edge_index = (int(t // t_per_edge) + edge_offset) % total_edges
        local_t = (t % t_per_edge) / t_per_edge
        start_vertex = vertices[edges[edge_index][0]]
        end_vertex = vertices[edges[edge_index][1]]
        pos = [start_vertex[i] + (end_vertex[i] - start_vertex[i]) * local_t for i in range(3)]
        return pos

    def capture_frame():
        pixels = glReadPixels(0, 0, width, height, GL_RGB, GL_UNSIGNED_BYTE)
        image = np.frombuffer(pixels, dtype=np.uint8).reshape(height, width, 3)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        image = cv2.flip(image, 0)
        return image

    pygame.init()
    screen = pygame.display.set_mode((width, height), DOUBLEBUF | OPENGL)
    gluPerspective(45, (width / height), 0.1, 50.0)
    glTranslatef(0.0, 0.0, -5)
    glEnable(GL_DEPTH_TEST)

    glutInit()
    start_time = time.time()
    displayed_objects = []
    first_edge = None
    edge_offset = random.randint(0, total_edges - 1)

    running = True
    glClearColor(1, 1, 1, 1)  # Set white background

    while running:
        current_time = time.time()
        elapsed_time = current_time - start_time
        if elapsed_time >= t_total:
            running = False

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False

        glRotatef(1, 3, 1, 1)
        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)
        glColor3f(0, 0, 0)  # Set black color for the cube edges
        draw_cube()
        draw_letters()
        sphere_position = get_sphere_position(elapsed_time)
        current_edge_index = (int(elapsed_time // t_per_edge) + edge_offset) % total_edges
        current_edge_name = edge_names[current_edge_index]

        if first_edge is None:
            first_edge = edge_names[edge_offset]
            displayed_objects.append({
                "edge": first_edge,
                "time": datetime.now().strftime("%H:%M:%S")
            })

        draw_moving_sphere(sphere_position)
        pygame.display.flip()
        frame = capture_frame()
        out.write(frame)
        pygame.time.wait(10)

    out.release()
    pygame.quit()

    qa_pair = generate_question_answer(first_edge)
    qa_pair['choices'], qa_pair['answer'] = generate_edge_choices(qa_pair['answer'])
    output_data = {
        "questions_answers": qa_pair,
        "displayed_objects": displayed_objects
    }

    with open(qa_file_name, 'w', encoding='utf-8') as f:
        json.dump(output_data, f, ensure_ascii=False, indent=4)

    print(f"Video saved to {file_name}")
    print(f"Q&A saved to {qa_file_name}")

if __name__ == "__main__":
    args = parse_arguments()
    main(args)
