import sys
import pygame
import random

from tqdm import trange
from functools import partial
from pygame.locals import QUIT, KEYDOWN, K_LEFT, K_RIGHT, K_UP, K_DOWN


WINDOW_WIDTH = 640
WINDOW_HEIGHT = 480
BALL_RADIUS = 10
GOAL_RADIUS = 30
BALL_SPEED = 4
GOAL_COLOR = (255, 215, 0)
BALL1_COLOR = (128, 128, 128)
BALL2_COLOR = (135, 206, 250)
OBSTACLE_COLOR = (102, 102, 153)
LINE_COLOR_1 = (0, 206, 209)
LINE_COLOR_2 = (222, 184, 135)
BACKGROUND_COLOR = (204, 255, 255)
OBSTACLE_WIDTH = BALL_RADIUS * 2
OBSTACLE_HEIGHT = BALL_RADIUS * 2

NOTCH = int(4.5 * BALL_RADIUS)
WALL_SIZE = 70

EXP_NUM = 10000
INTERVAL = 2
TOT_LENGTH = 82


class Ball:
    def __init__(self, color, initial_position):
        self.color = color
        self.radius = BALL_RADIUS
        self.position = list(initial_position)
        self.velocity = [0, 0]
        self.traj = []

    def update(self):
        self.position[0] += self.velocity[0]
        self.position[1] += self.velocity[1]
        pos = self.position
        self.traj.append([pos[0], pos[1]])

    def draw(self, surface):
        pygame.draw.circle(surface, self.color, (int(self.position[0]), int(self.position[1])), self.radius)


def dot_product(v1, v2):
    return sum([a * b for a, b in zip(v1, v2)])


def scalar_product(v, n):
    return [i * n for i in v]


def normalize(v):
    m = sum([spam ** 2 for spam in v]) ** 0.5
    return [spam / m for spam in v]


def draw_walls(surface):
    pygame.draw.rect(surface, OBSTACLE_COLOR, (0, 0, WINDOW_WIDTH, BALL_RADIUS))
    pygame.draw.rect(surface, OBSTACLE_COLOR, (0, 0, BALL_RADIUS, WINDOW_HEIGHT))
    pygame.draw.rect(surface, OBSTACLE_COLOR, (0, WINDOW_HEIGHT - BALL_RADIUS, WINDOW_WIDTH, BALL_RADIUS))
    pygame.draw.rect(surface, OBSTACLE_COLOR, (WINDOW_WIDTH - BALL_RADIUS, 0, BALL_RADIUS, WINDOW_HEIGHT))

    cen_x = WINDOW_WIDTH // 2
    cen_y = WINDOW_HEIGHT // 2
    width = BALL_RADIUS
    length = WALL_SIZE - NOTCH
    pygame.draw.rect(surface, OBSTACLE_COLOR, (cen_x - WALL_SIZE, cen_y - WALL_SIZE, width, length))
    pygame.draw.rect(surface, OBSTACLE_COLOR, (cen_x - WALL_SIZE, cen_y - WALL_SIZE, length, width))
    pygame.draw.rect(surface, OBSTACLE_COLOR, (cen_x + NOTCH, cen_y - WALL_SIZE, length, width))
    pygame.draw.rect(surface, OBSTACLE_COLOR, (cen_x + WALL_SIZE - BALL_RADIUS, cen_y - WALL_SIZE, width, length))
    pygame.draw.rect(surface, OBSTACLE_COLOR, (cen_x + WALL_SIZE - BALL_RADIUS, cen_y + NOTCH, width, length))
    pygame.draw.rect(surface, OBSTACLE_COLOR, (cen_x + NOTCH, cen_y + WALL_SIZE - BALL_RADIUS, length, width))
    pygame.draw.rect(surface, OBSTACLE_COLOR, (cen_x - WALL_SIZE, cen_y + WALL_SIZE - BALL_RADIUS, length, width))
    pygame.draw.rect(surface, OBSTACLE_COLOR, (cen_x - WALL_SIZE, cen_y + NOTCH, width, length))


def draw_goal(surface, position):
    pygame.draw.circle(surface, GOAL_COLOR, position, GOAL_RADIUS)


