import math
import sys
import numpy as np

import pygame
from pygame.constants import K_w, K_s, K_a, K_d, K_UP, K_DOWN, K_LEFT, K_RIGHT
from ple.games.utils.vec2d import vec2d
from ple.games.utils import percent_round_int

#import base
from ple.games.base.pygamewrapper import PyGameWrapper

class Ball(pygame.sprite.Sprite):

    def __init__(self, radius, speed, rng,
                 pos_init, SCREEN_WIDTH, SCREEN_HEIGHT):

        pygame.sprite.Sprite.__init__(self)

        self.rng = rng
        self.radius = radius
        self.speed = speed
        self.pos = vec2d(pos_init)
        self.pos_before = vec2d(pos_init)
        self.vel = vec2d((speed, -1.0 * speed))

        self.SCREEN_HEIGHT = SCREEN_HEIGHT
        self.SCREEN_WIDTH = SCREEN_WIDTH

        image = pygame.Surface((radius * 2, radius * 2))
        image.fill((0, 0, 0, 0))
        image.set_colorkey((0, 0, 0))

        pygame.draw.circle(
            image,
            (255, 255, 255),
            (radius, radius),
            radius,
            0
        )

        self.image = image
        self.rect = self.image.get_rect()
        self.rect.center = pos_init

    def line_from_pts(self, p0x, p0y, p1x, p1y):
        """ 
        Function to compute the equation of a line between 2 points

        Inputs:
        - p0x, p0y : (x,y) coordinates of p0
        - p1x, p1y : (x,y) coordinates of p1

        Outout:
        - homogenous line equation [a,b,c] such that ax + by + c = 0
        """
        return np.cross(np.array([p0x, p0y, 1.]), np.array([p1x, p1y, 1.]))

    def line_from_pt_vec(self, p0x, p0y, vx, vy):
        """
        Function to get a line through a point in the direction of a vector

        Inputs:
        - p0x, p0y : (x,y) coordinates of point
        - vx, vy : <vx, vy> vector

        Output:
        - homogenous line equation [a,b,c] such that ax + by + c = 0
        """
        pt1 = np.array([p0x, p0y, 1.])
        pt2 = pt1 + np.array([vx, vy, 0.])

        return np.cross(pt1, pt2)

    def line_intersectionX(self, line1, line2):
        """
        Function that finds the point of intersection between two lines

        Inputs:
        - line1, line2: homogenous line equations [a,b,c] such that ax + by + c = 0

        Output:
        - [x,y] : coordinates of intersection
        """
        if (line1.shape[0] != 3) or (line2.shape[0] != 3):
            raise RuntmeError("Lines of incorrect dimensions in call - check ball class in pong.py")

        # Point of intersection computed as the cross product between the line equations
        pt = np.cross(line1, line2)
        # Point is normalized by 3rd vector component to confrom with homogenous coordinate representation
        # if p[3] = 0 -> no intersection, but will result in math error. So a small number is used 
        # (large values are later discarded)
        pt = pt / math.copysign(max(abs(pt[2]), 1e-5), pt[2])

        return pt[:2]

    def line_intersection(self, p0_x, p0_y, p1_x, p1_y, p2_x, p2_y, p3_x, p3_y):
        """
        Original intersection function provided by PLE Pong
        """
        s1_x = p1_x - p0_x
        s1_y = p1_y - p0_y
        s2_x = p3_x - p2_x
        s2_y = p3_y - p2_y

        s = (-s1_y * (p0_x - p2_x) + s1_x * (p0_y - p2_y)) / (-s2_x * s1_y + s1_x * s2_y)
        t = (s2_x * (p0_y - p2_y) - s2_y * (p0_x - p2_x)) / (-s2_x * s1_y + s1_x * s2_y)

        return (s >= 0 and s <= 1 and t >= 0 and t <= 1)

    def computeIntersections(self, Player, tag='player1'):
        """
        Box colider function

        Inputs:
        - Player : Player object
        - tag : identify player 1 or 2 to set face coordinates appropriately

        Outputs:
        - intersection : (bool) was a collision detected?
        - intersection_point : [x,y] coordinate of intersection point
        - intersection_face : which face of the paddle the intersection was with
        """

        # Convert pad angle to radians
        Pad_ang = Player.angle
        Pad_ang *= np.pi/180.
        # Comptue rotation matrix for projection pad bounds relative to pad angle (following right-handed rule)
        R_mat = np.array([[np.cos(Pad_ang),  np.sin(Pad_ang)],
                          [-np.sin(Pad_ang), np.cos(Pad_ang)]])
        
        # Add a buffer to pad bounds to prevent appearance of ball passing trhough paddle edges
        tolerance_ = 1.
        # Set bounds wrt player - [x,y] coordinates of top front, bottom front, top back, bottom back
        if tag == 'player1':
            Pad_bounds = np.array([[ Player.rect_width / 2, -Player.rect_height / 2 - tolerance_],
                                   [ Player.rect_width / 2,  Player.rect_height / 2 + tolerance_],
                                   [-Player.rect_width / 2, -Player.rect_height / 2 - tolerance_],
                                   [-Player.rect_width / 2,  Player.rect_height / 2 + tolerance_]]).T
        else:
            Pad_bounds = np.array([[-Player.rect_width / 2, -Player.rect_height / 2 - tolerance_],
                                   [-Player.rect_width / 2,  Player.rect_height / 2 + tolerance_],
                                   [ Player.rect_width / 2, -Player.rect_height / 2 - tolerance_],
                                   [ Player.rect_width / 2,  Player.rect_height / 2 + tolerance_]]).T

        # Cast to world coordinates
        Pad_Rbounds = (np.matmul(R_mat, Pad_bounds).T + np.array([Player.pos.x, Player.pos.y])).T
        # Compute the lines representing the different surfaces of the pad
        Pad_face_ln = self.line_from_pts(Pad_Rbounds[0,0], Pad_Rbounds[1,0], Pad_Rbounds[0,1], Pad_Rbounds[1,1])
        Pad_back_ln = self.line_from_pts(Pad_Rbounds[0,2], Pad_Rbounds[1,2], Pad_Rbounds[0,3], Pad_Rbounds[1,3])
        Pad_upside_ln = self.line_from_pts(Pad_Rbounds[0,0], Pad_Rbounds[1,0], Pad_Rbounds[0,2], Pad_Rbounds[1,2])
        Pad_downside_ln = self.line_from_pts(Pad_Rbounds[0,1], Pad_Rbounds[1,1], Pad_Rbounds[0,3], Pad_Rbounds[1,3])

        # Comppute the trajectory line of the ball
        ball_vec = (self.pos.x - self.pos_before.x, self.pos.y - self.pos_before.y)
        ball_line = self.line_from_pt_vec(self.pos.x, self.pos.y, ball_vec[0], ball_vec[1])

        # Determine if the ball will intersect with any of the pad surfaces
        ball_face_intersect = self.line_intersectionX(ball_line, Pad_face_ln)
        ball_back_intersect = self.line_intersectionX(ball_line, Pad_back_ln)
        ball_upside_intersect = self.line_intersectionX(ball_line, Pad_upside_ln)
        ball_downside_intersect = self.line_intersectionX(ball_line, Pad_downside_ln)

        # Determine the points of intersection in the coordinate frame of the paddle
        intersect_pts = np.array([ball_face_intersect, ball_back_intersect, ball_upside_intersect, ball_downside_intersect])
        intersect_pts_ = intersect_pts - np.array([Player.pos.x, Player.pos.y])
        rotated_intersects = np.matmul(R_mat.T, intersect_pts_.T).T

        # Compute intersection 
        # --------------------
        dist = self.radius + 0.1
        intersection = False
        intersection_pt = None
        intersection_face = None

        # Determine closest valid surface intersection
        # If the intersection is within the radius +  a small distance from the ball, then mark as intersection
        # If the point of intersection is closer than a previous valid intersection, 
        eps = 1e-3
        for p_id, pt in enumerate(rotated_intersects):
            if (pt[0] - Player.rect_width/2 <= eps) and (pt[0] + Player.rect_width/2 >= -eps) and (pt[1] - Player.rect_height/2 <= eps) and (pt[1] + Player.rect_height/2 >= -eps):
                pt_0 = intersect_pts[p_id]
                d_ = np.linalg.norm(pt_0 - np.array([self.pos.x, self.pos.y]))
                if d_ < dist:
                    dist = d_
                    intersection = True
                    intersection_pt = pt_0.copy()
                    if p_id == 0: 
                        intersection_face = 'face'
                    elif p_id == 1:
                        intersection_face = 'back'
                    elif p_id == 2:
                        intersection_face = 'upside'
                    elif p_id == 3:
                        intersection_face = 'downside'

        # Return computations
        return intersection, intersection_pt, intersection_face


    def update(self, player1, player2, dt):
        """
        Frame update for Pong game

        -- Determines the physics for the interaction between ball and paddles
        """

        self.pos.x += self.vel.x * dt
        self.pos.y += self.vel.y * dt

        is_pad_hit = False

        # Determine intersections between the player paddles and ball
        player1_intersection, player1_intersection_pt, player1_intersection_face = self.computeIntersections(player1, tag='player1')
        player2_intersection, player2_intersection_pt, player2_intersection_face = self.computeIntersections(player2, tag='player2')

        # Compute collision response
        # --------------------------
        # Method:
        # - Set the normal vector of the collision surface of the paddle according to which paddle and which surface was collided with
        # - Normal vector is perpendicualr to paddle surface
        # - Reflect ball around normal vector of the surface of collision
        # - On collision, increase horizontal speed slightly (within a limit)
        # - On collision, impart some vertical momentum on paddle relative to y-velocity of paddle
        # --------------------------
        if player1_intersection:
            player1_angle = player1.angle * np.pi/180.
            if player1_intersection_face == 'face':
                n = np.array([np.cos(player1_angle), -np.sin(player1_angle)])
            if player1_intersection_face == 'downside':
                n = np.array([np.sin(player1_angle),  np.cos(player1_angle)])
            if player1_intersection_face == 'back':
                n = np.array([-np.cos(player1_angle), np.sin(player1_angle)])
            if player1_intersection_face == 'upside':
                n = np.array([-np.sin(player1_angle), np.cos(player1_angle)])
            
            self.pos.x = max(0, self.pos.x)
            vel = np.array([min(self.speed*3,(self.vel.x + math.copysign(self.speed,self.vel.x) * 0.05)), 
                            self.vel.y + player1.vel.y*0.1]) # 2.0
            mag = np.linalg.norm(vel)
            vel /= mag

            dp = np.dot(vel, n)
            if dp < 0:
                new_vel = (vel - 2*np.dot(vel, n)*n)*mag
            else:
                new_vel = vel * mag
            self.vel.x = new_vel[0]
            self.vel.y = new_vel[1]
            self.pos.x += self.radius
            is_pad_hit = True

        if player2_intersection:
            player2_angle = -player2.angle * np.pi/180.
            if player2_intersection_face == 'face':
                n = np.array([-np.cos(player2_angle),  np.sin(player2_angle)])
            if player2_intersection_face == 'downside':
                n = np.array([-np.sin(player2_angle), -np.cos(player2_angle)])
            if player2_intersection_face == 'back':
                n = np.array([ np.cos(player2_angle), -np.sin(player2_angle)])
            if player2_intersection_face == 'upside':
                n = np.array([ np.sin(player2_angle), -np.cos(player2_angle)])
            
            self.pos.x = min(self.SCREEN_WIDTH, self.pos.x)

            vel = np.array([min(self.speed*3,(self.vel.x + math.copysign(self.speed,self.vel.x) * 0.05)), 
                            self.vel.y + player2.vel.y*0.1]) # 0.006
            mag = np.linalg.norm(vel)
            vel /= mag

            dp = np.dot(vel, n)
            if dp < 0:
                new_vel = (vel - 2*np.dot(vel, n)*n)*mag
            else:
                new_vel = vel * mag
            self.vel.x = new_vel[0]
            self.vel.y = new_vel[1]
            self.pos.x -= self.radius
            is_pad_hit = True
        
        """ HERE BE OLD CODE """
        # if self.pos.x <= player1.pos.x + player1.rect_height:
        # print(rotated_intersects)
        #     if self.line_intersection(self.pos_before.x, self.pos_before.y, self.pos.x, self.pos.y, player1.pos.x + player1.rect_width / 2, player1.pos.y - player1.rect_height / 2, player1.pos.x + player1.rect_width / 2, player1.pos.y + player1.rect_height / 2):
        #     if self.line_intersection(self.pos_before.x, self.pos_before.y, self.pos.x, self.pos.y, aPad_up_x, aPad_up_y, aPad_down_x, aPad_down_y):
        #         self.pos.x = max(0, self.pos.x)
        #         self.vel.x = -1 * min(self.speed*3,(self.vel.x + math.copysign(self.speed,self.vel.x) * 0.05)) # this is wrong
        #         self.vel.y += player1.vel.y * 2.0
        #         self.pos.x += self.radius
        #         is_pad_hit = True

        # if self.pos.x >= player2.pos.x - player2.rect_height:
        #     if self.line_intersection(self.pos_before.x, self.pos_before.y, self.pos.x, self.pos.y, player2.pos.x - player2.rect_width / 2, player2.pos.y - player2.rect_height / 2, player2.pos.x - player2.rect_width / 2, player2.pos.y + player2.rect_height / 2):
        #         self.pos.x = min(self.SCREEN_WIDTH, self.pos.x)
        #         self.vel.x = -1 * min(self.speed*3,(self.vel.x + math.copysign(self.speed,self.vel.x) * 0.05))
        #         self.vel.y += player2.vel.y * 0.006
        #         self.pos.x -= self.radius
        #         is_pad_hit = True
        """ END OF OLD CODE """

        # Little randomness in order not to stuck in a static loop
        if is_pad_hit:
            self.vel.y += self.rng.random_sample() * 0.001 - 0.0005

        if self.pos.y - self.radius <= 0:
            self.vel.y *= -0.99
            self.pos.y += 1.0

        if self.pos.y + self.radius >= self.SCREEN_HEIGHT:
            self.vel.y *= -0.99
            self.pos.y -= 1.0

        self.pos_before.x = self.pos.x
        self.pos_before.y = self.pos.y

        self.rect.center = (self.pos.x, self.pos.y)


