import numpy as np
from moviepy.editor import *
import random
import time
import os
import json
import argparse

# 定义物体、颜色和对应的音符
objects = ['circle', 'square', 'triangle']
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]
color_names = ['red', 'green', 'blue']
notes = ['do.wav', 're.wav', 'mi.wav', 'fa.wav', 'so.wav', 'la.wav', 'ti.wav']
note_names = ['Do', 'Re', 'Mi', 'Fa', 'So', 'La', 'Ti']

def parse_arguments():
    parser = argparse.ArgumentParser(description="Pygame Scene Generator with Screen Recording")
    parser.add_argument('--tt', type=int, default=20, help='Total run time of the video in seconds')
    parser.add_argument('--ts', type=int, default=5, help='Display duration for each object')
    parser.add_argument('--output_dir', type=str, default="output/angle1", help='Directory to save output files')
    return parser.parse_args()

args = parse_arguments()


ts = args.ts 
tt = args.tt 
pause_duration = 1 
output_dir = args.output_dir 
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
timestamp = str(int(time.time()))
file_name = os.path.join(output_dir, f"{timestamp}.mp4")

def make_frame(t, shape, color):
    """
    创建包含指定形状和颜色的帧。
    """
    canvas = np.ones((400, 600, 3), dtype=np.uint8) * 255 

    if shape == 'circle':
        center = (300, 200)
        radius = 100
        Y, X = np.ogrid[:canvas.shape[0], :canvas.shape[1]]
        mask = (X - center[0]) ** 2 + (Y - center[1]) ** 2 <= radius ** 2
        canvas[mask] = color
    elif shape == 'square':

        canvas[100:300, 200:400] = color
    elif shape == 'triangle':

        for y in range(100, 300):
            x_start = 200 + (y - 100)
            x_end = 400 - (y - 100)
            if x_start < x_end:
                canvas[y, int(x_start):int(x_end)] = color
    return canvas

def create_clip_with_sound(shape, color, sound_file, ts):
    """

    """
    image = make_frame(0, shape, color)
    image_clip = ImageClip(image, duration=ts)


    audio = AudioFileClip(sound_file)


    if audio.duration < ts:
        audio = afx.audio_loop(audio, duration=ts)
    else:
        audio = audio.subclip(0, ts)


    video_with_audio = image_clip.set_audio(audio)

    return video_with_audio

def generate_sequence(tt, ts, pause_duration):

    clips = []
    current_time = 0
    last_shape = None
    last_color = None
    sequence_data = []

    while current_time < tt:
     
        shape = random.choice(objects)
        color_index = random.randint(0, len(colors) - 1)
        color = colors[color_index]
        color_name = color_names[color_index]

        while shape == last_shape and len(objects) > 1:
            shape = random.choice(objects)
        while color == last_color and len(colors) > 1:
            color_index = random.randint(0, len(colors) - 1)
            color = colors[color_index]
            color_name = color_names[color_index]


        sound_index = random.randint(0, len(notes) - 1)
        sound_file = notes[sound_index]
        note_name = note_names[sound_index]


        clip = create_clip_with_sound(shape, color, sound_file, ts)
        clips.append(clip)


        pause_clip = ColorClip(size=(600, 400), color=(255, 255, 255), duration=pause_duration)
        clips.append(pause_clip)

        current_time += ts + pause_duration

        last_shape = shape
        last_color = color

   
        sequence_data.append({
            'shape': shape,
            'color': color_name,
            'note': note_name
        })

    return clips, sequence_data