def handle_collision(ball1, ball2):
    distance = ((ball1.position[0] - ball2.position[0]) ** 2 + (ball1.position[1] - ball2.position[1]) ** 2) ** 0.5
    if distance <= ball1.radius + ball2.radius:
        # Calculate the normalized collision vector
        collision_vector = [ball2.position[0] - ball1.position[0], ball2.position[1] - ball1.position[1]]
        collision_vector = normalize(collision_vector)
        # Calculate the relative velocity
        relative_velocity = [ball2.velocity[0] - ball1.velocity[0], ball2.velocity[1] - ball1.velocity[1]]
        # Calculate the collision impulse
        impulse = dot_product(relative_velocity, collision_vector) * 2 / (ball1.radius + ball2.radius)
        # Update the velocities of the balls
        ball1.velocity[0] += impulse * collision_vector[0]
        ball1.velocity[1] += impulse * collision_vector[1]
        ball2.velocity[0] -= impulse * collision_vector[0]
        ball2.velocity[1] -= impulse * collision_vector[1]
    if ball1.position[0] - ball1.radius <= 0 or ball1.position[0] + ball1.radius >= WINDOW_WIDTH:
        ball1.velocity[0] = -ball1.velocity[0]
    if ball1.position[1] - ball1.radius <= 0 or ball1.position[1] + ball1.radius >= WINDOW_HEIGHT:
        ball1.velocity[1] = -ball1.velocity[1]

    cen_x = WINDOW_WIDTH // 2
    cen_y = WINDOW_HEIGHT // 2
    width = BALL_RADIUS
    length = WALL_SIZE - NOTCH

    if (cen_x - WALL_SIZE <= ball1.position[0] <= cen_x - WALL_SIZE + width or
        cen_x - WALL_SIZE <= ball1.position[0] + ball1.radius <= cen_x - WALL_SIZE + width) and \
            cen_y - WALL_SIZE <= ball1.position[1] <= cen_y - WALL_SIZE + length:
        ball1.velocity = [0, 0]

    if (cen_x - WALL_SIZE <= ball1.position[0] <= cen_x - WALL_SIZE + length or
        cen_x - WALL_SIZE <= ball1.position[0] + ball1.radius <= cen_x - WALL_SIZE + length) and \
            cen_y - WALL_SIZE <= ball1.position[1] <= cen_y - WALL_SIZE + width:
        ball1.velocity = [0, 0]

    if (cen_x + NOTCH <= ball1.position[0] <= cen_x + NOTCH + length or
        cen_x + NOTCH <= ball1.position[0] + ball1.radius <= cen_x + NOTCH + length) and \
            cen_y - WALL_SIZE <= ball1.position[1] <= cen_y - WALL_SIZE + width:
        ball1.velocity = [0, 0]

    if (cen_x + WALL_SIZE - BALL_RADIUS <= ball1.position[0] <= cen_x + WALL_SIZE or
        cen_x + WALL_SIZE - BALL_RADIUS <= ball1.position[0] + ball1.radius <= cen_x + WALL_SIZE) and \
            cen_y - WALL_SIZE <= ball1.position[1] <= cen_y - WALL_SIZE + length:
        ball1.velocity = [0, 0]

    if (cen_x + WALL_SIZE - BALL_RADIUS <= ball1.position[0] <= cen_x + WALL_SIZE or
        cen_x + WALL_SIZE - BALL_RADIUS <= ball1.position[0] + ball1.radius <= cen_x + WALL_SIZE) and \
            cen_y + NOTCH <= ball1.position[1] <= cen_y + NOTCH + length:
        ball1.velocity = [0, 0]

    if (cen_x + NOTCH <= ball1.position[0] <= cen_x + NOTCH + length or
        cen_x + NOTCH <= ball1.position[0] + ball1.radius <= cen_x + NOTCH + length) and \
            cen_y + WALL_SIZE - BALL_RADIUS <= ball1.position[1] <= cen_y + WALL_SIZE:
        ball1.velocity = [0, 0]

    if (cen_x - WALL_SIZE <= ball1.position[0] <= cen_x - WALL_SIZE + width or
        cen_x - WALL_SIZE <= ball1.position[0] + ball1.radius <= cen_x - WALL_SIZE + width) and \
            cen_y + WALL_SIZE - BALL_RADIUS <= ball1.position[1] <= cen_y + WALL_SIZE:
        ball1.velocity = [0, 0]

    if (cen_x - WALL_SIZE <= ball1.position[0] <= cen_x - WALL_SIZE + length or
        cen_x - WALL_SIZE <= ball1.position[0] + ball1.radius <= cen_x - WALL_SIZE + length) and \
            cen_y + NOTCH <= ball1.position[1] <= cen_y + NOTCH + width:
        ball1.velocity = [0, 0]

    if (cen_x - WALL_SIZE <= ball1.position[0] <= cen_x - WALL_SIZE + length or
        cen_x - WALL_SIZE <= ball1.position[0] + ball1.radius <= cen_x - WALL_SIZE + length) and \
            cen_y + WALL_SIZE - BALL_RADIUS <= ball1.position[1] <= cen_y + WALL_SIZE:
        ball1.velocity = [0, 0]

    if (cen_x + NOTCH <= ball1.position[0] <= cen_x + NOTCH + width or
        cen_x + NOTCH <= ball1.position[0] + ball1.radius <= cen_x + NOTCH + width) and \
            cen_y + WALL_SIZE - BALL_RADIUS <= ball1.position[1] <= cen_y + WALL_SIZE:
        ball1.velocity = [0, 0]


    if (cen_x - WALL_SIZE <= ball2.position[0] <= cen_x - WALL_SIZE + width or
        cen_x - WALL_SIZE <= ball2.position[0] + ball2.radius <= cen_x - WALL_SIZE + width) and \
            cen_y - WALL_SIZE <= ball2.position[1] <= cen_y - WALL_SIZE + length:
        ball2.velocity = [0, 0]

    if (cen_x - WALL_SIZE <= ball2.position[0] <= cen_x - WALL_SIZE + length or
        cen_x - WALL_SIZE <= ball2.position[0] + ball2.radius <= cen_x - WALL_SIZE + length) and \
            cen_y - WALL_SIZE <= ball2.position[1] <= cen_y - WALL_SIZE + width:
        ball2.velocity = [0, 0]

    if (cen_x + NOTCH <= ball2.position[0] <= cen_x + NOTCH + length or
        cen_x + NOTCH <= ball2.position[0] + ball2.radius <= cen_x + NOTCH + length) and \
            cen_y - WALL_SIZE <= ball2.position[1] <= cen_y - WALL_SIZE + width:
        ball2.velocity = [0, 0]

    if (cen_x + WALL_SIZE - BALL_RADIUS <= ball2.position[0] <= cen_x + WALL_SIZE or
        cen_x + WALL_SIZE - BALL_RADIUS <= ball2.position[0] + ball2.radius <= cen_x + WALL_SIZE) and \
            cen_y - WALL_SIZE <= ball2.position[1] <= cen_y - WALL_SIZE + length:
        ball2.velocity = [0, 0]

    if (cen_x + WALL_SIZE - BALL_RADIUS <= ball2.position[0] <= cen_x + WALL_SIZE or
        cen_x + WALL_SIZE - BALL_RADIUS <= ball2.position[0] + ball2.radius <= cen_x + WALL_SIZE) and \
            cen_y + NOTCH <= ball2.position[1] <= cen_y + NOTCH + length:
        ball2.velocity = [0, 0]

    if (cen_x + NOTCH <= ball2.position[0] <= cen_x + NOTCH + length or
        cen_x + NOTCH <= ball2.position[0] + ball2.radius <= cen_x + NOTCH + length) and \
            cen_y + WALL_SIZE - BALL_RADIUS <= ball2.position[1] <= cen_y + WALL_SIZE:
        ball2.velocity = [0, 0]

    if (cen_x - WALL_SIZE <= ball2.position[0] <= cen_x - WALL_SIZE + width or
        cen_x - WALL_SIZE <= ball2.position[0] + ball2.radius <= cen_x - WALL_SIZE + width) and \
            cen_y + WALL_SIZE - BALL_RADIUS <= ball2.position[1] <= cen_y + WALL_SIZE:
        ball2.velocity = [0, 0]

    if (cen_x - WALL_SIZE <= ball2.position[0] <= cen_x - WALL_SIZE + length or
        cen_x - WALL_SIZE <= ball2.position[0] + ball2.radius <= cen_x - WALL_SIZE + length) and \
            cen_y + NOTCH <= ball2.position[1] <= cen_y + NOTCH + width:
        ball2.velocity = [0, 0]

    if (cen_x - WALL_SIZE <= ball2.position[0] <= cen_x - WALL_SIZE + length or
        cen_x - WALL_SIZE <= ball2.position[0] + ball2.radius <= cen_x - WALL_SIZE + length) and \
            cen_y + WALL_SIZE - BALL_RADIUS <= ball2.position[1] <= cen_y + WALL_SIZE:
        ball2.velocity = [0, 0]

    if (cen_x + NOTCH <= ball2.position[0] <= cen_x + NOTCH + width or
        cen_x + NOTCH <= ball2.position[0] + ball2.radius <= cen_x + NOTCH + width) and \
            cen_y + WALL_SIZE - BALL_RADIUS <= ball2.position[1] <= cen_y + WALL_SIZE:
        ball2.velocity = [0, 0]


