from Box2D import b2PolygonShape, b2_pi
import math
from dataclasses import dataclass
import numpy as np


@dataclass
class Ball:
    x: float
    y: float
    radius: float
    color: str = "black"
    dynamic: bool = True


@dataclass
class Basket:
    x: float
    y: float
    scale: float
    angle: float
    color: str = "gray"
    dynamic: bool = False


@dataclass
class Platform:
    x: float
    y: float
    length: float
    angle: float
    color: str = "black"
    dynamic: bool = False


# Function to create the basket
def create_basket(world, basket_args, name):
    # Unpack basket arguments
    x = basket_args.x
    y = basket_args.y
    scale = basket_args.scale
    dynamic = basket_args.dynamic
    angle = float(basket_args.angle) # pretty much always 0

    # Adjust dimensions based on scale
    width = 1.083 * scale
    height = 1.67 * scale
    theta = 5 * b2_pi / 180
    thickness = 0.1 * scale
    angle_shift = math.cos(theta) * thickness

    # Create the basket body
    if dynamic:
        basket_body = world.CreateDynamicBody(
            position=(x, y),
            angle=basket_args.angle,
            bullet=True,
        )
    else:
        basket_body = world.CreateStaticBody(
            position=(x, y),
            angle=basket_args.angle,
            bullet=True,
        )

    # Create the bottom rectangle
    bottom_box = basket_body.CreatePolygonFixture(
        box=(width / 2, thickness / 2),
        density=1,
        friction=0.5,
        restitution=0.5,
    )
    bottom_box.shape.SetAsBox(
        width / 2,
        thickness / 2,
        (0, thickness / 2),
        angle,
    )

    # Create the left side rectangle
    left_box = basket_body.CreatePolygonFixture(
        box=(thickness / 2, height / 2),
        density=1,
        friction=0.5,
        restitution=0.5,
    )
    left_box.shape.SetAsBox(
        thickness / 2,
        height / 2,
        (-width / 2 + thickness / 2 - angle_shift, height / 2 + thickness / 2),
        angle + theta,
    )

    # Create the right side rectangle
    right_box = basket_body.CreatePolygonFixture(
        box=(thickness / 2, height / 2),
        density=1,
        friction=0.5,
        restitution=0.5,
    )
    right_box.shape.SetAsBox(
        thickness / 2,
        height / 2,
        (width / 2 - thickness / 2 + angle_shift, height / 2 + thickness / 2),
        angle-theta,
    )

    basket_body.userData = name
    return basket_body


# Create walls centered around the origin
def create_walls(world, wall_thickness, room_width, room_height):
    left_wall = world.CreateStaticBody(
        position=(-room_width / 2 + wall_thickness / 2, 0),
        shapes=b2PolygonShape(box=(wall_thickness, room_height)),
    )
    right_wall = world.CreateStaticBody(
        position=(room_width / 2 - wall_thickness / 2, 0),
        shapes=b2PolygonShape(box=(wall_thickness, room_height)),
    )
    top_wall = world.CreateStaticBody(
        position=(0, room_height / 2 - wall_thickness / 2),
        shapes=b2PolygonShape(box=(room_width, wall_thickness)),
    )
    bottom_wall = world.CreateStaticBody(
        position=(0, -room_height / 2 + wall_thickness / 2),
        shapes=b2PolygonShape(box=(room_width, wall_thickness)),
    )

    left_wall.userData = "left_wall"
    right_wall.userData = "right_wall"
    top_wall.userData = "top_wall"
    bottom_wall.userData = "bottom_wall"
    return left_wall, right_wall, top_wall, bottom_wall


def create_platform(world, platform_args, name):
    # Unpack platform arguments
    x = platform_args.x
    y = platform_args.y
    length = platform_args.length
    width = 0.1
    angle = platform_args.angle * b2_pi / 180
    dynamic = platform_args.dynamic

    if dynamic:
        platform = world.CreateDynamicBody(
            position=(x, y),
            angle=angle,
            bullet=True,
        )
    else:
        platform = world.CreateStaticBody(
            position=(x, y),
            angle=angle,
            bullet=True,
        )

    platform.CreatePolygonFixture(
        box=(length, width),
        density=1,
        friction=0.5,
        restitution=0.5,
    )

    platform.userData = name
    return platform


