from emlp.reps import V,T,Rep
from emlp.groups import Z,C,S,SO,Group
from jax.scipy.spatial.transform import Rotation
import jax.numpy as jnp
import numpy as np

class PseudoScalar(Rep):
    is_regular=False
    def __init__(self,G=None):
        self.G=G
        # self.concrete = (self.G is not None)
    @property
    def concrete(self):
        return self.G is not None    
    def __call__(self,G):
        return PseudoScalar(G)
    def size(self):
        return 1
    def __str__(self):
        return "P"
    def rho(self,M):
        sign = jnp.linalg.slogdet(M@jnp.eye(M.shape[0]))[0]
        return sign*jnp.eye(1)
    def __eq__(self,other):
        return type(self)==type(other) and self.G==other.G
    def __hash__(self):
        return hash((type(self),self.G))
    @property
    def T(self):
        return self


class D(Group):
    def __init__(self,k):
        translation = np.eye(4)[np.array([3,0,1,2])][None]
        reflection = np.eye(4)[np.array([1,0,3,2])][None]
        self.discrete_generators = np.concatenate((translation,reflection))
        super().__init__(k)

# can this cause a bug because of pointing to the same rep object?


import numpy as np
import jax.numpy as jnp
from jax import jit
def vector_dual(v):
    v1 = jnp.stack([0*v[...,0],-v[...,2],v[...,1]],axis=-1)
    v2 = jnp.stack([v[...,2],0*v[...,1],-v[...,0]],axis=-1)
    v3 = jnp.stack([-v[...,1],v[...,0],0*v[...,2]],axis=-1)
    return jnp.stack([v1,v2,v3],-2)

def quat2rot(q):
    q = q/jnp.sqrt((q**2).sum(-1))[...,None]
    q0,v = jnp.split(q,[1],axis=-1)
    v_cross = vector_dual(v)
    R = jnp.eye(3)-2*q0[...,None]*v_cross+2*v_cross@v_cross
    return R.T

def mat2quat_jax(R):
    # R: (...,3,3)
    t = jnp.einsum('...ii->...', R)
    def case0(R, t):
        r = jnp.sqrt(1.0 + t)
        w = 0.5 * r
        r4 = 0.5 / (r + 1e-9)
        x = (R[...,2,1] - R[...,1,2]) * r4
        y = (R[...,0,2] - R[...,2,0]) * r4
        z = (R[...,1,0] - R[...,0,1]) * r4
        return jnp.stack([w,x,y,z], -1)

    def casei(R, i):
        a = jnp.array([0,1,2])
        i1, i2 = (i+1)%3, (i+2)%3
        r = jnp.sqrt(1.0 + R[...,i,i] - R[...,i1,i1] - R[...,i2,i2])
        qi = 0.5 * r
        r4 = 0.5 / (r + 1e-9)
        q0 = (R[...,i2,i1] - R[...,i1,i2]) * r4
        q1 = jnp.where(i==0, qi, (R[...,i,0] + R[...,0,i]) * r4)
        q2 = jnp.where(i==1, qi, (R[...,i,1] + R[...,1,i]) * r4)
        q3 = jnp.where(i==2, qi, (R[...,i,2] + R[...,2,i]) * r4)
        return jnp.stack([q0,q1,q2,q3], -1)

    cond0 = t > 0.0
    q0 = case0(R, t)
    i = jnp.argmax(jnp.stack([R[...,0,0], R[...,1,1], R[...,2,2]], -1), -1)
    q1 = jnp.where(i[...,None]==0, casei(R,0),
         jnp.where(i[...,None]==1, casei(R,1), casei(R,2)))
    q = jnp.where(cond0[...,None], q0, q1)
    q = q / jnp.linalg.norm(q, axis=-1, keepdims=True)
    return q

