#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Simple, self contained box2d usage environment
Creates 0-G balls that bounce around the scene
"""
# import pygame
# from pygame.locals import (QUIT, KEYDOWN, K_ESCAPE)

import Box2D  # The main library
# Box2D.b2 maps Box2D.b2Vec2 to vec2 (and so on)
from Box2D.b2 import (world, polygonShape, staticBody, dynamicBody)
from Box2D import (b2CircleShape, b2FixtureDef, b2LoopShape, b2PolygonShape,
                   b2RevoluteJointDef, b2_pi, b2_dynamicBody)
import cv2
import time, collections
import numpy as np
from Environment.Environments.Box2D.box2d_init_specs import preassigned_objects, preassigned_nid, preassigned_names
import re

def strip_instance(name):
    return re.sub(r'\d+$', '', name)

# --- constants ---
# Box2D deals with meters, but we want to display pixels,
# so define a conversion factor:
TARGET_FPS = 60
TIME_STEP = 1.0 / TARGET_FPS
SCREEN_WIDTH, SCREEN_HEIGHT = 120, 120
PPM = SCREEN_WIDTH / 10  # pixels per meter


class BouncingShapes():
    def __init__(self, num_balls, num_polys, dyn_damping, base_density, init_velocity, general_radius_min, general_radius_max, use_target, target_form, target_damping, target_radius, target_density, target_init_vel, control_damping, control_radius, control_density, length, width, max_vertex, force_scaling, render_size, render_masks=False):
        # --- pybox2d world setup ---
        # Create the world
        self.world = world(gravity=(0, 0), doSleep=True)
        self.length, self.width = length, width
        self.force_scaling=force_scaling
        self.num_balls = num_balls
        self.num_polys = num_polys
        self.base_density = base_density
        self.init_velocity = init_velocity
        self.use_target = use_target
        self.target_form = target_form
        self.target_radius = target_radius
        self.target_density = target_density
        self.target_init_vel = target_init_vel
        self.control_radius = control_radius
        self.control_damping = control_damping
        self.control_density = control_density
        self.dyn_damping= dyn_damping
        self.target_damping = target_damping
        self.max_vertex = max_vertex
        self.render_width = int(render_size) # a name of an object that will always be live, regardless of selection (empty string means unused)
        self.ppm = render_size / self.width
        self.render_length = int(self.ppm * self.length)
        self.render_masks = render_masks
        self.general_radius_min = general_radius_min
        self.general_radius_max = general_radius_max
        # And a static body to hold the ground shape
        self.ground_body = self.world.CreateBody(
            shapes=b2LoopShape(vertices=[(-width/2, -length/2),
                                         (-width/2, length/2), (width/2, length/2),
                                         (width/2, -length/2)]),
        )

        self.reset()


    def reset(self, object_state_dict=None, type_instance_dict=None, max_count_dict=None):
        if hasattr(self, "object_dict"):
            for body in self.object_dict.values():
                self.world.DestroyBody(body)
        self.balls = dict()
        self.polys = dict()
        self.control_ball_attrs = None
        self.use_target = False
        self.target_attrs = None
        # handles setting to a particular value
        if object_state_dict is not None:
            # init from a dictionary to a particular state
            for n in object_state_dict.keys():
                if n.find("Poly") != -1:
                    num_vertices = int(n[4:n.find("vert")])
                    poly_state = object_state_dict[n]
                    pos = poly_state[:2]
                    vel = poly_state[2:4]
                    angle = np.arcsin(poly_state[4])
                    angular_vel = poly_state[6]
                    radius = poly_state[7]
                    select_id = preassigned_nid[strip_instance(n)] # TODO: assumes there are single digit number of instances of any one object
                    _, poly_attrs = self.create_bouncing_polygon(0, num_vertices=num_vertices, pos=pos.astype(float), vel=vel.astype(float), density=self.base_density, angle=float(angle), angular_vel=float(angular_vel), linearDamping=self.dyn_damping, radius=radius, select_id = select_id, name=n)
                    self.polys[n] = poly_attrs
                elif n.find("Ball") != -1:
                    ball_state = object_state_dict[n]
                    pos = ball_state[:2]
                    vel = ball_state[2:4]
                    radius = ball_state[4]
                    _, ball_attrs = self.create_bouncing_ball(0, pos=pos, vel=vel, density=self.base_density, linearDamping=self.dyn_damping, radius=radius, name=n)
                    self.balls[n] = ball_attrs
                elif n.find("Control") != -1: # control
                    ball_state = object_state_dict[n]
                    pos = ball_state[:2]
                    vel = ball_state[2:4]
                    radius = ball_state[4]
                    self.control_ball_name, self.control_ball_attrs = self.create_bouncing_ball(-1,name="Control", density=self.control_density, radius=self.control_radius, linearDamping=self.control_damping, vel=vel, pos=pos, color=(127,0,0,255))
                elif n.find("Target") != -1: # target ball
                    self.use_target = True
                    if self.target_form == "Ball":
                        ball_state = object_state_dict[n]
                        pos = ball_state[:2]
                        vel = ball_state[2:4]
                        radius = ball_state[4]
                        self.target_name, self.target_attrs = self.create_bouncing_ball(-1,name="Target", density = self.target_density, radius=self.target_radius, vel=vel, pos=pos, linearDamping=self.target_damping, color=(0,127,0,255))
                    elif self.target_form.find("Poly") != -1:
                        num_vertices = int(self.target_form[4:self.target_form.find("vert")])
                        poly_state = object_state_dict[n]
                        pos = poly_state[:2]
                        vel = poly_state[2:4]
                        angle = np.arcsin(poly_state[4])
                        angular_vel = poly_state[6]
                        radius = poly_state[7]
                        select_id = preassigned_nid[strip_instance(self.target_form)] # TODO: assumes there are single digit number of instances of any one object
                        self.target_name, self.target_attrs = self.create_bouncing_polygon(-1, name="Target", density= self.target_density,radius=self.target_radius, num_vertices=num_vertices, pos=pos.astype(float), vel=vel.astype(float), angle=float(angle), angular_vel=float(angular_vel), linearDamping=self.target_damping, select_id = select_id,color=(0,127,0,255))
                        self.polys[n] = poly_attrs

            self.num_balls = len(list(self.balls.keys()))
        elif type_instance_dict is not None:
            # init from a dictionary of ids to a random state
            self.num_balls = len(type_instance_dict["Ball"])
            for i in range(len(type_instance_dict["Ball"])):
                name, ball_attrs = self.create_bouncing_ball(i, 
                                                             density=self.base_density, 
                                                             name="Ball" if max_count_dict["Ball"] <= 1 else None, 
                                                             linearDamping=self.dyn_damping,
                                                             init_vel = self.init_velocity,
                                                             radius= self.general_radius_min + (np.random.rand() * (self.general_radius_max - self.general_radius_min))
                                                             )
                self.balls[name] = ball_attrs

            select_id_counter = collections.Counter()
            for i in range(len(type_instance_dict["Poly"])):
                select_id_counter[type_instance_dict["Poly"][i]] += 1

            for i in range(len(type_instance_dict["Poly"])):
                radius =  self.general_radius_min + (np.random.rand() * (self.general_radius_max - self.general_radius_min))
                name, poly_attrs = self.create_bouncing_polygon(select_id_counter[type_instance_dict["Poly"][i]] if select_id_counter[type_instance_dict["Poly"][i]] > 1 else -1, 
                                                                radius=radius, num_vertices=np.random.randint(3,self.max_vertex+1), select_id=type_instance_dict["Poly"][i],
                                                                name = preassigned_names[type_instance_dict["Poly"][i]] if max_count_dict["Poly"] <= 1 else None, 
                                                                 density=self.base_density, linearDamping=self.dyn_damping, init_vel = self.init_velocity)
                self.polys[name] = poly_attrs

            if len(type_instance_dict["Control"]) > 0:
                pos = (0,0) # TODO: could use a random pos to make the domain harder
                self.control_ball_name, self.control_ball_attrs = self.create_bouncing_ball(-1,name="Control",  density=self.control_density,radius=self.control_radius, linearDamping=self.control_damping, vel=(0,0), pos=pos, color=(127,0,0,255))
            else: self.control_ball_name, self.control_ball_attrs = "Control", None
            # print("all reset", self.control_ball_attrs, type_instance_dict["Control"])
            if len(type_instance_dict["Target"]) > 0:
                self.use_target = True
                if self.target_form == "Ball":
                    self.target_name, self.target_attrs = self.create_bouncing_ball(-1,name="Target",  density=self.target_density, radius=self.target_radius, linearDamping=self.target_damping, color=(0,127,0,255), init_vel = self.target_init_vel)
                elif self.target_form.find("Poly") != -1:
                    self.target_name, self.target_attrs = self.create_bouncing_polygon(-1, name= "Target",  density=self.target_density, radius=self.target_radius,
                                                                    select_id = preassigned_nid[strip_instance(self.target_form)], linearDamping=self.target_damping,color=(0,127,0,255), init_vel = self.target_init_vel)

            else:
                self.use_target = False 
                self.target_name, self.target_attrs = "Target", None
        else:
            for i in range(self.num_balls):
                name, ball_attrs = self.create_bouncing_ball(i, linearDamping=self.dyn_damping, density=self.base_density, init_vel = self.init_velocity)
                self.balls[name] = ball_attrs

            for i in range(self.num_polys):
                name, poly_attrs = self.create_bouncing_polygon(i, num_vertices=np.random.randint(3,self.max_vertex+1), linearDamping=self.dyn_damping,  density=self.base_density, init_vel = self.init_velocity)
                self.polys[name] = poly_attrs

            pos = (0,0) # TODO: could use a random pos to make the domain harder
            self.control_ball_name, self.control_ball_attrs = self.create_bouncing_ball(-1,name="Control", density=self.control_density, radius=self.control_radius, linearDamping=self.control_damping, vel=(0,0), pos=pos, color=(127,0,0,255))
            if self.use_target:
                if self.target_form == "Ball":
                    self.target_name, self.target_attrs = self.create_bouncing_ball(-1,name="Target", density=self.target_density, radius=self.target_radius, linearDamping=self.target_damping, color=(0,127,0,255), target_init_vel=self.target_init_vel)
                elif self.target_form.find("Poly") != -1:
                    self.target_name, self.target_attrs = self.create_bouncing_polygon(-1, name= "Target", density=self.target_density, radius=self.target_radius,
                                                                    select_id = preassigned_nid[strip_instance(self.target_form)], linearDamping=self.target_damping,color=(0,127,0,255), target_init_vel=self.target_init_vel)
            else: self.target_name, self.target_attrs = "Target", None
        # names and object dict
        self.ball_names = list(self.balls.keys())
        self.ball_names.sort()
        self.poly_names = list(self.polys.keys())
        self.poly_names.sort()
        self.object_dict = {**{name: self.balls[name][0] for name in self.balls.keys()},
                             **({self.control_ball_name: self.control_ball_attrs[0]} if self.control_ball_attrs is not None else dict()),
                             **({self.target_name: self.target_attrs[0]} if self.target_attrs is not None else dict()),
                             **{name: self.polys[name][0] for name in self.polys.keys()}}

        self.frame = np.zeros((int(self.render_width),int(self.render_length), 3))

    def create_bouncing_ball(self, i, name=None, color=(127, 127, 127), radius=-1, vel=None, pos=None,linearDamping=0, density=1.0, init_vel=1.0):
        if pos is None: pos = ((np.random.rand() - 0.5) * 2 * (self.width / 2), (np.random.rand() - 0.5) * 2 * (self.length / 2))
        if vel is None: vel = ((np.random.rand() - 0.5) * 2 * (self.width) * init_vel,(np.random.rand() - 0.5) * 2 * (self.length) * init_vel)
        if radius < 0: radius = max(1, np.random.rand() * (self.width/ 5))
        ball = self.world.CreateDynamicBody(
            fixtures=b2FixtureDef(
                shape=b2CircleShape(radius=radius),
                density=density,
                restitution = 1.0,),
            bullet=True,
            position=pos,
            linearVelocity=vel,
            linearDamping=linearDamping
            )
        color =  color # randomize color
        ball_name = "Ball" + str(i)
        return ball_name if name is None else name, (ball, color)

    def create_bouncing_polygon(self, i, name=None, color=(127, 127, 127), radius=-1, vel=None, pos=None, angle=0, angular_vel=0, fixed_rotation=False, linearDamping=0, density=1.0, num_vertices=3, select_id=-1, init_vel=1.0):
            
        if pos is None: pos = ((np.random.rand() - 0.5) * 2 * (self.width / 2), (np.random.rand() - 0.5) * 2 * (self.length / 2))
        if vel is None: vel = ((np.random.rand() - 0.5) * 2 * (self.width) * init_vel,(np.random.rand() - 0.5) * 2 * (self.length) * init_vel)
        if radius < 0: radius = 1 # max(1, np.random.rand() * (self.width/ 5))

        if select_id >= 0:
            # vertices = [(v[0] * radius, v[1] * radius) for v in preassigned_objects[select_id]]
            vertices = (np.array(preassigned_objects[select_id]) * radius).tolist()
            poly_name = preassigned_names[select_id]
        else:
            angles = [np.random.rand() * np.pi * 2 / num_vertices + np.pi * 2 / num_vertices * i for i in range(num_vertices)]
            vertices = [(np.cos(angle) * radius, np.sin(angle) * radius) for angle in angles]
            poly_name = "Poly" + str(num_vertices) + "vert"
            if select_id > 0: poly_name += str(select_id) + "form"

        fixture = b2FixtureDef(
            shape=b2PolygonShape(vertices=vertices),
            # shape=b2PolygonShape(vertices=[(-1, 0), (1, 0), (0, 2)]),
            density=density,
            restitution=0.1,
        )

        body = self.world.CreateBody(type=b2_dynamicBody,
                                    position=pos,
                                    linearVelocity=vel,
                                    angularVelocity=angular_vel,
                                    angle=angle,
                                    fixtures=fixture,
                                    fixedRotation=fixed_rotation,  # <--
                                    linearDamping=linearDamping, 
                                    )
        color =  color # randomize color
        poly_name = poly_name + str(i)
        return poly_name if name is None else name, (body, color)


    def draw_circle(self, body_attrs):
        body, color = body_attrs
        for fixture in body.fixtures:
            # The fixture holds information like density and friction,
            # and also the shape.
            shape = fixture.shape

            # Naively assume that this is a polygon shape. (not good normally!)
            # We take the body's transform and multiply it with each
            # vertex, and then convert from meters to pixels with the scale
            # factor.
            center = np.array(body.position) + np.array((self.width / 2, self.length/2))
            center = np.array((center[1], center[0])) * self.ppm
            radius = int(fixture.shape.radius * self.ppm)

            # But wait! It's upside-down! Pygame and Box2D orient their
            # axes in different ways. Box2D is just like how you learned
            # in high school, with positive x and y directions going
            # right and up. Pygame, on the other hand, increases in the
            # right and downward directions. This means we must flip
            # the y components.
            # center = (int(center[0]), int(SCREEN_HEIGHT - center[1]))
            # print(center, radius, body.position, body.linearVelocity)

            # pygame.draw.circle(self.screen, color, center, center.radius)
            cv2.circle(self.frame, center.astype(int), radius, color, -1)


    def draw_polygon(self, body_attrs):
        body, color = body_attrs
        for fixture in body.fixtures:
            # The fixture holds information like density and friction,
            # and also the shape.
            shape = fixture.shape

            # Naively assume that this is a polygon shape. (not good normally!)
            # We take the body's transform and multiply it with each
            # vertex, and then convert from meters to pixels with the scale
            # factor.
            rotation = np.stack([body.transform.R.x_axis, body.transform.R.y_axis], axis = 1)
            vertices = [np.matmul(rotation, v) for v in shape.vertices]
            vertices = [body.position + v for v in vertices]
            vertices = [np.array(v) + np.array((self.width / 2, self.length/2)) for v in vertices]
            vertices = [np.array((v[1], v[0])) * self.ppm for v in vertices]
            vertices = np.array(vertices).astype(int)

            cv2.fillPoly(self.frame,pts=[vertices],color=color)

    def render(self, show=False):
        if self.render_masks:
            frames = {}
            for ball_name, ball_attrs in self.balls.items():
                self.frame = np.zeros((self.render_width, self.render_length, 3)).astype(np.uint8)
                self.draw_circle(ball_attrs)
                frames[ball_name] = self.frame
            for poly_name, poly_attrs in self.polys.items():
                self.frame = np.zeros((self.render_width, self.render_length, 3)).astype(np.uint8)
                self.draw_polygon(poly_attrs)
                frames[poly_name] = self.frame
            if self.control_ball_attrs is not None:
                self.frame = np.zeros((self.render_width, self.render_length, 3)).astype(np.uint8)
                self.draw_circle(self.control_ball_attrs)
                frames[self.control_ball_name] = self.frame
            if self.target_attrs is not None:
                self.frame = np.zeros((self.render_width, self.render_length, 3)).astype(np.uint8)
                if self.target_form == "Ball":
                    self.draw_circle(self.target_attrs)
                else:
                    self.draw_polygon(self.target_attrs)
                frames[self.target_name] = self.frame
            all_img = np.zeros((self.render_width, self.render_length, 3)).astype(np.uint8)
            for x in frames.values():
                all_img = all_img + x
            frames['obs'] = all_img
            return frames


        self.frame = np.zeros((self.render_width, self.render_length, 3)).astype(np.uint8)
        for ball_attrs in self.balls.values():
            self.draw_circle(ball_attrs)
        for poly_attrs in self.polys.values():
            self.draw_polygon(poly_attrs)
        if self.control_ball_attrs is not None: self.draw_circle(self.control_ball_attrs)
        if self.target_attrs is not None:
            if self.target_form == "Ball":
                self.draw_circle(self.target_attrs)
            else:
                self.draw_polygon(self.target_attrs)
        if show: 
            cv2.imshow("bouncing", self.frame)
            cv2.waitKey(10)
        return self.frame

    def step(self, action, time_step=0.018, iters=10):
        if self.control_ball_attrs is not None: self.control_ball_attrs[0].ApplyForceToCenter(self.force_scaling * self.control_ball_attrs[0].mass * np.array((action[0], action[1])).astype(float), True)
        # print(action, self.control_ball_attrs[0], self.control_ball_attrs[1])
        self.world.Step(TIME_STEP, 10, 10)

    def get_contacts(self):
        contacts = list()
        shape_pointers = ([self.control_ball_attrs[0]] if self.control_ball_attrs is not None else list()) + \
                         [self.balls[bn][0] for bn in self.ball_names] + [self.polys[pn][0] for pn in self.poly_names] + \
                         ([self.target_attrs[0]] if self.target_attrs is not None else list())
        names = ([self.control_ball_name] if self.control_ball_attrs is not None else list()) + self.ball_names + self.poly_names + ([self.target_name] if self.target_attrs is not None else list())
        contact_names = {n: list() for n in names}
        for bn in names:
            all_contacts = np.zeros(len(shape_pointers)).astype(bool)
            for contact in self.object_dict[bn].contacts:
                if contact.contact.touching:
                    contact_bool = np.array([(contact.other == bp and contact.contact.touching) for bp in shape_pointers])
                    contact_names[bn] += [sn for sn, bp in zip(names, shape_pointers) if (contact.other == bp)]
                else:
                    contact_bool = np.zeros(len(shape_pointers)).astype(bool)
                all_contacts += contact_bool
            contacts.append(all_contacts)
        return np.stack(contacts, axis=0), contact_names


if __name__ == "__main__":
    colors = {
        staticBody: (255, 255, 255, 255),
        dynamicBody: (127, 127, 127, 255),
    }


    # --- main game loop ---
    num_balls = np.random.randint(3,10)
    bouncingballs = BouncingShapes(1,1,10,10,5,100)
    running = True
    start = time.time()
    sticky = (np.random.rand(2) - 0.5) * 2
    for i in range(1000000):
        # Check the event queue
        # for event in pygame.event.get():
        #     if event.type == QUIT or (event.type == KEYDOWN and event.key == K_ESCAPE):
        #         # The user closed the window or pressed escape
        #         running = False

        # bouncingballs.frame = np.zeros((SCREEN_WIDTH,SCREEN_HEIGHT))
        if i % 1000 == 0:
            print("fps", 1000 / (time.time() - start))
            start = time.time()
        # Draw the world

        # Make Box2D simulate the physics of our world for one step.
        # Instruct the world to perform a single step of simulation. It is
        # generally best to keep the time step and iterations fixed.
        # See the manual (Section "Simulating the World") for further discussion
        # on these parameters and their implications.
        sticky = (np.random.rand(2) - 0.5) * 2 if np.random.randint(100) > 90 else sticky
        bouncingballs.step(sticky)
        # print(sticky, bouncingballs.get_contacts())

        # Flip the screen and try to keep at the target FPS
        # pygame.display.flip()
        bouncingballs.render(show=True)
        if i % 300 == 0:
            bouncingballs.reset()
        # bouncingballs.clock.tick(TARGET_FPS)