def greedy_policy(goal_position, ball, steps):
    res_x = goal_position[0] - ball.position[0]
    res_y = goal_position[1] - ball.position[1]

    ratio = abs(res_x / res_y)
    x = (BALL_SPEED ** 2 / (1 + ratio ** 2)) ** 0.5

    x_velocity = ratio * x * res_x / abs(res_x)
    y_velocity = x * res_y / abs(res_y)
    return x_velocity, y_velocity


def epsilon_greedy_policy(goal_position, ball, steps, epsilon=0.9):
    res_x = goal_position[0] - ball.position[0]
    res_y = goal_position[1] - ball.position[1]

    ran = random.random()
    if ran < epsilon:
        ratio = abs(res_x / res_y)
        x_sign = 1
        y_sign = 1
    else:
        ratio = abs(random.random() + 1e-5 / random.random() + 1e-5)
        x_sign = random.choice((-1, 1))
        y_sign = random.choice((-1, 1))

    x = (BALL_SPEED ** 2 / (1 + ratio ** 2)) ** 0.5

    x_velocity = ratio * x * res_x / abs(res_x) * x_sign
    y_velocity = x * res_y / abs(res_y) * y_sign
    return x_velocity, y_velocity


def stable_prefix_policy(goal_position, ball, steps, stable_steps=10, epsilon=0.9):
    res_x = goal_position[0] - ball.position[0]
    res_y = goal_position[1] - ball.position[1]

    ran = random.random()
    if ran < epsilon or steps < stable_steps:
        ratio = abs(res_x / res_y)
        x_sign = 1
        y_sign = 1
    else:
        ratio = abs(random.random() + 1e-5 / random.random() + 1e-5)
        x_sign = random.choice((-1, 1))
        y_sign = random.choice((-1, 1))

    x = (BALL_SPEED ** 2 / (1 + ratio ** 2)) ** 0.5

    x_velocity = ratio * x * res_x / abs(res_x) * x_sign
    y_velocity = x * res_y / abs(res_y) * y_sign
    return x_velocity, y_velocity


