import torch
import numpy as np
from collections import deque
from torch.nn.functional import smooth_l1_loss
from torch.nn import Module, Linear, BatchNorm1d, Sequential, ReLU
from torch.distributions import Categorical
import random
import sys
sys.path.append('../LNet/')
from LMapNet_final import *
ROOM_SIZE_METER = 12
import math

# todo ####################################################################
class LnetsplitRNN(Module):
    def __init__(self, device, hidden_size):
        super(LnetsplitRNN, self).__init__()
        self.device = device
        self.Lnet = LMapNet_withPoseNet_uncertainty(hidden_size=hidden_size, device=device)

    def forward(self, obs, objs_vector, target_pos, backpack, \
                hidden_state_agent, cell_state_agent, hidden_state_objs_1, cell_state_objs_1, \
                hidden_state_objs_2, cell_state_objs_2, hidden_state_objs_3, cell_state_objs_3):
        seq_len = obs.shape[1]
        B_size = obs.shape[0]
        input_obs = obs.view(-1, 3, 80, 80)  ## B*L,3,80,80
        
        objs_vector = objs_vector.unsqueeze(1)
        objs_vector = objs_vector.expand(-1, seq_len, -1, -1)
        objs_vector = objs_vector.reshape(-1, 3, 10)

        agent_pose, obj_poses, obj_binarys, obj_var = self.Lnet.LMapNet.PoseNet(input_obs, objs_vector) ## agent_pose B*L*3, obj_poses B*L*2   binary B*L*1
        
        features_agent = self.Lnet.LMapNet.PoseNet.posenet.resnet18_agent(input_obs)
        features_objs = self.Lnet.LMapNet.PoseNet.posenet.resnet18_objs(input_obs)
        ## 

        agent_input = torch.cat((agent_pose, features_agent), dim=1) ## B, 3+128
        ###########  preprocess and process data of agent
        output_agent, (hidden_state_agent,cell_state_agent) = self.Lnet.LMapNet.rnn_agent(agent_input.unsqueeze(1),(hidden_state_agent,cell_state_agent)) ### output shape = (B,1,hidden size)
        predict_agent_pose_RNN = self.Lnet.LMapNet.MLP_agent_mean(output_agent).squeeze(1) # B, 3

        ###########  preprocess and process data of objs
        obj_global_poses_input = []
        
        for objs_local_pose in obj_poses:
            objs_global_x = objs_local_pose[:,0:1] * torch.cos(predict_agent_pose_RNN[:,2:3])\
                                - objs_local_pose[:,1:2]* torch.sin(predict_agent_pose_RNN[:,2:3]) + predict_agent_pose_RNN[:,0:1]
                                
            objs_global_y = objs_local_pose[:,1:2] * torch.cos(predict_agent_pose_RNN[:,2:3])\
                                + objs_local_pose[:,0:1]* torch.sin(predict_agent_pose_RNN[:,2:3]) + predict_agent_pose_RNN[:,1:2]
            
            obj_global_poses_input.append(torch.cat((objs_global_x,objs_global_y), dim=1))
      
        obj_var = torch.cat([obj_var[i].unsqueeze(1) for i in range(len(obj_var))], dim = 1)
        covar_matrix = torch.zeros((B_size, 3, 2, 2))
        covar_matrix[:, :,0,0] = torch.exp(obj_var[:, :, 0])
        covar_matrix[:, :,0,1] = obj_var[:, :, 1]
        covar_matrix[:, :,1,1] = torch.exp(obj_var[:, :, 2])
    
        covar_matrix = torch.linalg.inv(torch.matmul(covar_matrix.permute(0, 1, 3,2),covar_matrix)).to(obj_var.device)

        rotate_matrix = torch.cat([torch.cos(predict_agent_pose_RNN[:,2]).unsqueeze(1), torch.sin(predict_agent_pose_RNN[:,2]).unsqueeze(1),
                                    -torch.sin(predict_agent_pose_RNN[:,2]).unsqueeze(1), torch.cos(predict_agent_pose_RNN[:,2]).unsqueeze(1)], 
                                    dim = 1).view(-1, 2, 2).unsqueeze(1).repeat(1,3,1,1)

        covar_matrix = torch.matmul(torch.matmul(rotate_matrix.permute(0,1,3,2),covar_matrix),rotate_matrix)
        obj_var = torch.cat([covar_matrix[:, :,0,0].unsqueeze(-1),covar_matrix[:, :,0,1].unsqueeze(-1),covar_matrix[:, :,1,1].unsqueeze(-1)], dim = 2).permute(1,0,2)

        obj_global_poses_input = torch.cat(obj_global_poses_input,dim=1) ## B,6  
        obj_binarys_input = torch.cat(obj_binarys,dim=1)  ## B,3
        obj_binarys_input = self.Lnet.LMapNet.sigmoid_net(obj_binarys_input)

        # obj_binarys_input = torch.where(obj_binarys_input <= 0.5, 0, 1)
        input_data_objs_1 = torch.cat((obj_global_poses_input[:,0:2], obj_var[0], obj_binarys_input[:,0:1], predict_agent_pose_RNN, features_objs),dim=1).unsqueeze(1) ## B,1,2+1+3+256
        input_data_objs_2 = torch.cat((obj_global_poses_input[:,2:4], obj_var[1], obj_binarys_input[:,1:2], predict_agent_pose_RNN, features_objs),dim=1).unsqueeze(1) ## B,1,2+1+3+256
        input_data_objs_3 = torch.cat((obj_global_poses_input[:,4:6], obj_var[2], obj_binarys_input[:,2:3], predict_agent_pose_RNN, features_objs),dim=1).unsqueeze(1) ## B,1,2+1+3+256

        output_objs_1, (hidden_state_objs_1, cell_state_objs_1) = self.Lnet.LMapNet.rnn_objs(input_data_objs_1,(hidden_state_objs_1, cell_state_objs_1)) ### output shape = (B,1,hidden size)
        object_pose_predict_1 = self.Lnet.LMapNet.MLP_objs_mean(output_objs_1) # B,1, 2
        object_var_predict_1 = self.Lnet.MLP_objs_var(output_objs_1)

        output_objs_2, (hidden_state_objs_2, cell_state_objs_2) = self.Lnet.LMapNet.rnn_objs(input_data_objs_2,(hidden_state_objs_2, cell_state_objs_2)) ### output shape = (B,1,hidden size)
        object_pose_predict_2 = self.Lnet.LMapNet.MLP_objs_mean(output_objs_2) # B, 1, 2
        object_var_predict_2 = self.Lnet.MLP_objs_var(output_objs_2)

        output_objs_3, (hidden_state_objs_3, cell_state_objs_3) = self.Lnet.LMapNet.rnn_objs(input_data_objs_3,(hidden_state_objs_3, cell_state_objs_3)) ### output shape = (B,1,hidden size)
        object_pose_predict_3 = self.Lnet.LMapNet.MLP_objs_mean(output_objs_3) # B, 1, 2
        object_var_predict_3 = self.Lnet.MLP_objs_var(output_objs_3)

        object_pose_predict_all = torch.cat((object_pose_predict_1,object_pose_predict_2,object_pose_predict_3),dim=1) ## next obj pos mean  (B, 3, 2)    
        object_var_predict_all = torch.cat((object_var_predict_1,object_var_predict_2,object_var_predict_3),dim=1) ## next obj pos variance  (B, 3, 3)    


        ###########  detach from torch to np array 

        object_globalpose_mean_predict_all = object_pose_predict_all.detach().cpu().numpy()  ## next obj pos mean  (B, 3, 2)    
        object_globalpose_variance_predict_all = object_var_predict_all.detach().cpu().numpy() ## next obj pos variance  (B, 3, 3) 
        
        covar_matrix = np.zeros((B_size, 3, 2, 2))
        covar_matrix[:, :, 0,0] = np.exp(object_globalpose_variance_predict_all[:, :,0])
        covar_matrix[:, :, 0,1] = object_globalpose_variance_predict_all[:, :, 1]
        covar_matrix[:, :, 1,1] = np.exp(object_globalpose_variance_predict_all[:, :, 2])
        covar_matrix_final = np.linalg.inv(np.matmul(np.transpose(covar_matrix,(0, 1, 3,2)),covar_matrix))

        ######  processing state for polict net

        self.predicted_agent_pose = predict_agent_pose_RNN.detach().cpu().numpy()  ## B, 3


        state = []
        target_pos = target_pos.detach().cpu().numpy()
        backpack = backpack.detach().cpu().numpy()

        for i in range(3):
            init_pos = target_pos[:,i,:]
            state.extend([object_globalpose_mean_predict_all[:, i, 0] / (ROOM_SIZE_METER / 2) - 1,
                          object_globalpose_mean_predict_all[:, i, 1] / (ROOM_SIZE_METER / 2) - 1,
                          covar_matrix_final[:,i,0,0], covar_matrix_final[:,i, 0,1], covar_matrix_final[:,i,1,1], init_pos[:, 0], init_pos[:, 1]])
        state.extend([self.predicted_agent_pose[:, 0], self.predicted_agent_pose[:, 1], self.predicted_agent_pose[:, 2]])
        state.append(backpack)
        state = np.array(state)
        # state = [self.predicted_agent_pose, object_globalpose_mean_predict_all, covar_matrix_final] 
        return state, (hidden_state_agent, cell_state_agent, hidden_state_objs_1, cell_state_objs_1, \
                       hidden_state_objs_2, cell_state_objs_2, hidden_state_objs_3, cell_state_objs_3)

