from utils.data import get_object_type_onehot, add_batch_dim, from_numpy, MotionData
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch
import time

class RuleBasedPlanner:
    def __init__(self, cfg):
        self.cfg = cfg
        self.cfg_rl_waymo = self.cfg.datasets.rl_waymo
        self.steps = self.cfg.nocturne.steps

    def reset(self, vehicle_data_dict):
        num_agents = len(vehicle_data_dict.keys())
        
        self.states = np.zeros((num_agents, self.steps, 8))
        self.types = np.zeros((num_agents, 5))
        self.actions = np.zeros((num_agents, self.steps, 2)) # acceleration and steering
        self.rtgs = np.zeros((num_agents, self.steps, self.cfg.train.model.num_reward_components))
        self.goals = np.zeros((num_agents, self.steps, self.cfg_rl_waymo.goal_dim))
        self.timesteps = np.zeros((num_agents, self.steps, 1))
        self.relevant_agent_idxs = {}
        self.idx_to_veh_id = {}
        self.veh_id_to_idx = {}
        for i,v in enumerate(vehicle_data_dict.keys()):
            self.idx_to_veh_id[i] = v
            self.veh_id_to_idx[v] = i

    # TODO: implement rule based planner here

    # actuate the acceleration and steering commands
    def act(self, veh, t, vehicle_data_dict):
        veh_id = veh.getID()
        veh_exists = vehicle_data_dict[veh_id]['existence'][-1]
        
        if not veh_exists:
            acceleration = 0.0
            steering = 0.0
            veh.setPosition(-1000000, -1000000)  # make cars disappear if they are out of actions
        else:
            acceleration = vehicle_data_dict[veh_id]['next_acceleration']
            steering = vehicle_data_dict[veh_id]['next_steering']

        if acceleration > 0.0:
            veh.acceleration = acceleration
        else:
            veh.brake(np.abs(acceleration))
        veh.steering = steering

        return veh, [acceleration, steering]