def bgs(v1, v2, eps: float = 1e-8):
    # Normalize v1 to get the first basis vector, b1.
    norm_v1 = jnp.linalg.norm(v1, axis=-1, keepdims=True)
    b1 = v1 / jnp.maximum(norm_v1, eps)

    # Project v2 onto b1 and subtract to make it orthogonal.
    dot_v2_b1 = jnp.sum(b1 * v2, axis=-1, keepdims=True)
    v2_orthogonal = v2 - dot_v2_b1 * b1

    # Normalize the orthogonal vector to get the second basis vector, b2.
    norm_v2_ortho = jnp.linalg.norm(v2_orthogonal, axis=-1, keepdims=True)
    b2 = v2_orthogonal / jnp.maximum(norm_v2_ortho, eps)

    # The third basis vector is the cross product of the first two.
    b3 = jnp.cross(b1, b2, axis=-1)

    rot_matrix = jnp.stack([b1, b2, b3], axis=-1)

    return rot_matrix


def ant_state_transform(x):
    """ Converts the quaternion in state vector to a rotation matrix"""
    z,q,angs,vcom,w,angv,forces = jnp.split(x,[1,5,13,16,19,27],axis=-1)
    R  =quat2rot(q).reshape(*q.shape[:-1],-1)
    #R = Rotation.from_quat(q).as_matrix().reshape(*q.shape[:-1],-1)
    #Rw = vector_dual(w).reshape(*w.shape[:-1],-1)
    return jnp.concatenate([z,R,angs,vcom,w,angv],-1) #remove forces

def ant_inv_state_transform(x):
    """ converts the """
    z,R,angs,vcom,w,angv,forces = jnp.split(x,[1,5+5,13+5,16+5,19+5,27+5],axis=-1)
    R = R.reshape(*R.shape[:-1],3,3)
    q = np.roll(Rotation.from_matrix(R).as_quat(),1,axis=-1)
    return jnp.concatenate([z,q,angs,vcom,w,angv,forces],-1)


ant_state_perm = np.array([0,1,2,3,4,5,7,9,11,6,8,10,12,13,14,15,16,17,18,
                            19,21,23,25,20,22,24,26])
inv_ant_state_perm = np.argsort(ant_state_perm)
ant_action_perm = np.array([0,2,4,6,1,3,5,7])
inv_ant_action_perm = np.argsort(ant_action_perm)

def ant_state_transform2(x):
    """ groups legs into 4"""
    return x[...,ant_state_perm]
def ant_inv_state_transform2(x):
    return x[...,inv_ant_state_perm]

def ant_action_transform2(a):
    """ groups legs into 4"""
    return a[...,ant_action_perm]
def ant_inv_action_transform2(a):
    return a[...,inv_ant_action_perm]

def _walker_state_transform(x):
    """ groups left right joint correspondences together """
    (y,orient,rh,rk,ra,lh,lk,la,vcomx,\
        vcomy,angvel,vrh,vrk,vra,vlh,vlk,vla) = jnp.split(x,np.arange(1,x.shape[-1]),axis=-1)
    reordered_tuple = (y,orient,rh,lh,rk,lk,ra,la,vcomx,\
        vcomy,angvel,vrh,vlh,vrk,vlk,vra,vla)
    return jnp.concatenate(reordered_tuple,-1)

walker_perm = _walker_state_transform(np.arange(17))
inv_walker_perm = np.argsort(walker_perm)

def walker_state_transform(x):
    return x[...,walker_perm]

def inv_walker_state_transform(x):
    return x[...,inv_walker_perm]

def _inv_walker_action_transform(a):
    rh,lh,rk,lk,ra,la= jnp.split(a,np.arange(1,a.shape[-1]),axis=-1)
    return jnp.concatenate([rh,rk,ra,lh,lk,la],-1)

inv_walker_action_perm = _inv_walker_action_transform(np.arange(6))
walker_action_perm = np.argsort(inv_walker_action_perm)

def inv_walker_action_transform(a):
    return a[...,inv_walker_action_perm]

def walker_action_transform(a):
    return a[...,walker_action_perm]

def humanoid_state_transform(x):
    z,q,x_rest,extra_info = jnp.split(x,[1,5,45],axis=-1)
    R  =quat2rot(q).reshape(*q.shape[:-1],-1)
    out = jnp.concatenate([z,R,x_rest],-1)
    return out