class Player(pygame.sprite.Sprite):

    def __init__(self, speed, rect_width, rect_height,
                 pos_init, SCREEN_WIDTH, SCREEN_HEIGHT, color=(255,255,255), r_delta=1, player_tag=None):

        pygame.sprite.Sprite.__init__(self)

        # Player tag added to ensure that angle updates are handled properly (reflected between players)
        # This is done to preserve compatibility of trained agents
        if (player_tag is None) or player_tag not in ['player1', 'player2']:
            raise ValueError('Please set player tag')
        else:
            self.player_tag = player_tag

        self.speed = speed
        self.r_delta = r_delta
        self.pos = vec2d(pos_init)
        self.vel = vec2d((0, 0))

        self.rect_height = rect_height
        self.rect_width = rect_width
        self.SCREEN_HEIGHT = SCREEN_HEIGHT
        self.SCREEN_WIDTH = SCREEN_WIDTH

        image = pygame.Surface((rect_width, rect_height))
        image.fill((0, 0, 0, 0))
        image.set_colorkey((0, 0, 0))


        pygame.draw.rect(
            image,
            color,
            (0, 0, rect_width, rect_height),
            0
        )

        self.angle = 0
        self.angle_d = 1
        self.image_ori = image
        self.image = self.image_ori.copy()
        self.rect = self.image.get_rect()
        self.rect.center = pos_init

    def update(self, dy, r_delta, dt):
        # Modified speed behavior from original code to make more sense
        # Speed can be incremented/decremented 10% of max per step before being capped at max amplitude
        self.vel.y += dy * 0.5
        # self.vel.y *= 0.9
        self.vel.y = math.copysign(min(abs(self.vel.y), self.speed), self.vel.y) # added to limit max speed

        self.pos.y += self.vel.y * dt

        if self.pos.y - self.rect_height / 2 <= 0:
            self.pos.y = self.rect_height / 2
            self.vel.y = 0.0

        if self.pos.y + self.rect_height / 2 >= self.SCREEN_HEIGHT:
            self.pos.y = self.SCREEN_HEIGHT - self.rect_height / 2
            self.vel.y = 0.0

        self.angle = np.clip(self.angle + r_delta, -30, 30)

        if self.player_tag == 'player1':
            self.image = pygame.transform.rotate(self.image_ori, self.angle)
        elif self.player_tag == 'player2':
            self.image = pygame.transform.rotate(self.image_ori, -self.angle)
        else:
            raise RuntimeError('Oops! Should not have come here. Check!!')

        self.rect = self.image.get_rect()
        self.rect.center = (self.pos.x, self.pos.y)

    def updateCpu(self, ball, dt, r_delta=0, player='player2'):
        dy = 0.0
        if ball.vel.x >= 0 and ball.pos.x >= self.SCREEN_WIDTH * 0.5: #/2
            dy = self.speed
            if self.pos.y > ball.pos.y:
                dy = -1.0 * dy
        else:
            dy = 1.0 * self.speed / 4.0

            if self.pos.y > self.SCREEN_HEIGHT / 2.0:
                dy = -1.0 * self.speed / 4.0

        if self.pos.y - self.rect_height / 2 <= 0:
            self.pos.y = self.rect_height / 2
            self.vel.y = 0.0

        if self.pos.y + self.rect_height / 2 >= self.SCREEN_HEIGHT:
            self.pos.y = self.SCREEN_HEIGHT - self.rect_height / 2
            self.vel.y = 0.0
        
        self.image = pygame.transform.rotate(self.image_ori, self.angle)
        self.vel.y = dy

        self.pos.y += self.vel.y * dt
        self.rect.center = (self.pos.x, self.pos.y)


