import pygame
import random
import time
import json
import os
from pygame_screen_record import ScreenRecorder  # Use ScreenRecorder for recording
import argparse
# ---------------------------- Global Variables Definition ----------------------------

import re

MOVEMENTS = ["move left", "move right", "stay still"]
MOVEMENT_PATTERNS = [
    "move left, move right, stay still",
    "move left, touch left boundary, move right",
    "move right, move left, stay still",
    "move left, touch right boundary, move right"
]


def generate_movement_pattern_choices(answer):
    """Generate choices for movement pattern questions."""
    correct_answer = answer.split(": ")[1]  
    incorrect_patterns = [pattern for pattern in MOVEMENT_PATTERNS if pattern != correct_answer]
    choices = random.sample(incorrect_patterns, 3) + [correct_answer]
    random.shuffle(choices)
    choices = [f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)]
    correct_answer = next((choice for choice in choices if correct_answer in choice), None)
    return choices, correct_answer

def generate_number_choices(answer):
    """Generate choices for number-related questions."""
    correct_answer = int(re.search(r'\d+', answer).group()) 
    choices = [correct_answer - 2, correct_answer - 1, correct_answer, correct_answer + 1]
    random.shuffle(choices)
    choices = [f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)]
    correct_answer = next((choice for choice in choices if str(correct_answer) in choice), None)
    return choices, correct_answer

def generate_frequency_choices(answer):
    """Generate choices for frequency-related questions."""
    if "left" in answer.lower() and "right" in answer.lower():
        correct_answer = "left" if "left more often" in answer.lower() else "right"
        choices = ["A. left", "B. right", "C. stay still", "D. None"]
    else:
        correct_answer = "stay still"
        choices = ["A. left", "B. right", "C. stay still", "D. None"]
    
    correct_answer = next((choice for choice in choices if correct_answer in choice), None)
    return choices, correct_answer

def generate_choices_for_question(question, answer):
    """Generate clear and logical choices for different types of questions."""
    if "movement pattern" in question.lower():
       
        choices, correct_answer = generate_movement_pattern_choices(answer)
    elif "how many" in question.lower():
       
        choices, correct_answer = generate_number_choices(answer)
    elif "more often" in question.lower():
       
        choices, correct_answer = generate_frequency_choices(answer)
    else:
      
        choices = [f"A. {answer}", "B. Option 1", "C. Option 2", "D. Option 3"]
        correct_answer = choices[0]

    return choices, correct_answer

def parse_arguments():
    parser = argparse.ArgumentParser(description="Pygame Scene Generator with Screen Recording")
    parser.add_argument('--ts', type=int, default=20, help='Total run time of the video in seconds')
    # parser.add_argument('--at', type=float, default=0.5)
    parser.add_argument('--init_enemy_count', type=int, default=10, help='Number of questions to generate')
    parser.add_argument('--output_dir', type=str, default="output/angle1", help='Directory to save output files')
    return parser.parse_args()

args = parse_arguments()

# Game window size
width = 600
height = 600
t_total = args.ts # Maximum running time of the game (seconds)
change_interval_seconds = 2  # Interval in seconds to decide whether to change movement direction

# Color definitions
BLACK = (225, 225, 225)

# Initialize Pygame
pygame.init()

# Create game window
window = pygame.display.set_mode((width, height))
pygame.display.set_caption("Plane Battle")

# Create clock object to control the game frame rate
clock = pygame.time.Clock()

# Set video parameters
fps = 60  # Frame rate
output_dir = args.output_dir
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
id = str(int(time.time()))+'_D'
file_name = os.path.join(output_dir, 'plane_battle_' + id + ".mp4")

# Initialize recorder
recorder = ScreenRecorder(fps)

# Create sprite groups
all_sprites = pygame.sprite.Group()

# Create bullet sprite group
bullets = pygame.sprite.Group()

# Create enemy sprite group
enemies = pygame.sprite.Group()

# Initialize global counters
total_enemies_appeared = 0
total_enemies_destroyed = 0

initial_enemy_count = args.init_enemy_count
# ---------------------------- Plane Class Definition ----------------------------

