import torch
from huge.algo import buffer, huge, networks
from huge import envs
env_params = envs.get_env_params("ravens_pick_or_place")
from huge.algo import buffer, variants, networks
env = envs.create_env("ravens_pick_or_place", None, 3, False, None, True, True, -1)
env_params["continuous_action_space"]=True
env_params["use_horizon"]=False
env_params["fourier"]=True
env_params["fourier_goal_selector"]=True
env_params["normalize"]=False
env_params["normalize"]=False
env_params["buffer_size"]=100
env_params["goal_selector_name"]=""
env_params["reward_layers"]="400,600,900,600,300"
env, policy, goal_selector, replay_buffer, goal_selector_buffer, gcsl_kwargs = variants.get_params(env, env_params)
#goal_selector.load_state_dict(torch.load("checkpoint/goal_selector_model_14_01_2023_13:25:36.h5"))
goal = env.sample_goal()
state = env.reset()
state = env.observation(state)
goal = goal = goal[9:18]
goal_selector( torch.Tensor(state), torch.Tensor(goal))
state2 = state.copy()
state2[0] = -100000
v1 = goal_selector( torch.Tensor(state), torch.Tensor(goal))
v2 = goal_selector( torch.Tensor(state2), torch.Tensor(goal))
mean = (v1+v2)/2
#g1g2 = torch.cat([torch.exp(v1-mean)/(torch.exp(v1-mean)+torch.exp(v2-mean)),torch.exp(v2-mean)/(torch.exp(v1-mean)+torch.exp(v2-mean)) ], axis=-1)
g1g2 =torch.exp(v1-mean)/(torch.exp(v1-mean)+torch.exp(v2-mean))
loss_fn = torch.nn.BCELoss()
loss = loss_fn(g1g2, torch.Tensor(1).float().to("cuda"))
loss.backward()
for param in goal_selector.parameters():
    print(param.grad)
