import torch as th
import numpy as np


class SamplerAgent(object):
    def __init__(
            self,
            sampler,
            sampler_method='prob'
    ):
        self.sampler = sampler
        self.sampler_method = sampler_method
        self.pnet = None

    def process_reward(self, reward):
        return reward

    def process_action_mask(self, obs, action_nums, use_mask=False):
        mask_list = []
        if len(action_nums) > 0:
            for action_num in action_nums:
                mask = [1 for _ in range(action_num)]
                mask = th.from_numpy(np.asarray(mask, dtype=np.float32))
                mask = (1 - mask).to(th.bool)
                mask_list.append(mask)
        else:
            mask = th.from_numpy(np.asarray(action_nums[0] * action_nums[1] * [1], dtype=np.float32))
            mask = (1 - mask).to(th.bool)
            mask_list.append(mask)

        return mask_list

    def fetch_model_parameters(self, models):
        self.pnet = th.jit.load(models, map_location=th.device('cpu'))

    def get_action(self, prob):
        if self.sampler_method == 'prob':
            action = prob.multinomial(1)
        elif self.sampler_method == 'argmax':
            action = prob.argmax(1, keepdim=True)
        elif self.sampler_method == 'random':
            action = np.random.randint(prob.shape[1], size=prob.shape[0])
        else:
            action = prob.multinomial(1)
        return action

    def get_model_result(self, states, styles):
        with th.no_grad():
            if len(states.shape) == 1:
                states = th.unsqueeze(states, 0)
            if len(styles.shape) == 1:
                styles = th.unsqueeze(styles, 0)

            prob, log_prob, state_value = self.pnet(states, styles.float())
            
            action = self.get_action(prob)
            return action, prob, log_prob, state_value