class Plane(pygame.sprite.Sprite):
    def __init__(self, change_interval_seconds=2):
        super().__init__()
        try:
            self.image = pygame.image.load('player.png')
        except pygame.error as e:
            print("Cannot load 'player.png'. Please ensure the file exists in the current directory.")
            raise SystemExit(e)
        self.image = pygame.transform.scale(self.image, (50, 50))  # Resize to 50x50
        self.rect = self.image.get_rect()
        self.rect.centerx = width // 2
        self.rect.bottom = height - 50  # Position the plane away from the edge
        self.base_speed = 6  # Base speed
        self.speed = 0  # Initial speed is stationary
        self.shoot_cooldown = 0  # Shooting cooldown time

        # Random direction change related parameters
        self.change_interval_seconds = change_interval_seconds  # Interval in seconds to decide whether to change direction
        self.last_change_time = pygame.time.get_ticks()  # Record the last time the direction changed
        self.change_interval_ms = self.change_interval_seconds * 1000  # Convert to milliseconds

        # List to record movement history
        self.movement_history = []
        # Record the number of times the player actively chooses a movement direction
        self.move_left_count = 0
        self.move_right_count = 0

    def update(self):
        global total_enemies_appeared  # Reference the global variable
        current_time = pygame.time.get_ticks()

        # Check if it's time to change direction
        if current_time - self.last_change_time >= self.change_interval_ms:
            self.last_change_time = current_time
            # Randomly choose a new direction: left, right, or stay
            direction = random.choice(['left', 'right', 'stay'])
            if direction == 'left':
                self.speed = -self.base_speed
                self.movement_history.append('move left')
                self.move_left_count += 1
                print(f"Plane chooses to move left, current speed: {self.speed}, left move count: {self.move_left_count}")
            elif direction == 'right':
                self.speed = self.base_speed
                self.movement_history.append('move right')
                self.move_right_count += 1
                print(f"Plane chooses to move right, current speed: {self.speed}, right move count: {self.move_right_count}")
            else:
                self.speed = 0
                self.movement_history.append('stay still')
                print(f"Plane chooses to stay still, current speed: {self.speed}")

        # Boundary check to prevent the plane from moving off-screen
        if self.rect.left <= 0:
            self.speed = self.base_speed  # Touch the left boundary, move right
            self.movement_history.append('touch left boundary, move right')
            print(f"Plane touches the left boundary, moves right, current speed: {self.speed}")
        elif self.rect.right >= width:
            self.speed = -self.base_speed  # Touch the right boundary, move left
            self.movement_history.append('touch right boundary, move left')
            print(f"Plane touches the right boundary, moves left, current speed: {self.speed}")

        # Update plane position
        self.rect.x += self.speed

        # Automatic shooting strategy: shoot automatically at set intervals
        if self.shoot_cooldown <= 0:
            self.shoot()
            self.shoot_cooldown = 30  # Reset cooldown time (in frames)
        else:
            self.shoot_cooldown -= 1

    def shoot(self):
        bullet = Bullet(self.rect.centerx, self.rect.top)
        all_sprites.add(bullet)
        bullets.add(bullet)
        print(f"Plane shoots, bullet position: ({bullet.rect.x}, {bullet.rect.y})")

# ---------------------------- Bullet Class Definition ----------------------------

class Bullet(pygame.sprite.Sprite):
    def __init__(self, x, y):
        super().__init__()
        try:
            self.image = pygame.image.load('bullet.png')
        except pygame.error as e:
            print("Cannot load 'bullet.png'. Please ensure the file exists in the current directory.")
            raise SystemExit(e)
        self.image = pygame.transform.scale(self.image, (10, 20))  # Resize to 10x20
        self.rect = self.image.get_rect()
        self.rect.centerx = x
        self.rect.bottom = y
        self.speed = 10

    def update(self):
        self.rect.y -= self.speed
        if self.rect.bottom < 0:
            self.kill()
            print(f"Bullet moves off-screen, destroys bullet, position: ({self.rect.x}, {self.rect.y})")

# ---------------------------- Enemy Class Definition ----------------------------

class Enemy(pygame.sprite.Sprite):
    def __init__(self):
        super().__init__()
        global total_enemies_appeared  # Reference the global variable
        try:
            self.image = pygame.image.load('enemy.png')
        except pygame.error as e:
            print("Cannot load 'enemy.png'. Please ensure the file exists in the current directory.")
            raise SystemExit(e)
        self.image = pygame.transform.scale(self.image, (30, 30))  # Resize to 30x30
        self.rect = self.image.get_rect()
        self.rect.x = random.randint(0, width - self.rect.width)
        self.rect.y = 0
        self.speed = random.randint(1, 3)
        total_enemies_appeared += 1  # Increment the count each time an enemy plane is created
        print(f"New enemy appears, position: ({self.rect.x}, {self.rect.y}), speed: {self.speed}")

    def update(self):
        global total_enemies_appeared  # Reference the global variable
        self.rect.y += self.speed
        if self.rect.top > height:
            self.rect.x = random.randint(0, width - self.rect.width)
            self.rect.y = 0
            self.speed = random.randint(1, 2)
            total_enemies_appeared += 1  # Increment the count each time an enemy plane is recreated
            print(f"Enemy regenerates, position: ({self.rect.x}, {self.rect.y}), speed: {self.speed}")

# ---------------------------- Create sprite objects ----------------------------

# Create plane object
player = Plane(change_interval_seconds=change_interval_seconds)  # Decide whether to change direction every 2 seconds
all_sprites.add(player)

# Create enemy sprite group, initial enemy count is 2