class Pong(PyGameWrapper):
    """
    Loosely based on code from marti1125's `pong game`_.

    .. _pong game: https://github.com/marti1125/pong/

    Parameters
    ----------
    width : int
        Screen width.

    height : int
        Screen height, recommended to be same dimension as width.

    MAX_SCORE : int (default: 11)
        The max number of points the player1 or player2 need to score to cause a terminal state.
        
    player2_speed_ratio: float (default: 0.5)
        Speed of opponent (useful for curriculum learning)
        
    player1_speed_ratio: float (default: 0.25)
        Speed of player (useful for curriculum learning)

    ball_speed_ratio: float (default: 0.75)
        Speed of ball (useful for curriculum learning)

    p2_enabled: bool (default: False)
        Whether or not to have p2 respond to external controls or use the default CPU behavior

    """

    def __init__(self, width=64, height=48, player1_speed_ratio = 0.4, player2_speed_ratio=0.4, ball_speed_ratio=0.75,  MAX_SCORE=11, p2_enabled=False):

        actions = {
            "up": K_w,
            "down": K_s,
            "r_cc": K_a,
            "r_c": K_d,
            "up2": K_UP,
            "down2": K_DOWN,
            "r_cc2": K_RIGHT,
            "r_c2": K_LEFT
        }

        self.p2_enabled = p2_enabled

        self.width = width
        self.height = height

        PyGameWrapper.__init__(self, width, height, actions=actions)

        # the %'s come from original values, wanted to keep same ratio when you
        # increase the resolution.
        self.ball_radius = percent_round_int(height, 0.03)

        self.player1_speed_ratio = player1_speed_ratio
        self.player2_speed_ratio = player2_speed_ratio
        self.ball_speed_ratio = ball_speed_ratio

        self.paddle_width = percent_round_int(width, 0.023)
        self.paddle_height = percent_round_int(height, 0.15)
        self.paddle_dist_to_wall = percent_round_int(width, 0.0625)
        self.MAX_SCORE = MAX_SCORE

        self.dy1 = 0.0
        self.dr1 = 0.0
        self.dy2 = 0.0
        self.dr2 = 0.0

        self.p1score_sum = 0.0  # need to deal with 11 on either side winning
        self.p2score_sum = 0.0  # need to deal with 11 on either side winning

        self.score_counts = {
            "player1": 0.0,
            "player2": 0.0
        }

    def _handle_player_events(self):
        self.dy1 = 0.
        self.dr1 = 0.

        self.dy2 = 0.
        self.dr2 = 0.

        # Added control inputs for paddle rotations and simultaneous Player-2 control
        if __name__ == "__main__":
            # for debugging mode
            pygame.event.get()
            keys = pygame.key.get_pressed()
            if keys[self.actions['up']]:
                self.dy1 = -self.player1.speed
            elif keys[self.actions['down']]:
                self.dy1 = self.player1.speed
            # -----------------------------------
            # Here begins new code:
            # - added counter-clockwise and clockwise rotations
            elif keys[self.actions['r_cc']]: 
                self.dr1 = self.player1.r_delta
            elif keys[self.actions['r_c']]:
                self.dr1 = -self.player1.r_delta

            # - added controls for player 2
            # enable p2 control in game to have this matter
            if keys[self.actions['up2']]:
                self.dy2 = -self.player2.speed
            elif keys[self.actions['down2']]:
                self.dy2 = self.player2.speed
            elif keys[self.actions['r_cc2']]:
                self.dr2 = self.player2.r_delta
            elif keys[self.actions['r_c2']]:
                self.dr2 = -self.player2.r_delta            

            if keys[pygame.QUIT]:
                pygame.quit()
                sys.exit()
            pygame.event.pump()
        else:
            # consume events from act
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    pygame.quit()
                    sys.exit()

                if event.type == pygame.KEYUP: # KEYDOWN
                    key = event.key
                    if key == self.actions['up']:
                        self.dy1 = -self.player1.speed

                    if key == self.actions['down']:
                        self.dy1 = self.player1.speed

                    if key == self.actions['r_cc']:
                        self.dr1 = self.player1.r_delta

                    if key == self.actions['r_c']:
                        self.dr1 = -self.player1.r_delta

                    if key == self.actions['up2']:
                        self.dy2 = -self.player2.speed

                    if key == self.actions['down2']:
                        self.dy2 = self.player2.speed

                    if key == self.actions['r_cc2']:
                        self.dr2 = self.player2.r_delta

                    if key == self.actions['r_c2']:
                        self.dr2 = -self.player2.r_delta


    def getGameState(self):
        """
        Gets a non-visual state representation of the game.

        Returns
        -------

        dict
            * player y position.
            * players velocity.
            * player2 y position.
            * ball x position.
            * ball y position.
            * ball x velocity.
            * ball y velocity.
            * paddle angle

            See code for structure.

        Returns a dict of dicts which needs to be processed externally to extract relevant states

        """
        p1state = {
            "player_y": self.player1.pos.y,
            "player_velocity": self.player1.vel.y,
            "player2_y": self.player2.pos.y,
            "ball_x": self.ball.pos.x,
            "ball_y": self.ball.pos.y,
            "ball_velocity_x": self.ball.vel.x,
            "ball_velocity_y": self.ball.vel.y,
            "player_angle": self.player1.angle
        }

        # Player2 state is reflected on x-axis to maintain compatibility with agents trained on player1
        p2state = {
            "player_y": self.player2.pos.y,
            "player_velocity": self.player2.vel.y,
            "player2_y": self.player1.pos.y,
            "ball_x": self.width - self.ball.pos.x,
            "ball_y": self.ball.pos.y,
            "ball_velocity_x": -self.ball.vel.x,
            "ball_velocity_y": self.ball.vel.y,
            "player_angle": self.player2.angle
        }

        return {'player1_state': p1state, 'player2_state': p2state}

    def getScore(self):
        return {'player1': self.getP1Score(), 'player2': self.getP2Score()}

    def getP1Score(self):
        return self.p1score_sum

    def getP2Score(self):
        return self.p2score_sum

    def game_over(self):
        # pong used 11 as max score
        return (self.score_counts['player1'] == self.MAX_SCORE) or (
            self.score_counts['player2'] == self.MAX_SCORE)

    def init(self):
        self.score_counts = {
            "player1": 0.0,
            "player2": 0.0
        }

        self.p1score_sum = 0.0
        self.p2score_sum = 0.0

        self.ball = Ball(
            self.ball_radius,
            self.ball_speed_ratio * self.height,
            self.rng,
            (self.width / 2, self.height / 2),
            self.width,
            self.height
        )

        self.player1 = Player(
            self.player1_speed_ratio * self.height,
            self.paddle_width,
            self.paddle_height,
            (self.paddle_dist_to_wall, self.height / 2),
            self.width,
            self.height,
            player_tag='player1',
            color=(0,255,0))

        self.player2 = Player(
            self.player2_speed_ratio * self.height,
            self.paddle_width,
            self.paddle_height,
            (self.width - self.paddle_dist_to_wall, self.height / 2),
            self.width,
            self.height,
            player_tag='player2',
            color=(255,0,0))

        self.player1_group = pygame.sprite.Group()
        self.player1_group.add(self.player1)
        self.player1_group.add(self.player2)

        self.ball_group = pygame.sprite.Group()
        self.ball_group.add(self.ball)


    def reset(self):
        self.init()
        # after game over set random direction of ball otherwise it will always be the same
        self._reset_ball(1 if self.rng.random_sample() > 0.5 else -1)


    def _reset_ball(self, direction):
        self.ball.pos.x = self.width / 2  # move it to the center

        # we go in the same direction that they lost in but at starting vel.
        self.ball.vel.x = self.ball.speed * direction
        self.ball.vel.y = (self.rng.random_sample() *
                           self.ball.speed) - self.ball.speed * 0.5

    def step(self, dt):
        dt /= 1000.0
        self.screen.fill((0, 0, 0))

        self.player1.speed = self.player1_speed_ratio * self.height
        self.player2.speed = self.player2_speed_ratio * self.height
        self.ball.speed = self.ball_speed_ratio * self.height

        self._handle_player_events()

        # doesnt make sense to have this, but include if needed.
        self.p1score_sum += self.rewards["tick"]
        self.p2score_sum += self.rewards["tick"]

        self.ball.update(self.player1, self.player2, dt)

        is_terminal_state = False

        # logic
        if self.ball.pos.x <= 0:
            self.p1score_sum += self.rewards["negative"]
            self.p2score_sum += self.rewards["positive"]
            self.score_counts["player2"] += 1.0
            self._reset_ball(-1)
            is_terminal_state = True

        if self.ball.pos.x >= self.width:
            self.p1score_sum += self.rewards["positive"]
            self.p2score_sum += self.rewards["negative"]
            self.score_counts["player1"] += 1.0
            self._reset_ball(1)
            is_terminal_state = True

        if is_terminal_state:
            # p1 winning
            if self.score_counts['player1'] == self.MAX_SCORE:
                self.p1score_sum += self.rewards["win"]
                self.p2score_sum += self.rewards["loss"]

            # p2 losing
            if self.score_counts['player2'] == self.MAX_SCORE:
                self.p1score_sum += self.rewards["loss"]
                self.p2score_sum += self.rewards["win"]
        else:
            self.player1.update(dy=self.dy1, dt=dt, r_delta=self.dr1)
            if self.p2_enabled:
                self.player2.update(dy=self.dy2, dt=dt, r_delta=self.dr2)
            else:
                self.player2.updateCpu(self.ball, dt=dt)

        self.player1_group.draw(self.screen)
        self.ball_group.draw(self.screen)

if __name__ == "__main__":
    import numpy as np

    pygame.init()
    # Set p2_enabled True to enable p2 control (False triggers default behavior)
    game = Pong(width=256, height=200, p2_enabled=True)
    game.screen = pygame.display.set_mode(game.getScreenDims(), 0, 32)
    game.clock = pygame.time.Clock()
    game.rng = np.random.RandomState(24)
    game.init()

    while True:
        dt = game.clock.tick_busy_loop(60)
        game.step(dt)
        pygame.display.update()