def generate_questions_answers(sequence_data, video_file):

    questions_answers = []
    shapes_count = {}
    colors_count = {}
    notes_count = {}
    shape_sequence = []
    note_sequence = []
    color_sequence = []

    for item in sequence_data:
        shape = item['shape']
        color = item['color']
        note = item['note']

        shapes_count[shape] = shapes_count.get(shape, 0) + 1
        colors_count[color] = colors_count.get(color, 0) + 1
        notes_count[note] = notes_count.get(note, 0) + 1

        shape_sequence.append(shape)
        color_sequence.append(color)
        note_sequence.append(note)

    total_shapes = len(sequence_data)
    unique_shapes = len(shapes_count)
    unique_colors = len(colors_count)
    unique_notes = len(notes_count)

    appeared_shapes = list(shapes_count.keys())
    appeared_colors = list(colors_count.keys())
    appeared_notes = list(notes_count.keys())

    question_templates = [
        {
            'question': "Which shape appeared the most times?",
            'attribute': shapes_count,
            'options': objects,
            'format': lambda x: x.capitalize()
        },
        {
            'question': "What was the first color that appeared?",
            'attribute': color_sequence[0],
            'options': list(set(color_sequence)),
            'format': lambda x: x.capitalize()
        },
        {
            'question': 'Was the note "{}" played?',
            'attribute': notes_count,
            'options': ["Yes", "No"],
            'format': lambda x: x,
            'note': True
        },
        {
            'question': "How many different notes were played in total?",
            'attribute': unique_notes,
            'options': [unique_notes] + random.sample(range(1, len(note_names) + 1), 3),
            'format': lambda x: str(x)
        },
        # {
        #     'question': "Which color appeared the least times?",
        #     'attribute': colors_count,
        #     'options': color_names,
        #     'format': lambda x: x.capitalize(),
        #     'least': True
        # },
        {
            'question': "Did the triangle shape appear?",
            'attribute': shapes_count,
            'options': ["Yes", "No"],
            'format': lambda x: x.capitalize(),
            'specific_shape': 'triangle'
        },
        {
            'question': "What was the last note played?",
            'attribute': note_sequence[-1],
            'options': list(set(note_sequence)),
            'format': lambda x: x,
            'last_note': True
        },
        {
            'question': "Which shape did not appear at all?",
            'attribute': shapes_count,
            'options': objects,
            'format': lambda x: x.capitalize(),
            'not_appeared': True
        },
    ]

    selected_templates = (question_templates)

    for template in selected_templates:
        if 'least' in template:
            least_common = min(template['attribute'], key=template['attribute'].get)
            options = [template['format'](color) for color in template['options']]
            correct_answer = template['format'](least_common)
            question = template['question']
        elif 'note' in template:
            random_note = random.choice(note_names)
            question = template['question'].format(random_note)
            options = template['options']
            correct_answer = "Yes" if random_note in notes_count else "No"
            correct_answer = f"A. {correct_answer}"
            options = [f"{chr(65 + i)}. {opt}" for i, opt in enumerate(options)]
        elif 'specific_shape' in template:
            shape = template['specific_shape']
            question = template['question']
            options = template['options']
            correct_answer = "Yes" if shape in template['attribute'] else "No"
            correct_answer = f"A. {correct_answer}"
            options = [f"{chr(65 + i)}. {opt}" for i, opt in enumerate(options)]
        elif 'last_note' in template:
            correct_answer = template['format'](template['attribute'])
            options = [template['format'](note) for note in template['options']]
            options = [f"{chr(65 + i)}. {opt}" for i, opt in enumerate(options)]
            correct_answer = next(opt for opt in options if correct_answer in opt)
            question = template['question']
        elif 'not_appeared' in template:
            not_appeared_shapes = [shape for shape in template['options'] if shape not in template['attribute']]
            if not_appeared_shapes:
                correct_answer = template['format'](not_appeared_shapes[0])
                options = [template['format'](shape) for shape in template['options']]
                options = [f"{chr(65 + i)}. {opt}" for i, opt in enumerate(options)]
                correct_answer = next(opt for opt in options if correct_answer in opt)
                question = template['question']
            else:
                continue 
        else:
            if isinstance(template['attribute'], dict):
                most_common = max(template['attribute'], key=template['attribute'].get)
                correct_answer = template['format'](most_common)
                options = [template['format'](opt) for opt in template['options']]
            else:
                correct_answer = template['format'](template['attribute'])
                options = [template['format'](opt) for opt in template['options']]
            options = [f"{chr(65 + i)}. {opt}" for i, opt in enumerate(options)]
            correct_answer = next(opt for opt in options if correct_answer in opt)
            question = template['question']

        questions_answers.append({
            "question": question,
            "choices": options,
            "correct_answer": correct_answer,
            "video_path": video_file
        })

    return questions_answers

def save_data_to_json(qa_pairs, file_name):
    """
    保存问题和答案到 JSON 文件。
    """
    data = {
        "questions_answers": qa_pairs,
    }
    with open(file_name, 'w') as f:
        json.dump(data, f, indent=4)
    print(f"Questions and answers have been saved to {file_name}")

def main(tt=15, ts=2, pause_duration=1):
    """
    主程序入口。
    """
    clips, sequence_data = generate_sequence(tt, ts, pause_duration)

    final_clip = concatenate_videoclips(clips)

    final_clip.write_videofile(file_name, fps=24)

    video_file = f"./{timestamp}.mp4"
    qa_list = generate_questions_answers(sequence_data, video_file)
    qa_file = os.path.join(output_dir, f"{timestamp}.json")
    save_data_to_json(qa_list, qa_file)

if __name__ == "__main__":
    main(tt=args.tt, ts=args.ts, pause_duration=1)