# def inv_humanoid_state_transform(x):
#     z,R,x_rest,extra_info = jnp.split(x,[1,5+5,45+5],axis=-1)
#     R = R.reshape(*R.shape[:-1],3,3)
#     q = np.roll(Rotation.from_matrix(R).as_quat(),1,axis=-1)
#     return jnp.concatenate([z,q,x_rest],-1)
def inv_humanoid_state_transform(x):
    z, R_flat, x_rest, extra_info = jnp.split(x, [1, 10, 50], axis=-1)
    R = R_flat.reshape(*R_flat.shape[:-1], 3, 3)
    q = jnp.roll(mat2quat_jax(R), 1, axis=-1)
    return jnp.concatenate([z, q, x_rest], -1)

leg_arm_perm = np.array([0,4,1,5,2,6,3,7,8,11,9,12,10,13])
def _humanoid_state_perm2(x):
    stuff,legsarms,vs,legsvarmsv = jnp.split(x,[13,13+14,27+6+3])
    return jnp.concatenate([stuff,legsarms[leg_arm_perm],vs,legsvarmsv[leg_arm_perm]])

humanoid_state_perm = _humanoid_state_perm2(np.arange(50))
inv_humanoid_state_perm = np.argsort(humanoid_state_perm)
humanoid_action_perm = np.concatenate([np.arange(3),3+leg_arm_perm])
inv_humanoid_action_perm = np.argsort(humanoid_action_perm)
#print(humanoid_state_perm.shape,humanoid_action_perm.shape)

# def fetch_state_transform(x):
#     gripper_pos = x[..., 0:3]  # gripper x, y, z
#     object_pos = x[..., 3:6]  # object x, y, z
#     gripper_right = x[..., 6:7] # right gripper finger
#     gripper_left = x[..., 7:8] # left gripper finger
#     desired_goal = x[..., 8:11]  # desired goal position x, y, z

#     # Concatenate transformed state
#     transformed_state = jnp.concatenate([
#         gripper_pos,
#         object_pos,
#         gripper_right,
#         gripper_left,
#         desired_goal
#     ], axis=-1)
    
#     return transformed_state

# def fetch_inv_state_transform(x):
#     gripper = x[..., 0:1]  # gripper open ratio
#     gripper_pos = x[..., 1:4]  # gripper x, y, z
#     object_pos = x[..., 6:9]  # object x, y, z
    
#     # Concatenate original state format
#     original_state = jnp.concatenate([
#         gripper,        # 1D - gripper open ratio
#         gripper_pos,    # 3D - gripper position
#         object_pos,     # 3D - object position
#     ], axis=-1)
    
#     return original_state

def manip_state_transform(x):
    ee_pos, ee_quat, ee_velp, ee_velr, desired_goal_pos, desired_goal_quat = jnp.split(x, [3, 7, 10, 13, 16], axis=-1)
    R_ee = Rotation.from_quat(ee_quat).as_matrix()
    rot6d_ee = jnp.swapaxes(R_ee[..., :, :2], -1, -2).reshape(*R_ee.shape[:-2], -1)
    R_goal = Rotation.from_quat(desired_goal_quat).as_matrix()
    rot6d_goal = jnp.swapaxes(R_goal[..., :, :2], -1, -2).reshape(*R_goal.shape[:-2], -1)
    return jnp.concatenate([ee_pos, rot6d_ee, ee_velp, ee_velr, desired_goal_pos, rot6d_goal], -1)

def inv_manip_state_transform(x):
    ee_pos, rot6d_ee, ee_velp, ee_velr, desired_goal_pos, rot6d_goal = jnp.split(x, [3, 7+2, 10+2, 13+2, 16+2], axis=-1)
    R_ee = bgs(rot6d_ee[..., :3], rot6d_ee[..., 3:6])
    ee_quat = Rotation.from_matrix(R_ee).as_quat() # qx, qy, qz, qw
    R_goal = bgs(rot6d_goal[..., :3], rot6d_goal[..., 3:6])
    desired_goal_quat = Rotation.from_matrix(R_goal).as_quat() # qx, qy, qz, qw
    return jnp.concatenate([ee_pos, ee_quat, ee_velp, ee_velr, desired_goal_pos, desired_goal_quat], -1)


