from __future__ import division
import torch
import torch.nn.functional as F


class Agent(object):
    def __init__(self, model, env, args, state):
        self.model = model
        self.env = env
        self.state = state
        self.eps_len = 0
        self.args = args
        self.values = []
        self.log_probs = []
        self.entropies = []
        self.rewards = []
        self.done = True
        self.info = None
        self.reward = 0
        self.gpu_id = -1

    def action_train(self):
        value, logit = self.model(self.state.unsqueeze(0))
        prob = F.softmax(logit, dim=-1)
        log_prob = F.log_softmax(logit, dim=-1)
        action = prob.multinomial(1).detach()
        log_prob = log_prob.gather(1, action)
        entropy = -(log_prob * prob).sum(1)
        self.entropies.append(entropy)
        
        state, self.reward, self.done, self.info = self.env.step(
            action.cpu().numpy().item())
        self.state = torch.from_numpy(state).float()
        if self.gpu_id >= 0:
            with torch.cuda.device(self.gpu_id):
                self.state = self.state.cuda()
                
        self.reward = max(min(self.reward, 1), -1)
        self.values.append(value)
        self.log_probs.append(log_prob)
        self.rewards.append(self.reward)
        self.eps_len += 1
        return self

    def action_test(self):
        logit = self.model.actor_forward(self.state.unsqueeze(0))
        prob = F.softmax(logit, dim=-1)
        action = prob.max(1)[1].data.cpu().numpy()
        state, self.reward, self.done, self.info = self.env.step(action.item())
        # self.env.render()
        self.state = torch.from_numpy(state).float()
        if self.gpu_id >= 0:
            with torch.cuda.device(self.gpu_id):
                self.state = self.state.cuda()
        self.eps_len += 1
        return self

    def clear_actions(self):
        self.values = []
        self.log_probs = []
        self.rewards = []
        self.entropies = []
        return self