for _ in range(initial_enemy_count):
    enemy = Enemy()
    all_sprites.add(enemy)
    enemies.add(enemy)

# ---------------------------- Function to generate questions and answers ----------------------------

def generate_question_answer(movement_history, move_left_count, move_right_count, total_enemies_appeared, total_enemies_destroyed):
    """
    Generate questions and answers, returning a dictionary.
    Includes the following questions:
    1. What is the movement pattern of the plane?
    2. How many enemies appeared in the video?
    3. How many enemies were destroyed by the player's plane?
    4. Throughout the video, which movement occurred more often, moving left or moving right?
    """
    qa_list=[]
    # Question 1: What is the movement pattern of the plane?
    question1 = "What is the movement pattern of the plane?"
    movement_sequence = ', '.join(movement_history)
    answer1 = (
        f"The plane's movement pattern is to randomly choose between moving left, right, or staying still every {change_interval_seconds} seconds. "
        f"The specific movement sequence was: {movement_sequence}."
    )
    qa_list.append({
            "question": question1,
            "answer": answer1
        })

    # Question 2: How many enemies appeared in the video?
    question2_1 = "How many enemies appeared in the video?"
    answer2_1 = f"In total, {total_enemies_appeared} enemies appeared during the entire video."
    qa_list.append({
            "question": question2_1,
            "answer": answer2_1
        })

    # Question 3: How many enemies were destroyed by the player's plane?
    question3_1 = "How many enemies were destroyed by the player's plane?"
    answer3_1 = f"The player's plane destroyed {total_enemies_destroyed} enemies in total."
    qa_list.append({
            "question": question3_1,
            "answer": answer3_1
        })

    # Question 4: Which movement occurred more often, moving left or moving right?
    question4 = "Which movement occurred more often, moving left or moving right?"
    if move_left_count > move_right_count:
        answer4 = f"The plane moved left more often, with {move_left_count} left movements and {move_right_count} right movements."
    elif move_right_count > move_left_count:
        answer4 = f"The plane moved right more often, with {move_right_count} right movements and {move_left_count} left movements."
    else:
        answer4 = f"The plane moved left and right an equal number of times, with {move_left_count} left movements and {move_right_count} right movements."
    qa_list.append({
            "question": question4,
            "answer": answer4
        })

    return  qa_list
    



# ---------------------------- Main Game Loop ----------------------------

def main(t_total=10):
    global total_enemies_destroyed  # Reference the global variable
    running = True
    start_time = pygame.time.get_ticks()  # Record the start time

    # Start recording video
    recorder.start_rec()
    print(f"Start recording video, file name: {file_name}")

    try:
        while running:
            # Calculate elapsed time
            elapsed_time = (pygame.time.get_ticks() - start_time) / 1000  # Convert to seconds

            # Check if the set time has been exceeded
            if elapsed_time > t_total:
                running = False
                print("Set game time reached, ending game.")

            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    running = False
                    print("Quit event received, ending game.")

            # Update all sprites
            all_sprites.update()

            # Detect collisions between the plane and enemies
            if pygame.sprite.spritecollide(player, enemies, True):
                running = False
                print("Plane collides with enemy, game over.")

            # Detect collisions between bullets and enemies
            hits = pygame.sprite.groupcollide(enemies, bullets, True, True)
            for hit in hits:
                print(f"Enemy hit and destroyed, position: ({hit.rect.x}, {hit.rect.y})")
                total_enemies_destroyed += 1  # Increase the destroyed enemy count
                # Regenerate enemy
                enemy = Enemy()
                all_sprites.add(enemy)
                enemies.add(enemy)

            # Draw background
            window.fill(BLACK)

            # Draw all sprites
            all_sprites.draw(window)

            # Refresh the screen
            pygame.display.flip()

            # Control the frame rate
            clock.tick(fps)

    finally:
        # Stop recording and save the file
        recorder.stop_rec()
        recorder.save_recording(file_name)
        print(f"Recording stopped, video saved as: {file_name}")
        pygame.quit()
        print("Pygame exited.")

        # Generate questions and answers and save as a JSON file
        qa_file_name = os.path.join(output_dir, 'plane_battle_' + id + "_qa.json")
        qa_pair = generate_question_answer(
            player.movement_history,
            player.move_left_count,
            player.move_right_count,
            total_enemies_appeared,
            total_enemies_destroyed
        )
        for qa in qa_pair:
            qa['choices'],qa['answer']=generate_choices_for_question(qa['question'],qa['answer'])
        output_data = {
        "questions_answers": qa_pair,
      
    }

        with open(qa_file_name, 'w', encoding='utf-8') as f:
            json.dump(output_data, f, ensure_ascii=False, indent=4)
        print(f"Q&A saved to {qa_file_name}")

# ---------------------------- Main Program Entry ----------------------------

if __name__ == "__main__":

    main(t_total)