P = PseudoScalar()
vector3 = T(1)+T(0)
Rrep = 3 * vector3
matrix3 = vector3**2
pseodovector3 = P*vector3
s = T(0)
# # L/R symmetries only
# environment_symmetries['Humanoid-v2'] = {
#     'state_rep':T(0)+matrix3+17*T(0)+vector3+pseodovector3+\
#             17*T(0),
#     'state_transform':humanoid_state_transform,
#     'action_rep':17*T(0),
#     'inv_action_transform':Id,
#     'symmetry_group':SO(2),
#     'action_space':"continuous",
# }
Id = lambda x:x 

#import collections
# environment_symmetries = collections.defaultdict(lambda: {
#     'state_transform':Id,
#     'inv_state_transform':Id,
#     'action_transform':Id,
#     'inv_action_transform':Id})
legarmrep = P*T(1)+3*T(1)+P*T(1)+2*T(1)

# fetch_reach_rep = {
#     # State: [ee_pose, goal]
#     # ee_pose  = [x, y, z] -> (T(1) + T(0)) = 3D
#     # goal = [x, y, z] -> (T(1) + T(0)) = 3D
#     # Total: 6D
#     'state_rep': vector3 + vector3,

#     # Action:
#     # [x, y, z, gripper] -> (T(1) + T(0)) + T(0) = 3+1 = 4D
#     'action_rep': vector3 + s,

#     'state_transform': Id,
#     'inv_state_transform': Id,
#     'action_transform': Id,
#     'inv_action_transform': Id,
#     'symmetry_group': C(8),
#     'action_space': "continuous",
# }
# fetch_reach_rep_SO2 = fetch_reach_rep.copy()
# fetch_reach_rep_SO2['symmetry_group'] = SO(2)
fetch_reach_rep_SO3 = {
    # State: [ee_pose, ee_vel, goal]
    # ee_pose = [x, y, z] -> T(1) = 3D
    # ee_vel = [vx, vy, vz] -> T(1) = 3D
    # goal = [x, y, z] -> T(1) = 3D
    # Total: 9D
    'state_rep': T(1) + T(1) + T(1),

    # Action:
    # [x, y, z, gripper] -> T(1) + T(0) = 3+1 = 4D
    'action_rep': T(1) + s,

    'state_transform': Id,
    'inv_state_transform': Id,
    'action_transform': Id,
    'inv_action_transform': Id,
    'symmetry_group': SO(3),
    'action_space': "continuous",
}
fetch_obj_rep = {
    # State: [ee_pose(3), obj_pose(3), rel_obj_pose(3), gripper_fingers(2), obj_rot(3),
    #         rel_obj_vel(3), obj_ang_vel(3), ee_vel(3), gripper_fingers_vel(2), cur_obj_pos(3), goal(3)]
    # ee_pose  = [x, y, z] -> (T(1) + T(0)) = 3D
    # obj_pose = [x, y, z] -> (T(1) + T(0)) = 3D
    # rel_obj_pose = [x, y, z] -> (T(1) + T(0)) = 3D
    # gripper_fingers = [right, left] -> 2*T(0) = 2D
    # obj_rot = [x, y, z] -> 3*T(0) = 3D
    # rel_obj_vel = [vx, vy, vz] -> (T(1) + T(0)) = 3D
    # obj_ang_vel = [wx, wy, wz] -> 3*T(0) = 3D
    # ee_vel = [vx, vy, vz] -> (T(1) + T(0)) = 3D
    # gripper_fingers_vel = [right, left] -> 2*T(0) = 2D
    # cur_obj_pos = [x, y, z] -> (T(1) + T(0)) = 3D
    # goal = [x, y, z] -> (T(1) + T(0)) = 3D
    # Total: 31D
    'state_rep': vector3 + vector3 + vector3 + 2*s + 3*s + vector3 + 3*s + vector3 + 2*s + vector3 + vector3,

    # Action: RPP representation for equivariance
    # [x, y, z, gripper] ->(T(1) + T(0)) + T(0) = 3+1 = 4D
    'action_rep': vector3 + s,

    'state_transform': Id,
    'inv_state_transform': Id,
    'action_transform': Id,
    'inv_action_transform': Id,
    'symmetry_group': C(8),
    'action_space': "continuous",
}
fetch_obj_rep_SO2 = fetch_obj_rep.copy()
fetch_obj_rep_SO2['symmetry_group'] = SO(2)
fetch_obj_rep_SO3 = {
    # State: [ee_pose, obj_pose, gripper_fingers, goal]
    # ee_pose  = [x, y, z] -> T(1) = 3D
    # obj_pose = [x, y, z] -> T(1) = 3D
    # gripper_fingers = [right, left] -> 2*T(0) = 2D
    # goal = [x, y, z] -> T(1) = 3D
    # Total: 11D
    'state_rep': T(1) + T(1) + 2*s + T(1),

    # Action:
    # [x, y, z, gripper] -> T(1) + T(0) = 3+1 = 4D
    'action_rep': T(1) + s,

    'state_transform': Id,
    'inv_state_transform': Id,
    'action_transform': Id,
    'inv_action_transform': Id,
    'symmetry_group': SO(3),
    'action_space': "continuous",
}

