import numpy as np
import torch
from baseline_policy.ai_policy.r_actor import R_Actor

class Policy():
    def __init__(self,map_name):
        device = torch.device("cuda:0")

        if map_name == 'many_orders':
            self.actor = R_Actor((5, 5, 26), 6, device)
        if map_name == 'random3':
            self.actor = R_Actor((8, 5, 20), 6, device)
        if map_name == 'distant_tomato':
            self.actor = R_Actor((5, 7, 26), 6, device)
        if map_name == 'soup_coordination':
            self.actor = R_Actor((11, 5, 26), 6, device)
        if map_name == 'unident_s':
            self.actor = R_Actor((9, 5, 20), 6, device)
        self.device = device
        
    def act(self, obs, rnn_states_actor, masks, available_actions=None, deterministic=False):
        actions, _, rnn_states_actor = self.actor(obs, rnn_states_actor, masks, available_actions, deterministic)
        return actions, rnn_states_actor
        
    def load_checkpoint(self, ckpt_path):
        self.actor.load_state_dict(torch.load(ckpt_path, map_location=self.device))
        
    def prep_rollout(self):
        self.actor.eval()
    
    def predict():
        pass