class PNet_PPO_GTPos(Module):
    def __init__(self, device, input_size, out_size=4):
        super(PNet_PPO_GTPos, self).__init__()
        self.input_size = input_size
        self.device = device
        self.model = Sequential(
            Linear(self.input_size, 256),
            ReLU(),
            Linear(256, 128),
            ReLU(),
            Linear(128, 64),
            ReLU(),
            Linear(64, 32),
            ReLU(),
            Linear(32, out_size),
        )
    def forward(self, states, batch_size):
        states = torch.from_numpy(states).float().to(self.device)
        x = states.view(batch_size, self.input_size)
        x = self.model(x)
        return torch.nn.functional.softmax(x, dim=1).squeeze()

class value_net(Module):
    def __init__(self, device, input_size):
        super(value_net, self).__init__()
        self.input_size = input_size
        self.device = device
        self.model = Sequential(
            Linear(self.input_size, 64),
            ReLU(),
            Linear(64, 64),
            ReLU(),
            Linear(64, 32),
            ReLU(),
            Linear(32, 1),
        )
    def forward(self, states, batch_size):
        states = torch.from_numpy(states).float().to(self.device)
        x = states.view(batch_size, self.input_size)
        x = self.model(x)
        return x.squeeze()

class Agent_PNet_PPO():
    def __init__(self, device, input_size=17, lr=0.0001, batch_size=1000,
                 discount=0.90, clip=0.1, test=False, test_model="",
                 num_envs=1, num_test_envs=1, LMapNet_pretrain_model='', pick_model_path='', drop_model_path=''):
        self.input_size = input_size
        self.test = test
        self.device = device
        self.batch_size = batch_size
        self.hidden_size = 512
        self.NUM_OBJECTS = 3
        self.ROOM_SIZE_METER = 12
        self.LoclizationNet = LnetsplitRNN(self.device, self.hidden_size).to(self.device)
        self.LoclizationNet.Lnet.load_state_dict(torch.load(LMapNet_pretrain_model))
        self.LoclizationNet.eval()
        self.num_envs = num_envs
        self.pick_model = PNet_PPO_GTPos(self.device, 19).to(self.device)
        self.pick_model.load_state_dict(torch.load(pick_model_path))
        self.pick_model.eval()
        self.drop_model = PNet_PPO_GTPos(self.device, 19).to(self.device)
        self.drop_model.load_state_dict(torch.load(drop_model_path))
        self.drop_model.eval()
        self.policy_net = PNet_PPO_GTPos(self.device, self.input_size, out_size=self.NUM_OBJECTS).to(self.device)
        self.target_net = PNet_PPO_GTPos(self.device, self.input_size, out_size=self.NUM_OBJECTS).to(self.device)
        self.value_net = value_net(self.device, self.input_size).to(self.device)
        if test:
            self.target_net.load_state_dict(torch.load(test_model))
            self.target_net.eval()
        else:
            self.lr = lr
            self.clip = clip
            self.discount = discount
            self.policy_net.train()
            self.target_net.eval()
            self.value_net.train()
            self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=self.lr)
            self.value_optimizer = torch.optim.Adam(self.value_net.parameters(), lr=self.lr/2)

        self.states = deque([[]]*self.num_envs)
        self.actions = deque([[]]*self.num_envs)
        self.rewards = deque([[]]*self.num_envs)
        self.rtgs = deque([[]]*self.num_envs)
        self.advantages = deque([])
        self.actor_prob = deque([[]]*self.num_envs)
        # self.rewards = deque(maxlen=self.replay_memory_size)
        self.num_test_envs = num_test_envs

    def init_hidden_test(self):
        self.hidden_batch_agent_test, self.cell_batch_agent_test = self.LoclizationNet.Lnet.init_hidden_states(
            self.num_test_envs)
        self.hidden_batch_objs_1_test, self.cell_batch_objs_1_test = self.LoclizationNet.Lnet.init_hidden_states(
            self.num_test_envs)
        self.hidden_batch_objs_2_test, self.cell_batch_objs_2_test = self.LoclizationNet.Lnet.init_hidden_states(
            self.num_test_envs)
        self.hidden_batch_objs_3_test, self.cell_batch_objs_3_test = self.LoclizationNet.Lnet.init_hidden_states(
            self.num_test_envs)

    def reset_agent(self):
        self.targets = -1 * np.ones(self.num_envs, dtype=int)
        self.decision_states = [[] for _ in range(self.num_envs)]
        self.decision_action_prob = [[] for _ in range(self.num_envs)]
        self.decision_actions = [[] for _ in range(self.num_envs)]
        self.decision_rewards = [[] for _ in range(self.num_envs)]
        self.cumulative_rewards = np.zeros(self.num_envs)
        self.hidden_batch_agent, self.cell_batch_agent = self.LoclizationNet.Lnet.init_hidden_states(self.num_envs)
        self.hidden_batch_objs_1, self.cell_batch_objs_1 = self.LoclizationNet.Lnet.init_hidden_states(
            self.num_envs)
        self.hidden_batch_objs_2, self.cell_batch_objs_2 = self.LoclizationNet.Lnet.init_hidden_states(
            self.num_envs)
        self.hidden_batch_objs_3, self.cell_batch_objs_3 = self.LoclizationNet.Lnet.init_hidden_states(
            self.num_envs)

    def reset_per_memory(self, idx):
        self.targets[idx] = -1
        self.decision_states[idx] = []
        self.decision_action_prob[idx] = [] 
        self.decision_actions[idx] = []
        self.decision_rewards[idx] = []
        self.cumulative_rewards[idx] = 0
        self.hidden_batch_agent[:, idx:idx + 1, :] = torch.zeros(1, 1, self.hidden_size).float().to(self.device)
        self.cell_batch_agent[:, idx:idx + 1, :] = torch.zeros(1, 1, self.hidden_size).float().to(self.device)
        self.hidden_batch_objs_1[:, idx:idx + 1, :] = torch.zeros(1, 1, self.hidden_size).float().to(self.device)
        self.cell_batch_objs_1[:, idx:idx + 1, :] = torch.zeros(1, 1, self.hidden_size).float().to(self.device)
        self.hidden_batch_objs_2[:, idx:idx + 1, :] = torch.zeros(1, 1, self.hidden_size).float().to(self.device)
        self.cell_batch_objs_2[:, idx:idx + 1, :] = torch.zeros(1, 1, self.hidden_size).float().to(self.device)
        self.hidden_batch_objs_3[:, idx:idx + 1, :] = torch.zeros(1, 1, self.hidden_size).float().to(self.device)
        self.cell_batch_objs_3[:, idx:idx + 1, :] = torch.zeros(1, 1, self.hidden_size).float().to(self.device)
    def init_test(self):
        self.test_targets = -1 * np.ones(self.num_test_envs)
    def compute_rtgs(self, idx):
        if len(self.rewards[idx]) == 1:
            if self.rtgs[idx]:
                self.rtgs[idx].append(self.rewards[idx].popleft())
            else:
                self.rtgs[idx] = deque([self.rewards[idx].popleft()])
            return 0
        for i in range(-2, -1-len(self.rewards[idx]), -1):
            self.rewards[idx][i] += self.discount * self.rewards[idx][i+1]
        while self.rewards[idx]:
            if self.rtgs[idx]:
                self.rtgs[idx].append(self.rewards[idx].popleft())
            else:
                self.rtgs[idx] = deque([self.rewards[idx].popleft()])
    def clear_replay_memory(self):
        self.targets = -1 * np.ones(self.num_envs, dtype=int)
        self.decision_states = [[] for _ in range(self.num_envs)]
        self.decision_action_prob = [[] for _ in range(self.num_envs)]
        self.decision_actions = [[] for _ in range(self.num_envs)]
        self.decision_rewards = [[] for _ in range(self.num_envs)]
        self.cumulative_rewards = np.zeros(self.num_envs)
        self.states = deque([[]]*self.num_envs)
        self.actions = deque([[]]*self.num_envs)
        self.rewards = deque([[]]*self.num_envs)
        self.advantages = deque([])
        self.actor_prob = deque([[]]*self.num_envs)
        self.rtgs = deque([[]]*self.num_envs)

    def store_replay_memory_MP(self, rewards, dones):
        self.cumulative_rewards += rewards
        total_num_new_records = 0
        for i, d in enumerate(dones):
            total_num_new_records += self.store_replay_memory(i, d)

        return total_num_new_records

    def store_replay_memory(self, idx, done):
        # state is a deque like [[world1 states], [world2 states], ..., [world5 states]]
        # action is a deque like [[world1 actions], [world2 actions], ..., [world5 actions]]
        num_new_records = 0
        if done:
            if self.decision_rewards[idx]:
                self.decision_rewards[idx].pop(0)
                self.decision_rewards[idx].append(self.cumulative_rewards[idx])
            num_new_records += len(self.decision_actions[idx])
            for i in range(len(self.decision_states[idx])):
                if self.states[idx]:
                    self.states[idx].append(self.decision_states[idx][i])
                else:
                    self.states[idx] = deque([self.decision_states[idx][i]])

                if self.actions[idx]:
                    self.actions[idx].append(self.decision_actions[idx][i])
                else:
                    self.actions[idx] = deque([self.decision_actions[idx][i]])
                if self.rewards[idx]:
                    self.rewards[idx].append(self.decision_rewards[idx][i])
                else:
                    self.rewards[idx] = deque([self.decision_rewards[idx][i]])
                if self.actor_prob[idx]:
                    self.actor_prob[idx].append(self.decision_action_prob[idx][i])
                else:
                    self.actor_prob[idx] = deque([self.decision_action_prob[idx][i]])
            self.reset_per_memory(idx)
        return num_new_records

    def state_estimator(self, states, test=False):
        obs_list = []
        color_vector_list = []
        backpack_list = []
        target_pos_list = [] 
        for item in states:
            current_obs = item[0]
            current_objs_type_and_color_vector = item[1]
            current_obs = current_obs.reshape(3,80,80)
            current_obs = 2*current_obs/255-1
            current_obs_torch = torch.from_numpy(current_obs).float().unsqueeze(0).unsqueeze(0).to(self.device)
            current_objs_type_and_color_vector_torch = torch.from_numpy(current_objs_type_and_color_vector).float().unsqueeze(0).to(self.device)
            current_backpack = torch.from_numpy(item[2]).int().to(self.device)
            current_target_pos = torch.from_numpy(item[3]).float().unsqueeze(0).to(self.device)
            obs_list.append(current_obs_torch)
            color_vector_list.append(current_objs_type_and_color_vector_torch)
            backpack_list.append(current_backpack)
            target_pos_list.append(current_target_pos)
        obs_list= torch.cat(obs_list, dim=0) ### num_env, 1 ,3, 80, 80
        color_vector_list= torch.cat(color_vector_list, dim=0) ### num_env, 3, 10
        backpack_list = torch.cat(backpack_list, dim=0) ### num_env, 1
        target_pos_list = torch.cat(target_pos_list, dim=0) ### num_env,3, 2

        if test:
            with torch.no_grad():
                est_state, hidden_ = self.LoclizationNet(obs_list, color_vector_list, target_pos_list, backpack_list, \
                                                self.hidden_batch_agent_test, self.cell_batch_agent_test, self.hidden_batch_objs_1_test, self.cell_batch_objs_1_test, \
                                                self.hidden_batch_objs_2_test, self.cell_batch_objs_2_test, self.hidden_batch_objs_3_test, self.cell_batch_objs_3_test)
                self.hidden_batch_agent_test, self.cell_batch_agent_test, self.hidden_batch_objs_1_test, self.cell_batch_objs_1_test, \
                self.hidden_batch_objs_2_test, self.cell_batch_objs_2_test, self.hidden_batch_objs_3_test, self.cell_batch_objs_3_test  = hidden_

        else:
            with torch.no_grad():
                est_state, hidden_ = self.LoclizationNet(obs_list, color_vector_list, target_pos_list, backpack_list,\
                                                self.hidden_batch_agent, self.cell_batch_agent, self.hidden_batch_objs_1, self.cell_batch_objs_1, \
                                                self.hidden_batch_objs_2, self.cell_batch_objs_2, self.hidden_batch_objs_3, self.cell_batch_objs_3)
                self.hidden_batch_agent, self.cell_batch_agent, self.hidden_batch_objs_1, self.cell_batch_objs_1, \
                self.hidden_batch_objs_2, self.cell_batch_objs_2, self.hidden_batch_objs_3, self.cell_batch_objs_3  = hidden_
        return est_state

    def step_MP(self, states, test=False, exploration_rate=0.0):
        actions = []
        for idx in range(states.shape[1]):
            actions.append(self.step(states[:, idx], idx, test, exploration_rate))
        return np.array(actions, dtype=int)

    def step(self, state, idx, test=False, exploration_rate=0.0):
        # this test is during training, to not erase training history like self.targets, self.decision_states()
        # different from test in __init__() in which test is for after training
        if state[-1] != -1:  # is carrying an object need to check if last one is backpack
            state = self.convert_drop_state(state, state[-1])
            actions_prob = self.drop_model(state, batch_size=1)
            distribution = Categorical(actions_prob)
            action = distribution.sample().item()
            if action == 3:
                action = 4
            return action

        if test:
            target = self.test_targets[idx]
        else:
            target = self.targets[idx]

        if target == -1:  # not carrying anything and not picking, need to decide the target object
            decision_state = self.convert_decision_state(state)
            with torch.no_grad():
                actions_prob = self.target_net(decision_state, batch_size=1)
            distribution = Categorical(actions_prob)
            target = distribution.sample().item()
            target = np.array(target, dtype=int)
            
            if test:
                self.test_targets[idx] = target
            else:
                self.targets[idx] = target
                self.decision_action_prob[idx].append(actions_prob[target].detach())
                self.decision_states[idx].append(decision_state)
                self.decision_actions[idx].append(self.targets[idx])
                self.decision_rewards[idx].append(self.cumulative_rewards[idx])

        if test:
            target = self.test_targets[idx]
        else:
            target = self.targets[idx]

        state = self.convert_pick_state(state, target)
        actions_prob = self.pick_model(state, batch_size=1)
        distribution = Categorical(actions_prob)
        action = distribution.sample().item()
        if action == 3:
            if test:
                self.test_targets[idx] = -1
            else:
                self.targets[idx] = -1
        return action

    def clear_over_step(self):
        # clear incomplete trajectories
        states = deque([])
        actions = deque([])
        actor_prob = deque([])
        rtgs = deque([])
        # print("***************************")
        # for i in range(self.num_envs):
        #     print(len(self.states[i]))
        #     print(len(self.actor_prob[i]))
        # print("***************************")
        for i in range(self.num_envs):
            for _ in range(len(self.states[i])-len(self.rtgs[i])):
                self.states[i].pop()
                self.actions[i].pop()
                self.actor_prob[i].pop()
            states += self.states[i]
            actions += self.actions[i]
            actor_prob += self.actor_prob[i]
            rtgs += self.rtgs[i]
        # merge the data
        self.states = np.array(states)
        self.actions = actions
        self.actor_prob = torch.tensor(actor_prob).to(self.device).float()
        self.rtgs = rtgs
        self.rtgs = torch.tensor(self.rtgs).to(self.device).float()
    def get_adv_func(self):
        self.clear_over_step()
    def learn(self):
        with torch.no_grad():
            self.advantages = self.rtgs - self.value_net(self.states, batch_size=len(self.states))
        exp_rtgs = self.value_net(self.states, batch_size=len(self.states))
        current_probs = self.policy_net(self.states, batch_size=len(self.states))
        current_prob = deque()
        for i, a in enumerate(self.actions):
            current_prob.append(current_probs[i][a])
        current_prob = torch.stack(list(current_prob)).float()
        ratios = (current_prob / self.actor_prob)
        surr1 = ratios * self.advantages
        surr2 = torch.clamp(ratios, 1 - self.clip, 1 + self.clip) * self.advantages
        actor_loss = (-torch.min(surr1, surr2)).mean()
        crit_loss = torch.nn.MSELoss()(exp_rtgs, self.rtgs)
        # backward
        self.optimizer.zero_grad()
        actor_loss.backward()
        self.optimizer.step()
        self.value_optimizer.zero_grad()
        crit_loss.backward()
        self.value_optimizer.step()

        return actor_loss.item(), crit_loss.item()

    def update_target(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

    def convert_pick_state(self, state, target):
        target = int(target)
        new_state = []
        agent_pos, agent_ori = state[21:23], state[23] ## get the agent pos and ori
        x_ori = math.cos(agent_ori)
        y_ori = math.sin(agent_ori)
        # 3 objects * 2 curr_pos = 6. first is the target object to pick up
        order = list(range(self.NUM_OBJECTS))
        order.remove(target)
        order.insert(0, target)
        for i in order:
            pos_mean = state[i*7:i*7+2]
            pos_var = state[i*7+2:i*7+5] / (self.ROOM_SIZE_METER**2/4)
            new_state.extend(pos_mean)
            new_state.extend(pos_var)
        # 3 agent current position and orientation
        new_state.extend([agent_pos[0] / (self.ROOM_SIZE_METER / 2) - 1,
                          agent_pos[1] / (self.ROOM_SIZE_METER / 2) - 1,
                          x_ori,
                          y_ori])
        new_state = np.array(new_state)
        return new_state
        
    def convert_drop_state(self, state, target):
        target = int(target)
        new_state = []
        agent_pos, agent_ori = state[21:23], state[23] ## get the agent pos and ori
        x_ori = math.cos(agent_ori)
        y_ori = math.sin(agent_ori)
        # agent to target position
        pos = state[target*7+5:target*7+7]

        dist = ((agent_pos[0] - pos[0]) ** 2 + (agent_pos[1] - pos[1]) ** 2) ** 0.5
        new_state.extend([pos[0] / (self.ROOM_SIZE_METER / 2) - 1,
                          pos[1] / (self.ROOM_SIZE_METER / 2) - 1,
                          dist / (self.ROOM_SIZE_METER / 2) - 1])
        direct = ((pos[0] - agent_pos[0]) / dist, (pos[1] - agent_pos[1]) / dist)
        angle_cos = direct[0] * x_ori + direct[1] * y_ori
        angle_sin = direct[0] * y_ori - direct[1] * x_ori
        new_state.extend([angle_cos,
                          angle_sin])
        # agent to other objects
        for i in range(self.NUM_OBJECTS):
            if i == target: continue
            pos = state[i*7:i*7+5]
            new_state.extend(pos)
        # 3 agent current position and orientation
        new_state.extend([agent_pos[0] / (self.ROOM_SIZE_METER / 2) - 1,
                          agent_pos[1] / (self.ROOM_SIZE_METER / 2) - 1,
                          x_ori,
                          y_ori])
        new_state = np.array(new_state)
        return new_state
    def convert_decision_state(self, state):
        new_state = []
        agent_pos, agent_ori = state[21:23], state[23] ## get the agent pos and ori
        x_ori = math.cos(agent_ori)
        y_ori = math.sin(agent_ori)
        for i in range(self.NUM_OBJECTS):
            pos = state[i*7:i*7+7]
            new_state.extend(pos)
        new_state.extend([agent_pos[0] / (self.ROOM_SIZE_METER / 2) - 1,
                          agent_pos[1] / (self.ROOM_SIZE_METER / 2) - 1,
                          x_ori,
                          y_ori])
        new_state = np.array(new_state)
        return new_state