manip_reach_rep_SO3 = {
    # State: [ee_pos, rot6d_ee, ee_velp, ee_velr, desired_goal_pos, rot6d_goal]
    # ee_pos = [x, y, z] -> T(1) = 3D
    # rot6d_ee -> 2*T(1) = 6D
    # ee_velp = [vx, vy, vz] -> T(1) = 3D
    # ee_velr = [wx, wy, wz] -> T(1) = 3D
    # desired_goal_pos = [x, y, z]  -> T(1) = 3D
    # rot6d_goal -> 2*T(1) = 6D
    # Total: 24D
    'state_rep': 8*T(1),

    # Action:
    # [x, y, z, ax, ay, az] -> 3*T(0) + 3*T(0) = 3+3 = 6D
    'action_rep': 6*s,

    'state_transform': manip_state_transform,
    'inv_state_transform': inv_manip_state_transform,
    'action_transform': Id,
    'inv_action_transform': Id,
    'symmetry_group': SO(3),
    'action_space': "continuous",
}

environment_symmetries={
    'Humanoid-v2': {
        'state_rep':T(0)+Rrep+17*T(0)+vector3+17*T(0),
        'state_transform':humanoid_state_transform,
        'inv_state_transform':inv_humanoid_state_transform,
        'action_rep':17*T(0),
        'action_transform':Id,
        'inv_action_transform':Id,
        'symmetry_group':SO(2),
        'action_space':"continuous",
    },
    'PE-Humanoid-v2': {
        'state_rep':T(0)+Rrep+17*T(0)+vector3+17*T(0),
        'state_transform':humanoid_state_transform,
        'inv_state_transform':inv_humanoid_state_transform,
        'action_rep':17*T(0),
        'action_transform':Id,
        'inv_action_transform':Id,
        'symmetry_group':SO(2),
        'action_space':"continuous",
    },
    # 'Humanoid-v2': {
    #     'state_rep':T(0)+matrix3+17*T(0)+vector3+pseodovector3+\
    #         17*T(0),
    #     'state_transform':humanoid_state_transform,
    #     'inv_state_transform':inv_humanoid_state_transform,
    #     'action_rep':17*T(0),
    #     'action_transform':Id,
    #     'inv_action_transform':Id,
    #     'symmetry_group':SO(2),
    #     'action_space':"continuous",
    # },
    # 'Humanoid-v2': {
    #     'state_rep':s+matrix3+s+s+P+legarmrep+\
    #         vector3+pseodovector3+s+s+P+legarmrep,
    #     'state_transform':lambda x: humanoid_state_transform(x)[...,humanoid_state_perm],
    #     'inv_state_transform':lambda x: inv_humanoid_state_transform(x[...,inv_humanoid_state_perm]),
    #     'action_rep':2*s+P+legarmrep,
    #     'action_transform':lambda a: a[...,humanoid_action_perm],
    #     'inv_action_transform':lambda a: a[...,inv_humanoid_action_perm],
    #     'symmetry_group':Z(2),
    #     'action_space':"continuous",
    #     'middle_rep':136*T(0)+60*T(1),
    # },
    'Humanoid-v2': {
        'state_rep':s+matrix3+s+s+P+legarmrep+\
            vector3+pseodovector3+s+s+P+legarmrep,
        'state_transform':lambda x: humanoid_state_transform(x)[...,humanoid_state_perm],
        'inv_state_transform':lambda x: inv_humanoid_state_transform(x[...,inv_humanoid_state_perm]),
        'action_rep':2*s+P+legarmrep,
        'action_transform':lambda a: a[...,humanoid_action_perm],
        'inv_action_transform':lambda a: a[...,inv_humanoid_action_perm],
        'symmetry_group':Z(2),
        'action_space':"continuous",
        'middle_rep':136*T(0)+60*T(1),
    },
    # 'Ant-v2': {
    #     'state_rep':T(0)+matrix3+8*T(0)+vector3+pseodovector3+\
    #         8*T(0),#+14*vector3+14*pseodovector3,
    #     'state_transform':ant_state_transform,
    #     'inv_state_transform':ant_inv_state_transform, #TODO: write inv
    #     'action_rep':8*T(0),
    #     'action_transform':Id,
    #     'inv_action_transform':Id,
    #     'symmetry_group':SO(2),
    #     'action_space':"continuous",
    # },
    'Ant-v2': {
        'state_rep':T(0)+4*T(0)+2*T(1)+6*T(0)+\
            2*T(1),#+14*vector3+14*pseodovector3,
        'state_transform':ant_state_transform2,
        'inv_state_transform':ant_inv_state_transform2,
        'action_rep':2*T(1),
        'action_transform':ant_action_transform2,
        'inv_action_transform':ant_inv_action_transform2,
        'symmetry_group':Z(4),
        'action_space':"continuous",
        'middle_rep':136*T(0)+30*T(1),
    },
    'PE-Ant-v2': {
        'state_rep':T(0)+4*T(0)+2*T(1)+6*T(0)+\
            2*T(1),#+14*vector3+14*pseodovector3,
        'state_transform':ant_state_transform2,
        'inv_state_transform':ant_inv_state_transform2,
        'action_rep':2*T(1),
        'action_transform':ant_action_transform2,
        'inv_action_transform':ant_inv_action_transform2,
        'symmetry_group':Z(4),
        'action_space':"continuous",
        # 'middle_rep':136*T(0)+30*T(1),
        # 'pA_middle_rep':136*T(0)+30*T(1), # 108*T(0)+47*T(1)
    },
    'Swimmer-v2':{ # Focus just on LR symmetry now, to add front back later
        'state_rep':P+P+P+2*T(0)+P+P+P, # shoud vcom swap?
        'state_transform':Id,
        'inv_state_transform':Id,
        'action_rep':2*P,
        'action_std_rep':2*T(0),
        'action_transform':Id,
        'inv_action_transform':Id,
        'symmetry_group':Z(2),
        'action_space':"continuous",
        'middle_rep':126*T(0)+55*T(1)+5*T(2),
    },
    'PE-Swimmer-v2':{ # Focus just on LR symmetry now, to add front back later
        'state_rep':P+P+P+2*T(0)+P+P+P, # shoud vcom swap?
        'state_transform':Id,
        'inv_state_transform':Id,
        'action_rep':2*P,
        'action_std_rep':2*T(0),
        'action_transform':Id,
        'inv_action_transform':Id,
        'symmetry_group':Z(2),
        'action_space':"continuous",
        # 'middle_rep':126*T(0)+55*T(1)+5*T(2),
        # 'pA_middle_rep':126*T(0)+55*T(1)+5*T(2), # 108*T(0)+47*T(1)
    },
    'Walker2d-v2':{
        'state_rep':2*T(0)+3*T(1)+3*T(0)+3*T(1),
        'state_transform':walker_state_transform,
        'inv_state_transform':inv_walker_state_transform,
        'action_rep':3*T(1),
        'action_transform':walker_action_transform,
        'inv_action_transform':inv_walker_action_transform,
        'symmetry_group':Z(2),
        'action_space':"continuous",
        'middle_rep':136*T(0)+60*T(1),
    },
    'Hopper-v2':{ #
        'state_rep':T(0)+4*P+(P+T(0)) + 4*P, # shoud vcom swap?
        'state_transform':Id,
        'inv_state_transform':Id,
        'action_rep':3*P,
        'action_std_rep':3*T(0),
        'action_transform':Id,
        'inv_action_transform':Id,
        'symmetry_group':Z(2),
        'action_space':"continuous",
        'middle_rep':126*T(0)+55*T(1)+5*T(2),
    },
    'HalfCheetah-v2':{
        'state_rep':T(0) + 8*P + T(0) + 7*P,
        'state_transform':Id,
        'inv_state_transform':Id,
        'action_rep':6*P,
        'action_std_rep':6*T(0),
        'action_transform':Id,
        'inv_action_transform':Id,
        'symmetry_group':Z(2),
        'action_space':"continuous",
        'middle_rep':126*T(0)+55*T(1)+5*T(2),
    },
    'InclinedCartpole-v0':{
        'state_rep':P + P + P + P,
        'state_transform':Id,
        'inv_state_transform':Id,
        'action_rep':T(1),
        'action_transform':Id,
        'inv_action_transform':Id,
        'symmetry_group':Z(2),
        'action_space':"discrete",
    },
    'CartPole-v0':{
        'state_rep':P + P + P + P,
        'state_transform':Id,
        'inv_state_transform':Id,
        'action_rep':T(1),
        'action_transform':Id,
        'inv_action_transform':Id,
        'symmetry_group':Z(2),
        'action_space':"discrete",
    },
    # 'FetchReach-v2': fetch_reach_rep,
    # 'FetchReachDense-v2': fetch_reach_rep,
    # 'FetchReach-v2_SO(2)': fetch_reach_rep_SO2,
    # 'FetchReachDense-v2_SO(2)': fetch_reach_rep_SO2,
    'FetchReach-v2_SO(3)': fetch_reach_rep_SO3,
    'FetchReachDense-v2_SO(3)': fetch_reach_rep_SO3,
    'FetchPush-v2': fetch_obj_rep,
    'FetchPushDense-v2': fetch_obj_rep,
    'FetchPush-v2_SO(2)': fetch_obj_rep_SO2,
    'FetchPushDense-v2_SO(2)': fetch_obj_rep_SO2,
    'FetchSlide-v2': fetch_obj_rep,
    'FetchSlideDense-v2': fetch_obj_rep,
    'FetchSlide-v2_SO(2)': fetch_obj_rep_SO2,
    'FetchSlideDense-v2_SO(2)': fetch_obj_rep_SO2,
    'FetchPickAndPlace-v2': fetch_obj_rep,
    'FetchPickAndPlaceDense-v2': fetch_obj_rep,
    'FetchPickAndPlace-v2_SO(2)': fetch_obj_rep_SO2,
    'FetchPickAndPlaceDense-v2_SO(2)': fetch_obj_rep_SO2,
    'FetchPickAndPlace-v2_SO(3)': fetch_obj_rep_SO3,
    'FetchPickAndPlaceDense-v2_SO(3)': fetch_obj_rep_SO3,
    'UR5eReach-v0': manip_reach_rep_SO3,
    'UR5eReachDense-v0': manip_reach_rep_SO3,
}


# halfcheetah_rep2 = {
#     'state_rep':2*T(0)+3*T(1)+3*T(0)+3*T(1),
#     'state_transform':walker_state_transform,
#     'action_rep':3*T(1),
#     'inv_action_transform':inv_walker_action_transform,
#     'symmetry_group':Z(2),
# }
# comment back in to use the 2nd rep version for half cheetah
# environment_symmetries['HalfCheetah-v2']=halfcheetah_rep2