import gymnasium as gym
import numpy as np
import Box2D
from Box2D.b2 import (edgeShape, circleShape, fixtureDef, polygonShape, revoluteJointDef, contactListener, distance)
import pygame
from PIL import Image
from stable_baselines3.common.buffers import ReplayBuffer


class ObjectManager(object):
    def __init__(self, world):
        self.objects = []
        self.world = world

    def clear(self):
        for x in self.objects:
            self.world.DestroyBody(x)

    def __del__(self):
        self.clear()

    def append(self, obj):
        self.objects.append(obj)


class ContactDetector(contactListener):
    def __init__(self, env):
        contactListener.__init__(self)
        self.env = env

    def BeginContact(self, contact):
        if self.env.you == contact.fixtureA.body or self.env.you == contact.fixtureB.body:
            # if the drone is collide to something, set game over true
            # self.env.game_over = True
            self.env.contact = True

            # if the drone collide with the goal, success
            if self.env.goal == contact.fixtureA.body or self.env.goal == contact.fixtureB.body:
                self.env.achieve_goal = True


class PointMassBox2d(gym.Env):
    radius = 1
    screen_width: int = 600
    screen_height: int = 600
    scale = 100

    @property
    def screen_scale_rate(self):
        return self.screen_width / self.scale

    def __init__(self, N=1, risk_prob=0.7, risk_var=50,
                 seed=1, eval_env=True):

        self.world = Box2D.b2World(gravity=(0., 0.,))

        # Step 1: Car parameterss
        self.v_max = 0.1
        self.v_sigma = 0.01
        # Step 3: Environment parameters
        self.d_safe = 0.1
        self.d_goal = 0.05
        self.d_sampling = 0.1
        self.init_pos = np.array([1.0, 1.0])
        self.risk_prob = risk_prob
        self.risk_var = risk_var
        self.N = N  # number of obstacles

        self.low_state = 0
        self.high_state = 1

        self.min_actions = np.array(
            [-self.v_max, -self.v_max], dtype=np.float32
        )
        self.max_actions = np.array(
            [self.v_max, self.v_max], dtype=np.float32
        )
        self.action_space = gym.spaces.Box(
            low=-1,
            high=1,
            shape=(2,),
            dtype=np.float32,
        )
        self.observation_space = gym.spaces.Box(
            low=self.low_state,
            high=self.high_state,
            shape=(2 + 2,),
            dtype=np.float32
        )

        self._goal_pos = np.array([0.1, 0.1]) * self.scale
        self.r = 0.3 * self.scale
        self.centers = np.array([0.5, 0.5]) * self.scale

        # Rendering parameters
        self.screen_size = [self.scale, self.scale]
        self.screen_scale = 300
        self.background_color = [222, 222, 222]
        self.wall_color = [0, 0, 0]
        self.circle_color = [227, 122, 84]
        self.safe_circle_color = [200, 0, 0]
        self.lidar_color = [0, 0, 255]
        self.goal_color = [0, 255, 0]
        self.robot_color = [125, 5, 125]
        self.safety_color = [255, 0, 0]
        self.goal_size = 0.03 * self.scale
        self.radius = 0.015 * self.scale
        self.width = 3
        self.pygame_init = False
        self.np_rng = np.random.default_rng(seed)
        self.eval = eval_env
        self.you: Box2D.b2Body | None = None
        self.goal = None
        self.risk_zone = None
        self.map = None
        self.achieve_goal = False
        self.contact = False
        self.walls = []

    def _create_walls(self, position, size):
        body_def = Box2D.b2BodyDef()
        body_def.position = position
        body_def.type = Box2D.b2_staticBody
        body = self.world.CreateBody(body_def)
        box = Box2D.b2PolygonShape(box=size)
        body.CreateFixture(shape=box, density=0, friction=0.3)
        return body

    def _clean_walls(self):
        while self.walls:
            self.world.DestroyBody(self.walls.pop())

    @classmethod
    def normalize(cls, xy):
        return xy / np.asarray([cls.screen_width, cls.screen_height])

    @classmethod
    def denormalize(cls, xy):
        return xy * np.asarray([cls.screen_width, cls.screen_height])

    @classmethod
    def denormalize_scalar(cls, x):
        return x * np.linalg.norm(np.asarray([cls.screen_width, cls.screen_height]))

    @classmethod
    def normalize_scalar(cls, x):
        return x / np.linalg.norm(np.asarray([cls.screen_width, cls.screen_height]))

    def seed(self, seed=None):
        self.np_rng = np.random.default_rng(seed)
        return [seed]

    def __del__(self):
        self._destroy()

    def _destroy(self):
        self.world.DestroyBody(self.you)
        self.world.DestroyBody(self.goal)
        self.world.DestroyBody(self.map)
        if len(self.walls) > 0:
            self._clean_walls()
        self.world.contactListener = None

    def build(self, initial_position):
        # self.build_walls()
        self.world.contactListener_keepref = ContactDetector(self)
        self.world.contactListener = self.world.contactListener_keepref
        self.you = self.world.CreateDynamicBody(
            position=initial_position,
            angle=0.,
            fixtures=fixtureDef(shape=circleShape(radius=self.radius),
                                density=5.,
                                categoryBits=0x0010,  # [0x00] [10000]
                                maskBits=0x003,  # [0x00] [11]
                                restitution=0.0),
        )
        self.map = self.world.CreateStaticBody(
            shapes=edgeShape(vertices=[(0, 0), (self.screen_width, 0)])
        )

        self.you.color1 = (0.5, 0.4, 0.9)
        self.you.color2 = (0.3, 0.3, 0.5)

        self.goal: Box2D.b2Body = self.world.CreateStaticBody(
            position=self._goal_pos,
            fixtures=fixtureDef(shape=circleShape(radius=self.radius),
                                density=10.,
                                friction=0,
                                categoryBits=0x002,
                                maskBits=0x0010, restitution=0.0))

        self.goal.color1 = (0., 0.5, 0)
        self.goal.color2 = (0., 0.5, 0)

    def build_walls(self):
        heights = width = self.scale

        world = self.world

        wall_thickness = 3
        horizontal_wall_size = (width / 2, wall_thickness / 2)
        vertical_wall_size = (wall_thickness / 2, heights / 2)

        def create_wall(world, position, size):
            body_def = Box2D.b2BodyDef()
            body_def.position = position
            body_def.type = Box2D.b2_staticBody

            body = world.CreateStaticBody(
                position=position,
                fixtures=fixtureDef(
                    shape=polygonShape(box=size),
                    density=100, friction=0., categoryBits=0x001, restitution=1.0, )
            )
            self.walls.append(body)
            return body

        # Create walls at the borders of the screen
        create_wall(world, (width / 2, (heights - wall_thickness / 2)),
                    horizontal_wall_size)
        create_wall(world, (width / 2, wall_thickness / 2), horizontal_wall_size)
        create_wall(world, (wall_thickness / 2, heights / 2), vertical_wall_size)
        create_wall(world, ((width - wall_thickness / 2), heights / 2),
                    vertical_wall_size)

    def reset(self, *, seed=None, **kwargs):
        if seed is not None:
            self.seed(seed)
        sampled = False
        while not sampled:
            if self.eval:
                self.init_pos = self.scale * (np.array([0.8, 0.8]) + self.np_rng.uniform(-0.05, 0, size=(2,)))
            else:
                self.init_pos = self.scale * self.np_rng.uniform(0.11, 0.9, size=(2,))
            if self.is_safe(self.init_pos):
                sampled = True
        if self.you is not None:
            self._destroy()

        self.build(self.init_pos)
        self.achieve_goal = False
        self.contact = False
        return np.array(self.state), {}

    @property
    def state(self):
        pos = self.you.position
        goal_pos = self.goal.position
        vel = self.you.linearVelocity

        return (np.asarray([pos[0], pos[1],
                            goal_pos[0] - pos[0], goal_pos[1] - pos[1]]) - self.scale) / self.scale

    def get_dist_to_goal(self):
        pos_you = np.array(self.you.position).copy()
        pos_goal = np.array(self.goal.position).copy()
        return np.linalg.norm(pos_you - pos_goal).copy()

    # Check if the state is safe.
    def is_safe(self, pos):
        distance = ((pos - self.centers) ** 2).sum()
        safe = True
        if distance <= (self.r ** 2):
            safe = False
        return safe

    def step(self, action):
        action = action * self.you.mass
        d_goal_prev = self.get_dist_to_goal()

        self.goal: Box2D.b2Body
        # assert self.action_space.contains(action)
        in_the_risk_zone = False
        self.you.linearVelocity.Set(float(action[0]), float(action[1]))
        self.world.Step(1 / 6, 6 * 30, 2 * 30)
        d_goal = self.get_dist_to_goal()

        reward_without_risk = (d_goal_prev - d_goal)

        reward = 0
        data_info = {}
        if not self.is_safe(np.asarray(self.you.position)):
            u = self.np_rng.uniform(0, 1)
            if u > self.risk_prob:
                reward += self.risk_var * self.np_rng.normal(0, 1)
            in_the_risk_zone = True
        done = self.achieve_goal
        out_of_world = False
        if self.contact and not self.achieve_goal:
            reward -= 10  # contact penalty
            self.contact = False
            out_of_world = True
        if self.achieve_goal:
            reward_without_risk += 100
            data_info['is_success'] = True
        data_info.update({'deterministic_reward': reward_without_risk,
                          "in_the_risk_zone": in_the_risk_zone,
                          "out_of_world": out_of_world
                          })

        reward += reward_without_risk
        return np.array(self.state), reward, done, False, data_info

    def render(self):
        if not self.pygame_init:
            pygame.init()
            self.pygame_init = True
            self.screen = pygame.display.set_mode([self.screen_width, self.screen_height])
            self.clock = pygame.time.Clock()

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                exit()
        self.screen.fill(self.background_color)
        p_car = np.asarray(self.you.position) * self.screen_scale_rate
        p = p_car.astype(int).tolist()
        # print("p car", p_car)

        c, r = (self.centers[:2] * self.screen_scale_rate).astype(int), int(self.r)
        pygame.draw.circle(self.screen, (255, 0, 0), c, r * self.screen_scale_rate)
        pygame.draw.circle(self.screen, self.circle_color, c, (r - 1) * self.screen_scale_rate)
        '''
        pygame.draw.circle(self.screen, self.robot_color, p, self.radius * self.screen_scale_rate,
                           self.screen_width)
        '''
        pygame.draw.circle(self.screen, (34, 139, 34),
                           (np.asarray(self.goal.position) * self.screen_scale_rate).astype(int),
                           self.goal_size * self.screen_scale_rate)

        pygame.draw.circle(self.screen, self.goal_color,
                           (np.asarray(self.goal.position) * self.screen_scale_rate).astype(int),
                           (self.goal_size - 0.5) * self.screen_scale_rate)

        for body in self.walls:
            for fixture in body.fixtures:
                shape = fixture.shape
                vertices = [(body.transform * v * self.screen_scale_rate) for v in shape.vertices]
                pygame.draw.polygon(self.screen, [255, 0, 255], vertices)

        pygame.display.flip()
        self.clock.tick(20)
        return pygame.image.tobytes(self.screen, 'RGB')

    def base_img(self):
        return Image.fromarray(np.frombuffer(self.render(), dtype=np.uint8).reshape(600, 600, 3))

