import numpy as np
from data import Batch

RAND_PARAMS = ['body_mass', 'dof_damping', 'body_inertia', 'geom_friction']
RAND_PARAMS_EXTENDED = RAND_PARAMS + ['geom_size']

def get_init_params(env, rand_params = RAND_PARAMS):
    init_params = {}
    if 'body_mass' in rand_params:
        init_params['body_mass'] = env.sim.model.body_mass
    # body_inertia
    if 'body_inertia' in rand_params:
        init_params['body_inertia'] = env.sim.model.body_inertia
    # damping -> different multiplier for different dofs/joints
    if 'dof_damping' in rand_params:
        init_params['dof_damping'] = env.sim.model.dof_damping
    # friction at the body components
    if 'geom_friction' in rand_params:
        init_params['geom_friction'] = env.sim.model.geom_friction
    return init_params

def get_random_params(init_params,log_scale_limit, rand_params = RAND_PARAMS):
    # body mass -> one multiplier for all body parts
    new_params = {}
    if 'body_mass' in rand_params:
        body_mass_multiplyers = np.array(1.5) ** np.random.uniform(-log_scale_limit, log_scale_limit,  size=init_params['body_mass'].shape)
        new_params['body_mass'] = init_params['body_mass'] * body_mass_multiplyers
    # body_inertia
    if 'body_inertia' in rand_params:
        body_inertia_multiplyers = np.array(1.5) ** np.random.uniform(-log_scale_limit, log_scale_limit,  size=init_params['body_inertia'].shape)
        new_params['body_inertia'] = body_inertia_multiplyers * init_params['body_inertia']
    # damping -> different multiplier for different dofs/joints
    if 'dof_damping' in rand_params:
        dof_damping_multipliers = np.array(1.3) ** np.random.uniform(-log_scale_limit, log_scale_limit, size=init_params['dof_damping'].shape)
        new_params['dof_damping'] = np.multiply(init_params['dof_damping'], dof_damping_multipliers)
    # friction at the body components
    if 'geom_friction' in rand_params:
        dof_damping_multipliers = np.array(1.5) ** np.random.uniform(-log_scale_limit, log_scale_limit, size=init_params['geom_friction'].shape)
        new_params['geom_friction'] = np.multiply(init_params['geom_friction'], dof_damping_multipliers)
    return new_params

def get_random_params_target(init_params,log_scale_limit, rand_params = RAND_PARAMS):
    # body mass -> one multiplier for all body parts
    new_params = {}
    if 'body_mass' in rand_params:
        temp = np.random.randint(0,2,size=init_params['body_mass'].shape)*2-1
        body_mass_multiplyers = np.array(1.5) ** (np.random.uniform(1, 2,  size=init_params['body_mass'].shape)*temp)
        new_params['body_mass'] = init_params['body_mass'] * body_mass_multiplyers
    # body_inertia
    if 'body_inertia' in rand_params:
        temp = np.random.randint(0,2,size=init_params['body_inertia'].shape)*2-1
        body_inertia_multiplyers = np.array(1.5) ** (np.random.uniform(1, 2,  size=init_params['body_inertia'].shape)*temp)
        new_params['body_inertia'] = body_inertia_multiplyers * init_params['body_inertia']
    # damping -> different multiplier for different dofs/joints
    if 'dof_damping' in rand_params:
        temp = np.random.randint(0,2,size=init_params['dof_damping'].shape)*2-1
        dof_damping_multipliers = np.array(1.3) ** (np.random.uniform(1, 2, size=init_params['dof_damping'].shape)*temp)
        new_params['dof_damping'] = np.multiply(init_params['dof_damping'], dof_damping_multipliers)
    # friction at the body components
    if 'geom_friction' in rand_params:
        temp = np.random.randint(0,2,size=init_params['geom_friction'].shape)*2-1
        dof_damping_multipliers = np.array(1.5) ** (np.random.uniform(1, 2, size=init_params['geom_friction'].shape)*temp)
        new_params['geom_friction'] = np.multiply(init_params['geom_friction'], dof_damping_multipliers)
    return new_params

def get_random_params2(init_params,log_scale_limit, rand_params = RAND_PARAMS):
    # body mass -> one multiplier for all body parts
    new_params = {}
    if 'body_mass' in rand_params:
        body_mass_multiplyers = np.array(1.5) ** np.random.uniform(-log_scale_limit[0], log_scale_limit[1],  size=init_params['body_mass'].shape)
        new_params['body_mass'] = init_params['body_mass'] * body_mass_multiplyers
    # body_inertia
    if 'body_inertia' in rand_params:
        body_inertia_multiplyers = np.array(1.5) ** np.random.uniform(-log_scale_limit[0], log_scale_limit[1],  size=init_params['body_inertia'].shape)
        new_params['body_inertia'] = body_inertia_multiplyers * init_params['body_inertia']
    # damping -> different multiplier for different dofs/joints
    if 'dof_damping' in rand_params:
        dof_damping_multipliers = np.array(1.3) ** np.random.uniform(-log_scale_limit[0], log_scale_limit[1], size=init_params['dof_damping'].shape)
        new_params['dof_damping'] = np.multiply(init_params['dof_damping'], dof_damping_multipliers)
    # friction at the body components
    if 'geom_friction' in rand_params:
        dof_damping_multipliers = np.array(1.5) ** np.random.uniform(-log_scale_limit[0], log_scale_limit[1], size=init_params['geom_friction'].shape)
        new_params['geom_friction'] = np.multiply(init_params['geom_friction'], dof_damping_multipliers)
    return new_params

