from __future__ import division
import torch
import torch.nn.functional as F


class Agent(object):
    def __init__(self, model, env, args, state, state_onehot):
        self.model = model
        self.env = env
        self.state = state
        self.state_onehot = state_onehot
        self.eps_len = 0
        self.args = args
        self.values = []
        self.values_ = []
        self.log_probs = []
        self.rewards = []
        self.done = True
        self.reward = 0
        self.mu = None

    def action_train(self):
  
        state, state_onehot = self.env.set_stationary_state(self.mu)
        state = torch.from_numpy(state).float()
        state_onehot = torch.from_numpy(state_onehot).float()
        value, logit = self.model((state.unsqueeze(0), 
                                   state_onehot.unsqueeze(0)))
        prob = F.softmax(logit, dim=-1)
        log_prob = F.log_softmax(logit, dim=-1)
        action = prob.multinomial(1).detach()
        log_prob = log_prob.gather(1, action)
        
        state_, self.reward, state_onehot_, self.done = self.env.step(
            action.cpu().numpy())
        if not self.done:  
            state_ = torch.from_numpy(state_).float()
            state_onehot_ = torch.from_numpy(state_onehot_).float()
            value_, _ = self.model((state_.unsqueeze(0), 
                                    state_onehot_.unsqueeze(0)))
        else:
            value_ = torch.zeros(1,1)
            
        self.reward = max(min(self.reward, 1), -1)
        self.values.append(value)
        self.values_.append(value_)
        self.log_probs.append(log_prob)
        self.rewards.append(self.reward)
        self.eps_len += 1
        return self

    def action_test(self):
        value, logit = self.model((self.state.unsqueeze(0), 
                                  self.state_onehot.unsqueeze(0)))
        prob = F.softmax(logit, dim=-1)
        action = prob.max(1)[1].data.cpu().numpy()
        state, self.reward, state_onehot, self.done = self.env.step(action[0])
        self.state = torch.from_numpy(state).float()
        self.state_onehot = torch.from_numpy(state_onehot).float()
        self.eps_len += 1
        return self

    def clear_actions(self):
        self.values = []
        self.values_ = []
        self.log_probs = []
        self.rewards = []
        return self