def create_ball(world, ball_args, name):
    # Unpack ball arguments
    x = ball_args.x
    y = ball_args.y
    radius = ball_args.radius
    dynamic = ball_args.dynamic

    if dynamic:
        circle = world.CreateDynamicBody(
            position=(x, y),
            angle=0,
            bullet=True,
        )
    else:
        circle = world.CreateStaticBody(
            position=(x, y),
            angle=0,
            bullet=True,
        )

    circle.CreateCircleFixture(
        radius=radius,
        density=1,
        friction=0.5,
        restitution=0.5,
    )

    circle.userData = name
    return circle

def attrs_from_state(factor_state, name):
    type_name = name.strip("0123456789")
    if type_name in ["Ball", "Target"]:
        # TODO: colors do not match after setting because color is not a parameter
        attrs = Ball(factor_state[0], factor_state[1], factor_state[5], "green" if type_name == "Target" else "red", factor_state[6])
    if type_name in ["Basket"]:
        attrs = Basket(factor_state[0], factor_state[1], factor_state[7], np.arctan(factor_state[4] / (factor_state[5] + 1e-5)), "gray", factor_state[8])
    if type_name in ["Platform"]:
        attrs = Platform(factor_state[0], factor_state[1], factor_state[7], np.arctan(factor_state[4] / (factor_state[5] + 1e-5)), "gray", factor_state[8])
    return attrs

def set_velocities_from_state(factor_state, body, name):
    type_name = name.strip("0123456789")
    body.linearVelocity.x = factor_state[2]
    body.linearVelocity.y = factor_state[3]
    if type_name in ["Ball", "Target"]:
        body.angularVelocity = factor_state[4]
    if type_name in ["Basket", "Platform"]:
        body.angularVelocity = factor_state[6]


def create_object(world, obj_name, name, object_attrs):
    type_name = name.strip("0123456789")
    if type_name in ["Ball", "Target"]:
        # TODO: colors do not match after setting because color is not a parameter
        return create_ball(world, object_attrs, obj_name)
    if type_name in ["Basket"]:
        return create_basket(world, object_attrs, obj_name)
    if type_name in ["Platform"]:
        return create_platform(world, object_attrs, obj_name)



class PHYREObjectWrapper:
    def __init__(self, name, body_name, object, obj_type, dynamic):
        self.name = name
        self.body_name = body_name
        self.object = object
        self.interaction_trace = list()
        self.type = obj_type
        self.dynamic = dynamic

    def get_state(self):
        if self.object is None:
            if self.type in ["Ball", "Target"]:
                return np.zeros(7)
            else:
                return np.zeros(9)
        if self.type in ["Ball", "Target"]:
            state = [
                self.object.position.x,
                self.object.position.y,
                self.object.linearVelocity.x,
                self.object.linearVelocity.y,
                self.object.angularVelocity,
                self.object.fixtures[0].shape.radius,
                self.dynamic,
            ]
        elif self.type in ["Platform"]:
            state = [
                self.object.position.x,
                self.object.position.y,
                self.object.linearVelocity.x,
                self.object.linearVelocity.y,
                np.sin(self.object.angle),
                np.cos(self.object.angle),
                self.object.angularVelocity,
                float(self.object.fixtures[0].shape.vertices[1][0] - self.object.fixtures[0].shape.vertices[0][0]),
                self.dynamic,
            ]
        elif self.type in ["Basket"]:
            state = [
                self.object.position.x,
                self.object.position.y,
                self.object.linearVelocity.x,
                self.object.linearVelocity.y,
                np.sin(self.object.angle),
                np.cos(self.object.angle),
                self.object.angularVelocity,
                self.object.fixtures[0].shape.vertices[1][0] * 2 / 1.083,
                self.dynamic,
            ]
        else:
            raise Exception(f"{self.name} is of unrecognized type: {self.type}")
        return np.array(state)
