import numpy as np
import copy

preassigned_objects = {
    0: [(0.7284984725125085, 0.25382403682354227), (-0.746435544532064, 0.19486057574988944), (0.46032832136564317, -0.6190593692620655)],
    1: [(1.41059124794502966, 0.28533283566905365), (-0.25931395858285117, 0.8275000244258371), (0.44102059889169287, -0.23558614423011498)],
    2: [(0.28607050203575746, 0.5743962094732143), (-0.8256496899495922, 0.14258261835920313), (-0.22572785868734882, -0.6006781762346717)],
    3: [(0.8914449235617385, 0.985253044263275), (-1.2562605785309027, 0.43267420977364307), (-1.1542388057203732, -0.6581264258052942), (1.3274513036435218, -0.057189591965314056)],
    4: [(0.47954437133466, 0.14155280259057307), (-0.060697176883889606, 0.4963021788369721), (-0.4966413556545922, -0.05785640719547688), (0.16242871971369668, -0.47288149785349964)],
    5: [(0.0742895546650564, 1.2856982741024796), (-0.08343989374969488, 1.2851368697899328), (-1.2339578710930303, -0.3686284882235547), (1.1004339170216832, -0.6690173273023221)],
    6: [(0.9408624994001197, 0.7604177155534265), (-1.13170926353535, 0.4274242479471746), (-0.9183752462445788, -0.7874288869416345), (1.1468633359427898, -0.3849179569403168)],
    7: [(1.5786421737380196, 1.1722739142583314), (0.4926255378135824, 1.9035906393572983), (-1.8985407430597876, 0.5117424056073909), (-0.005497012838627081, -1.9662927110694128), (1.558790729938868, -1.1985444101123248)],
    8: [(0.4529929227554307, 0.21165399106440777), (-0.2109978990279942, 0.4532988932324591), (-0.4322840835701934, -0.2512577781718967), (-0.09126448922487189, -0.4916002369878632), (0.2997586658287976, -0.40018088692557424)],
    9: [(0.9628239880547248, 0.750070846069129), (-0.4817583339097068, 1.121403234258314), (-1.2125411303377678, 0.1392131938275693), (-0.9830385672174559, -0.7233750627852162), (1.102355491336986, -0.5238784943244533)]

}

preassigned_names = {
    0: "Poly3vert0form",
    1: "Poly3vert1form",
    2: "Poly3vert2form",
    3: "Poly4vert3form",
    4: "Poly4vert4form",
    5: "Poly4vert5form",
    6: "Poly4vert6form",
    7: "Poly5vert7form",
    8: "Poly5vert8form",
    9: "Poly5vert9form"
}

preassigned_nid = {
    "Poly3vert0form":0,
    "Poly3vert1form":1,
    "Poly3vert2form":2,
    "Poly4vert3form":3,
    "Poly4vert4form":4,
    "Poly4vert5form":5,
    "Poly4vert6form":6,
    "Poly5vert7form":7,
    "Poly5vert8form":8,
    "Poly5vert9form":9
}


preassigned_radius = {
    i: (np.min(preassigned_objects[i]), np.max(preassigned_objects[i])) for i in preassigned_objects.keys()
}

CONTROL_ID = 0
BALL_ID = 1
POLY_IDS = [i for i in range(2,len(list(preassigned_names.keys())) + 2)]
TARGET_ID = max(POLY_IDS) + 1
NUM_DISCRETE_ACTIONS = 8
VEL_STEP_CONSTANT = 60 / 2 # based on the fps of the box2d environment
ACC_STEP_CONSTANT = 4 # hand tuned, should be at least larger than 2


def generate_object_dicts(continuous_actions, all_names, object_names, length, width):
    object_sizes, object_range, object_dynamics, position_masks = dict(), dict(), dict(), dict()
    for n in object_names:
        if n == "Action":
            object_sizes[n] = 2 if continuous_actions else 1
            object_range[n] = [np.array([-1,-1]).astype(np.float64), np.array([1,1]).astype(np.float64)] if continuous_actions else [np.array([0]).astype(np.float64), np.array([8]).astype(np.float64)]
            object_dynamics[n] = [np.array([-2,-2]).astype(np.float64), np.array([2,2]).astype(np.float64)] if continuous_actions else [np.array([-8]).astype(np.float64), np.array([8]).astype(np.float64)]
            position_masks[n] = np.zeros(object_sizes[n])
        elif n in ["Reward", "Done"]:
            object_sizes[n] = 1
            object_range[n] = [np.array([-1]).astype(np.float64), np.array([1]).astype(np.float64)]
            object_dynamics[n] = [np.array([-2]).astype(np.float64), np.array([2]).astype(np.float64)]
            position_masks[n] = np.zeros(object_sizes[n])
        elif n.find("Poly") != -1:
            # TODO: for now, not vertex based encodings
            object_sizes[n] = 8 # pos, vel, sin angle, cos angle, ang_vel, radius
            object_range[n] = [np.array([-length / 2, -width/2, -length, -width, -1,-1,-3.15, -preassigned_radius[preassigned_nid[n]][1]]).astype(np.float64), 
                               np.array([length / 2, width/2, length, width, 1,1,3.15, preassigned_radius[preassigned_nid[n]][1]]).astype(np.float64)]
            object_dynamics[n] = [np.array([-length / VEL_STEP_CONSTANT, -width / VEL_STEP_CONSTANT, -length * ACC_STEP_CONSTANT, -width * ACC_STEP_CONSTANT,-2,-2, -15, -0.1]).astype(np.float64), 
                               np.array([length / VEL_STEP_CONSTANT, width / VEL_STEP_CONSTANT, length * ACC_STEP_CONSTANT, width * ACC_STEP_CONSTANT,2,2, 15, 0.1]).astype(np.float64)]
            position_masks[n] = np.array([1,1,0,0,0,0,0,0])
        elif n.find("Ball") != -1 or n in ["Control", "Target"]: # Ball, Control, Target
            object_sizes[n] = 5 # pos, vel, radius
            object_range[n] = [np.array([-length / 2, -width/2, -5, -5, 0.5]).astype(np.float64), 
                               np.array([length / 2, width/2, 5, 5, width / 5]).astype(np.float64)] # TODO: change default radius to be smaller... but would require regenerating all the data
            object_dynamics[n] = [np.array([-length / VEL_STEP_CONSTANT, -width / VEL_STEP_CONSTANT, -length * ACC_STEP_CONSTANT, -width * ACC_STEP_CONSTANT, -0.1]).astype(np.float64), 
                               np.array([length / VEL_STEP_CONSTANT, width / VEL_STEP_CONSTANT, length * ACC_STEP_CONSTANT, width * ACC_STEP_CONSTANT, 0.1]).astype(np.float64)]
            position_masks[n] = np.array([1,1,0,0,0])
    object_range_true, object_dynamics_true = copy.deepcopy(object_range), copy.deepcopy(object_dynamics)
    return object_sizes, object_range, object_dynamics, object_range_true, object_dynamics_true, position_masks