def get_random_params3(init_params,log_scale_limit, rand_params = RAND_PARAMS):
    # body mass -> one multiplier for all body parts
    assert log_scale_limit[0] >=0 and log_scale_limit[1] > log_scale_limit[0]
    def params_map(params_value):
        delta = log_scale_limit[1] - log_scale_limit[0]
        # map: [0, 1] --> [-b, -a]; [1, 2] -- > [a, b]
        return np.where(params_value < 1, delta * params_value - log_scale_limit[1], delta * (params_value - 1) + log_scale_limit[0])

    new_params = {}
    if 'body_mass' in rand_params:
        params_value = np.random.uniform(0, 2, size=init_params['body_mass'].shape)
        body_mass_multiplyers = np.array(1.5) ** params_map(params_value)
        new_params['body_mass'] = init_params['body_mass'] * body_mass_multiplyers
    # body_inertia
    if 'body_inertia' in rand_params:
        params_value = np.random.uniform(0, 2, size=init_params['body_inertia'].shape)
        body_inertia_multiplyers = np.array(1.5) ** params_map(params_value)
        new_params['body_inertia'] = body_inertia_multiplyers * init_params['body_inertia']
    # damping -> different multiplier for different dofs/joints
    if 'dof_damping' in rand_params:
        params_value = np.random.uniform(0, 2, size=init_params['dof_damping'].shape)
        dof_damping_multipliers = np.array(1.3) ** params_map(params_value)
        new_params['dof_damping'] = np.multiply(init_params['dof_damping'], dof_damping_multipliers)
    # friction at the body components
    if 'geom_friction' in rand_params:
        params_value = np.random.uniform(0, 2, size=init_params['geom_friction'].shape)
        dof_damping_multipliers = np.array(1.5) ** params_map(params_value)
        new_params['geom_friction'] = np.multiply(init_params['geom_friction'], dof_damping_multipliers)
    return new_params


def get_random_params4(init_params,log_scale_limit, p, rand_params = RAND_PARAMS):
    # body mass -> one multiplier for all body parts
    new_params = {}
    prob = p
    if 'body_mass' in rand_params:
        temp = np.random.choice(2,init_params['body_mass'].shape,p=[prob,1-prob])
        temp1 = np.zeros(init_params['body_mass'].shape)
        temp1[temp==0] = 1
        temp2 = np.zeros(init_params['body_mass'].shape)
        temp2[temp==1] = 1        
        u1 = np.random.uniform(log_scale_limit[0], log_scale_limit[1],  size=init_params['body_mass'].shape)
        u2 = np.random.uniform(log_scale_limit[2], log_scale_limit[3],  size=init_params['body_mass'].shape)
        u = u1*temp1+u2*temp2
        # print(u1,u2)
        # print(temp1,temp2,u)
        body_mass_multiplyers = np.array(1.5) ** u
        # print(body_mass_multiplyers)
        new_params['body_mass'] = init_params['body_mass'] * body_mass_multiplyers
    # body_inertia
    if 'body_inertia' in rand_params:
        temp = np.random.choice(2,init_params['body_inertia'].shape,p=[prob,1-prob])
        temp1 = np.zeros(init_params['body_inertia'].shape)
        temp1[temp==0] = 1
        temp2 = np.zeros(init_params['body_inertia'].shape)
        temp2[temp==1] = 1       
        u1 = np.random.uniform(log_scale_limit[0], log_scale_limit[1],  size=init_params['body_inertia'].shape)
        u2 = np.random.uniform(log_scale_limit[2], log_scale_limit[3],  size=init_params['body_inertia'].shape)
        u = u1*temp1+u2*temp2
        body_inertia_multiplyers = np.array(1.5) ** u
        new_params['body_inertia'] = body_inertia_multiplyers * init_params['body_inertia']
    # damping -> different multiplier for different dofs/joints
    if 'dof_damping' in rand_params:
        temp = np.random.choice(2,init_params['dof_damping'].shape,p=[prob,1-prob])
        temp1 = np.zeros(init_params['dof_damping'].shape)
        temp1[temp==0] = 1
        temp2 = np.zeros(init_params['dof_damping'].shape)
        temp2[temp==1] = 1       
        u1 = np.random.uniform(log_scale_limit[0], log_scale_limit[1],  size=init_params['dof_damping'].shape)
        u2 = np.random.uniform(log_scale_limit[2], log_scale_limit[3],  size=init_params['dof_damping'].shape)
        u = u1*temp1+u2*temp2
        dof_damping_multipliers = np.array(1.5) ** u
        new_params['dof_damping'] = np.multiply(init_params['dof_damping'], dof_damping_multipliers)
    # friction at the body components
    if 'geom_friction' in rand_params:
        temp = np.random.choice(2,init_params['geom_friction'].shape,p=[prob,1-prob])
        temp1 = np.zeros(init_params['geom_friction'].shape)
        temp1[temp==0] = 1
        temp2 = np.zeros(init_params['geom_friction'].shape)
        temp2[temp==1] = 1       
        u1 = np.random.uniform(log_scale_limit[0], log_scale_limit[1],  size=init_params['geom_friction'].shape)
        u2 = np.random.uniform(log_scale_limit[2], log_scale_limit[3],  size=init_params['geom_friction'].shape)
        u = u1*temp1+u2*temp2
        geom_friction_multipliers = np.array(1.5) ** u
        new_params['geom_friction'] = np.multiply(init_params['geom_friction'], geom_friction_multipliers)
    return new_params


# import gym
# import json
# env = gym.make("Hopper-v3")
# init_params = get_init_params(env)
# random_params_list = []
# for i in range(20):
#     random_param = get_random_params4(init_params,log_scale_limit=[-0.5,0.5,0.5,1],p=0.7)
#     for k,v in random_param.items():
#         random_param[k] = v.tolist()
#     random_params_list.append(random_param)
# with open("hopper.json",'w',encoding='utf-8') as f:
#     json.dump(random_params_list,f)