def main(move_policy, INTERVALs=200, TOT_LENGTH=1000, render=False):
    if render:
        pygame.init()
        window_surface = pygame.display.set_mode((WINDOW_WIDTH, WINDOW_HEIGHT))
        pygame.display.set_caption("Ball Game")
        clock = pygame.time.Clock()

    ball1 = Ball(BALL1_COLOR, (50, 50))
    ball2 = Ball(BALL2_COLOR, (WINDOW_WIDTH - 50, WINDOW_HEIGHT - 50))
    goal_position = (WINDOW_WIDTH // 2, WINDOW_HEIGHT // 2)
    success_flag = False

    for step in range(TOT_LENGTH):
        if abs(ball1.position[0] - goal_position[0]) < GOAL_RADIUS and \
                abs(ball1.position[1] - goal_position[1]) < GOAL_RADIUS:
            if abs(ball2.position[0] - goal_position[0]) < GOAL_RADIUS and \
                    abs(ball2.position[1] - goal_position[1]) < GOAL_RADIUS:
                success_flag = True
                break

        ball1.update()
        ball2.update()

        # policy INTERVALs
        if step % INTERVALs == 0:
            ball1.velocity[0], ball1.velocity[1] = move_policy(goal_position, ball1, step)
            ball2.velocity[0], ball2.velocity[1] = move_policy(goal_position, ball2, step)

        handle_collision(ball1, ball2)

        if render:
            window_surface.fill(BACKGROUND_COLOR)
            draw_walls(window_surface)
            draw_goal(window_surface, goal_position)
            ball1.draw(window_surface)
            ball2.draw(window_surface)
            pygame.display.update()
            clock.tick(20)

    if render:
        pygame.quit()
    return int(success_flag), ball1.traj, ball2.traj

def render_traj(ball1_trajs, ball2_trajs, name):
    goal_position = (WINDOW_WIDTH // 2, WINDOW_HEIGHT // 2)

    pygame.init()
    window_surface = pygame.display.set_mode((WINDOW_WIDTH, WINDOW_HEIGHT))
    pygame.display.set_caption("Ball Game")
    window_surface.fill(BACKGROUND_COLOR)
    for ball1_traj in ball1_trajs:
        for point in ball1_traj:
            pygame.draw.circle(window_surface, LINE_COLOR_1, (int(point[0]), int(point[1])), 2)
    for ball2_traj in ball2_trajs:
        for point in ball2_traj:
            pygame.draw.circle(window_surface, LINE_COLOR_2, (int(point[0]), int(point[1])), 2)
    pygame.draw.circle(window_surface, BALL1_COLOR, (50, 50), BALL_RADIUS)
    pygame.draw.circle(window_surface, BALL2_COLOR, (WINDOW_WIDTH - 50, WINDOW_HEIGHT - 50), BALL_RADIUS)
    draw_walls(window_surface)

    draw_goal(window_surface, goal_position)
    goal_image = pygame.image.load("goal.png").convert_alpha()
    goal_image = pygame.transform.scale(goal_image, (40, 40))
    window_surface.blit(goal_image, (goal_position[0] - 20, goal_position[1] - 20))

    pygame.display.update()
    pygame.image.save(window_surface, name)


if __name__ == "__main__":
    num, ball1_trajs, ball2_trajs = 0, [], []
    policy = greedy_policy
    for _ in trange(EXP_NUM):
        result, traj1, traj2 = main(greedy_policy, INTERVAL, TOT_LENGTH)
        num += result
        ball1_trajs += [traj1]
        ball2_trajs += [traj2]
    print(f"greedy policy successful rate: {num / EXP_NUM}")
    render_traj(ball1_trajs, ball2_trajs, "greedy.png")

    num, ball1_trajs, ball2_trajs = 0, [], []
    policy = partial(epsilon_greedy_policy, epsilon=0.9)
    for _ in trange(EXP_NUM):
        result, traj1, traj2 = main(policy, INTERVAL, TOT_LENGTH)
        num += result
        ball1_trajs += [traj1]
        ball2_trajs += [traj2]
    print(f"epsilon-greedy policy successful rate: {num / EXP_NUM}")
    render_traj(ball1_trajs, ball2_trajs, "epsilon-greedy.png")

    num, ball1_trajs, ball2_trajs = 0, [], []
    policy = partial(stable_prefix_policy, stable_steps=40, epsilon=0.9)
    for _ in trange(EXP_NUM):
        result, traj1, traj2 = main(policy, INTERVAL, TOT_LENGTH)
        num += result
        ball1_trajs += [traj1]
        ball2_trajs += [traj2]
    print(f"stable-prefix policy successful rate: {num / EXP_NUM}")
    render_traj(ball1_trajs, ball2_trajs, "stable-prefix.png")

