#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Simple, self contained box2d usage environment
Creates 0-G balls that bounce around the scene
"""
# import pygame
# from pygame.locals import (QUIT, KEYDOWN, K_ESCAPE)

import Box2D  # The main library
# Box2D.b2 maps Box2D.b2Vec2 to vec2 (and so on)
from Box2D.b2 import (world, polygonShape, staticBody, dynamicBody)
from Box2D import (b2CircleShape, b2FixtureDef, b2LoopShape, b2PolygonShape,
                   b2RevoluteJointDef, b2_pi, b2_dynamicBody, b2_staticBody, b2Filter)
import cv2
import time, collections
import numpy as np
import re

def strip_instance(name):
    return re.sub(r'\d+$', '', name)

# --- constants ---
# Box2D deals with meters, but we want to display pixels,
# so define a conversion factor:
TARGET_FPS = 60
TIME_STEP = 1.0 / TARGET_FPS
SCREEN_WIDTH, SCREEN_HEIGHT = 120, 120
PPM = SCREEN_WIDTH / 10  # pixels per meter


class AirHockeyBox2D():
    def __init__(self, num_balls, num_blocks, num_obstacles, num_targets, absorb_target, use_cue, length, width, force_scaling, paddle_damping, render_size, render_masks=False):
        # --- pybox2d world setup ---
        # Create the world
        self.world = world(gravity=(0, -5), doSleep=True)
        self.length, self.width = length, width
        self.force_scaling=force_scaling
        self.num_balls = num_balls
        self.num_blocks = num_blocks
        self.num_obstacles = num_obstacles
        self.num_targets = num_targets
        self.absorb_target = absorb_target
        self.render_width = int(render_size) # a name of an object that will always be live, regardless of selection (empty string means unused)
        self.ppm = render_size / self.width
        self.render_length = int(self.ppm * self.length)
        self.render_masks = render_masks
        self.ball_min_height = (-length / 2) + (length / 3)
        self.paddle_max_height = (-length / 2) + (length / 4)
        self.block_min_height = 0
        self.max_speed_start = width
        self.min_speed_start = 0
        self.paddle_damping = paddle_damping
        self.use_cue = use_cue
        # And a static body to hold the ground shape
        self.ground_body = self.world.CreateBody(
            shapes=b2LoopShape(vertices=[(-width/2, -length/2),
                                         (-width/2, length/2), (width/2, length/2),
                                         (width/2, -length/2)]),
        )

        self.reset()


    def reset(self, object_state_dict=None, type_instance_dict=None, max_count_dict=None):
        if hasattr(self, "object_dict"):
            for body in self.object_dict.values():
                self.world.DestroyBody(body)
        self.balls = dict()
        self.blocks = dict()
        self.obstacles = dict()
        self.targets = dict()
        self.paddle_attrs = None
        self.target_attrs = None
        # TODO: figure this code out. handles setting to a particular value
        # if object_state_dict is not None:
        #     for n in object_state_dict.keys():
        #         if n.find("Ball") != -1:
        #             ball_state = object_state_dict[n]
        #             pos = ball_state[:2]
        #             vel = ball_state[2:4]
        #             radius = ball_state[4]
        #             _, ball_attrs = self.create_bouncing_ball(0, pos=pos, vel=vel, radius=radius, name=n)
        #             self.balls[n] = ball_attrs
        #         elif n.find("Paddle") != -1: # control
        #             ball_state = object_state_dict[n]
        #             pos = ball_state[:2]
        #             vel = ball_state[2:4]
        #             radius = ball_state[4]
        #             self.paddle_name, self.paddle_radius = self.create_bouncing_ball(-1,name="Paddle", radius=self.paddle_radius, vel=vel, pos=pos, ldamp = self.ldamp, color=(127,0,0,255))
        #         elif n.find("Target") != -1: # target ball
        #             self.use_target = True
        #             ball_state = object_state_dict[n]
        #             pos = ball_state[:2]
        #             vel = ball_state[2:4]
        #             radius = ball_state[4]
        #             self.target_ball_name, self.target_ball_attrs = self.create_bouncing_ball(-1,name="Target", radius=1, vel=vel, pos=pos, color=(0,127,0,255))
        #     self.num_balls = len(list(self.balls.keys()))
        # elif type_instance_dict is not None:
        #     self.num_balls = len(type_instance_dict["Ball"])
        #     for i in range(len(type_instance_dict["Ball"])):
        #         name, ball_attrs = self.create_bouncing_ball(i, name="Ball" if max_count_dict["Ball"] <= 1 else None)
        #         self.balls[name] = ball_attrs

        #     for i in range(len(type_instance_dict["Poly"])):
        #         name, poly_attrs = self.create_bouncing_polygon(select_id_counter[type_instance_dict["Poly"][i]] if select_id_counter[type_instance_dict["Poly"][i]] > 1 else -1, 
        #                                                         num_vertices=np.random.randint(3,self.max_vertex+1), select_id=type_instance_dict["Poly"][i],
        #                                                         name = preassigned_names[type_instance_dict["Poly"][i]] if max_count_dict["Poly"] <= 1 else None)
        #         self.polys[name] = poly_attrs

        #     if len(type_instance_dict["Control"]) > 0:
        #         self.control_ball_name, self.control_ball_attrs = self.create_bouncing_ball(-1,name="Control", radius=1, vel=(0,0), pos=(0,0), color=(127,0,0,255))
        #     else: self.control_ball_name, self.control_ball_attrs = "Control", None
        #     # print("all reset", self.control_ball_attrs, type_instance_dict["Control"])
        #     if len(type_instance_dict["Target"]) > 0:
        #         self.use_target = True
        #         self.target_ball_name, self.target_ball_attrs = self.create_bouncing_ball(-1,name="Target", radius=1, vel=(0,0), pos=(0,0), color=(0,127,0,255))
        #     else:
        #         self.use_target = False 
        #         self.target_ball_name, self.target_ball_attrs = "Target", None
        # else:
        for i in range(self.num_balls):
            rad = max(0.25, np.random.rand() * (self.width/ 8))
            name, ball_attrs = self.create_bouncing_ball(i, radius = rad, min_height=self.ball_min_height)
            self.balls[name] = ball_attrs

        for i in range(self.num_blocks):
            name, block_attrs = self.create_block_type(i, name_type = "Block", dynamic=False, min_height = self.block_min_height)
            self.blocks[name] = block_attrs

        for i in range(self.num_obstacles): # could replace with arbitary polygons
            name, obs_attrs = self.create_block_type(i, name_type = "Obstacle", angle=np.random.rand() * np.pi, dynamic = False, color=(0, 127, 127), min_height = self.block_min_height)
            self.obstacles[name] = obs_attrs

        for i in range(self.num_targets):
            name, target_attrs = self.create_block_type(i, name_type = "Target", color=(255, 255, 0))
            self.targets[name] = target_attrs
        self.paddle = self.create_bouncing_ball(i, name="Paddle", radius=0.25, density=1000, ldamp=self.paddle_damping, color=(0, 255, 0), max_height=self.paddle_max_height)
        if self.use_cue:
            self.cue = self.create_bouncing_ball(-1,name="Cue", radius=0.25, vel=(0,0), pos=(0,0), color=(200,100,0))
        else: self.cue = ("Cue", None)
        # names and object dict
        self.ball_names = list(self.balls.keys())
        self.ball_names.sort()
        self.block_names = list(self.blocks.keys())
        self.block_names.sort()
        self.obstacle_names = list(self.obstacles.keys())
        self.obstacle_names.sort()
        self.target_names = list(self.targets.keys())
        self.target_names.sort()
        self.object_dict = {**{name: self.balls[name][0] for name in self.balls.keys()},
                             **({self.paddle[0]: self.paddle[1][0]} if self.paddle[1] is not None else dict()),
                             **({self.cue[0]: self.cue[1][0]} if self.cue[1] is not None else dict()),
                             **{name: self.blocks[name][0] for name in self.blocks.keys()},
                             **{name: self.targets[name][0] for name in self.targets.keys()},
                             **{name: self.obstacles[name][0] for name in self.obstacles.keys()},
                             }

        self.frame = np.zeros((int(self.render_width),int(self.render_length), 3))

    def create_bouncing_ball(self, i, name=None, color=(127, 127, 127), radius=-1,density=10, vel=None, pos=None, ldamp=1, collidable=True, min_height=-30, max_height=30):
        if pos is None: pos = ((np.random.rand() - 0.5) * 2 * (self.width / 2), max(min_height,-self.length / 2) + (np.random.rand() * ((min(max_height,self.length / 2)) - (max(min_height,-self.length / 2)))))
        if vel is None: vel = (np.random.rand() * (self.max_speed_start - self.min_speed_start) + self.min_speed_start,np.random.rand() * (self.max_speed_start - self.min_speed_start) + self.min_speed_start)
        if radius < 0: radius = max(1, np.random.rand() * (self.width/ 5))
        ball = self.world.CreateDynamicBody(
            fixtures=b2FixtureDef(
                shape=b2CircleShape(radius=radius),
                density=1.0,
                restitution = 1.0,
                filter=b2Filter (maskBits=1,
                                 categoryBits=1 if collidable else 0)),
            bullet=True,
            position=pos,
            linearVelocity=vel,
            linearDamping=ldamp
        )
        color =  color # randomize color
        ball_name = "Ball" + str(i)
        return ((ball_name, (ball, color)) if name is None else (name, (ball, color)))

    def create_block_type(self, i, name=None,name_type=None, color=(127, 127, 127), width=-1, height=-1, vel=None, pos=None, dynamic=True, angle=0, angular_vel=0, fixed_rotation=False, collidable=True, min_height=-30):
        if pos is None: pos = ((np.random.rand() - 0.5) * 2 * (self.width / 2), min_height + (np.random.rand() * (self.length - (min_height + self.length / 2))))
        if vel is None: vel = ((np.random.rand() - 0.5) * 2 * (self.width),(np.random.rand() - 0.5) * 2 * (self.length))
        if not dynamic: vel = np.zeros((2,))
        if width < 0: width = max(0.75, np.random.rand() * 3)
        if height < 0: height = max(0.5, np.random.rand())
        # TODO: possibly create obstacles of arbitrary shape
        vertices = [([-width / 2, -height / 2]), ([width / 2, -height / 2]), ([width / 2, height / 2]), ([-width / 2, height / 2])]
        block_name  = name_type # Block, Obstacle, Target

        fixture = b2FixtureDef(
            shape=b2PolygonShape(vertices=vertices),
            # shape=b2PolygonShape(vertices=[(-1, 0), (1, 0), (0, 2)]),
            density=1,
            restitution=0.1,
            filter=b2Filter (maskBits=1,
                                 categoryBits=1 if collidable else 0),
        )

        body = self.world.CreateBody(type=b2_dynamicBody if dynamic else b2_staticBody,
                                    position=pos,
                                    linearVelocity=vel,
                                    angularVelocity=angular_vel,
                                    angle=angle,
                                    fixtures=fixture,
                                    fixedRotation=fixed_rotation,  # <--
                                    )
        color =  color # randomize color
        block_name = block_name + str(i)
        return (block_name if name is None else name), (body, color)

    def draw_circle(self, body_attrs):
        body, color = body_attrs
        for fixture in body.fixtures:
            # The fixture holds information like density and friction,
            # and also the shape.
            shape = fixture.shape

            # Naively assume that this is a polygon shape. (not good normally!)
            # We take the body's transform and multiply it with each
            # vertex, and then convert from meters to pixels with the scale
            # factor.
            center = np.array(body.position) + np.array((self.width / 2, self.length/2))
            center = np.array((center[1], center[0])) * self.ppm
            radius = int(fixture.shape.radius * self.ppm)

            # pygame.draw.circle(self.screen, color, center, center.radius)
            cv2.circle(self.frame, center.astype(int), radius, color, -1)


    def draw_polygon(self, body_attrs):
        body, color = body_attrs
        for fixture in body.fixtures:
            # The fixture holds information like density and friction,
            # and also the shape.
            shape = fixture.shape

            # Naively assume that this is a polygon shape. (not good normally!)
            # We take the body's transform and multiply it with each
            # vertex, and then convert from meters to pixels with the scale
            # factor.
            rotation = np.stack([body.transform.R.x_axis, body.transform.R.y_axis], axis = 1)
            vertices = [np.matmul(rotation, v) for v in shape.vertices]
            vertices = [body.position + v for v in vertices]
            vertices = [np.array(v) + np.array((self.width / 2, self.length/2)) for v in vertices]
            vertices = [np.array((v[1], v[0])) * self.ppm for v in vertices]
            vertices = np.array(vertices).astype(int)

            cv2.fillPoly(self.frame,pts=[vertices],color=color)

    def render(self, show=False):
        if self.render_masks:
            frames = {}
            for ball_name, ball_attrs in self.balls.items():
                self.frame = np.zeros((self.render_width, self.render_length, 3)).astype(np.uint8)
                self.draw_circle(ball_attrs)
                frames[ball_name] = self.frame
            for block_name, block_attrs in self.blocks.items():
                self.frame = np.zeros((self.render_width, self.render_length, 3)).astype(np.uint8)
                self.draw_polygon(block_attrs)
                frames[block_name] = self.frame
            for obstacle_name, obstacle_attrs in self.obstacles.items():
                self.frame = np.zeros((self.render_width, self.render_length, 3)).astype(np.uint8)
                self.draw_polygon(obstacle_attrs)
                frames[obstacle_name] = self.frame
            if self.paddle is not None:
                self.frame = np.zeros((self.render_width, self.render_length, 3)).astype(np.uint8)
                self.draw_circle(self.paddle[1])
                frames[self.paddle[0]] = self.frame
            if self.cue is not None:
                self.frame = np.zeros((self.render_width, self.render_length, 3)).astype(np.uint8)
                self.draw_circle(self.cue[1])
                frames[self.paddle[0]] = self.frame
            if self.target is not None:
                self.frame = np.zeros((self.render_width, self.render_length, 3)).astype(np.uint8)
                self.draw_circle(self.target[1])
                frames[self.target[0]] = self.frame
            all_img = np.zeros((self.render_width, self.render_length, 3)).astype(np.uint8)
            for x in frames.values():
                all_img = all_img + x
            frames['obs'] = all_img
            return frames


        self.frame = np.zeros((self.render_width, self.render_length, 3)).astype(np.uint8)
        for ball_attrs in self.balls.values():
            self.draw_circle(ball_attrs)
        for block_attrs in self.blocks.values():
            self.draw_polygon(block_attrs)
        for obstacle_attrs in self.obstacles.values():
            self.draw_polygon(obstacle_attrs)
        if self.paddle[1] is not None: self.draw_circle(self.paddle[1])
        if self.cue[1] is not None: self.draw_circle(self.cue[1])
        if show: 
            cv2.imshow("bouncing", self.frame)
            cv2.waitKey(10)
        return self.frame

    def step(self, action, time_step=0.018, iters=10):
        force = self.force_scaling * self.paddle[1][0].mass * np.array((action[0], action[1])).astype(float)
        if self.paddle[1][0].position[1] > self.paddle_max_height: force[1] = min(self.force_scaling * self.paddle[1][0].mass * action[1], 0)
        if self.paddle is not None: self.paddle[1][0].ApplyForceToCenter(force, True)
        # print(action, self.control_ball_attrs[0], self.control_ball_attrs[1])
        self.world.Step(time_step, 10, 10)
        contacts, contact_names = self.get_contacts()
        hit_target = self.respond_contacts(contact_names)
        return contacts, contact_names, hit_target

    def get_contacts(self):
        contacts = list()
        shape_pointers = (
                        ([self.paddle[1][0]] if self.paddle[1] is not None else list()) + \
                        ([self.cue[1][0]] if self.cue[1] is not None else list()) + \
                         [self.balls[bn][0] for bn in self.ball_names] + [self.blocks[pn][0] for pn in self.block_names] + \
                         [self.obstacles[pn][0] for pn in self.obstacle_names] + [self.targets[pn][0] for pn in self.target_names])
        names = ([self.paddle[0]] if self.paddle[1] is not None else list()) + \
                    self.ball_names + self.block_names + self.obstacle_names + self.target_names + \
                ([self.cue[0]] if self.cue[1] is not None else list())
        # print(list(self.object_dict.keys()))
        contact_names = {n: list() for n in names}
        for bn in names:
            all_contacts = np.zeros(len(shape_pointers)).astype(bool)
            for contact in self.object_dict[bn].contacts:
                if contact.contact.touching:
                    contact_bool = np.array([(contact.other == bp and contact.contact.touching) for bp in shape_pointers])
                    contact_names[bn] += [sn for sn, bp in zip(names, shape_pointers) if (contact.other == bp)]
                else:
                    contact_bool = np.zeros(len(shape_pointers)).astype(bool)
                all_contacts += contact_bool
            contacts.append(all_contacts)
        return np.stack(contacts, axis=0), contact_names

    def respond_contacts(self, contact_names):
        hit_target = list()
        for tn in self.target_names:
            for cn in contact_names[tn]: 
                if cn.find("Ball") != -1:
                    hit_target.append(cn)
        if self.absorb_target:
            for cn in hit_target:
                self.world.DestroyBody(self.object_dict[cn])
                del self.object_dict[cn]
        return hit_target # TODO: record a destroyed flag

def demonstrate(frame):
    action = np.array([0,0])
    # frame = cv2.resize(frame, (frame.shape[0] * 5, frame.shape[1] * 5), interpolation = cv2.INTER_NEAREST)
    cv2.imshow('frame',frame)
    key = cv2.waitKey(100)
    if key == ord('q'):
        action = -1
    elif key == ord('w'):
        action = np.array([-1,0])
    elif key == ord('a'):
        action = np.array([0,-1])
    elif key == ord('d'):
        action = np.array([0,1])
    elif key == ord('s'):
        action = np.array([1,0])
    return action



if __name__ == "__main__":
    colors = {
        staticBody: (255, 255, 255, 255),
        dynamicBody: (127, 127, 127, 255),
    }


    # --- main game loop ---
    num_balls = np.random.randint(3,10)
    # air_hockey = AirHockeyBox2D(1,1,2,1,True,True,15,5,500,100,240)
    # num_balls, num_blocks, num_obstacles, num_targets, absorb_target, use_cue, length, width, force_scaling, paddle_damping, render_size, render_masks
    air_hockey = AirHockeyBox2D(1,0,1,0,False,False,15,5,200,10,240)
    running = True
    start = time.time()
    sticky = (np.random.rand(2) - 0.5) * 2
    for i in range(1000000):
        # Check the event queue
        # for event in pygame.event.get():
        #     if event.type == QUIT or (event.type == KEYDOWN and event.key == K_ESCAPE):
        #         # The user closed the window or pressed escape
        #         running = False

        # bouncingballs.frame = np.zeros((SCREEN_WIDTH,SCREEN_HEIGHT))
        if i % 1000 == 0:
            print("fps", 1000 / (time.time() - start))
            start = time.time()
        # Draw the world

        # Make Box2D simulate the physics of our world for one step.
        # Instruct the world to perform a single step of simulation. It is
        # generally best to keep the time step and iterations fixed.
        # See the manual (Section "Simulating the World") for further discussion
        # on these parameters and their implications.
        frame = air_hockey.render(show=False)
        action = demonstrate(frame)
        # sticky = (np.random.rand(2) - 0.5) * 2 if np.random.randint(100) > 90 else sticky
        # print(air_hockey.paddle[1][0].position)
        # air_hockey.step(sticky)
        air_hockey.step(action)
        # print(sticky, air_hockey.get_contacts())

        # Flip the screen and try to keep at the target FPS
        # pygame.display.flip()
        if i % 300 == 0:
            air_hockey.reset()
        # bouncingballs.clock.tick(TARGET_FPS)
