from __future__ import division
# from setproctitle import setproctitle as ptitle
import torch
import torch.nn.functional as F
from synthetic_env import synthetic_env
from utils import ensure_shared_grads
from model import AClinear
from player_util import Agent
import time


def train(rank, args, shared_model, optimizers, env_conf, lock, counter):
    # ptitle('Training Agent: {}'.format(rank))
    torch.manual_seed(args.seed + rank)
        
    env = synthetic_env(env_conf["state_space"], env_conf["action_space"], 
                        env_conf["state_dim"], args.gamma, 0, True)
    env.seed(args.seed+rank)   
    
    agent = Agent(None, env, args, None, None)
    agent.model = AClinear(agent.env.state_space, agent.env.action_space, 
                            agent.env.state_dim)
    
    agent.model.train()
    
    start_time = time.time()
    end_time = 60*args.minutes_per_run
    while (time.time() - start_time) <= end_time:
        
        agent.model.load_state_dict(shared_model.state_dict())
        _, logit = agent.model((torch.zeros(1, agent.env.state_dim), 
                                torch.eye(agent.env.state_space)))
        pi = F.softmax(logit, dim=-1)
        agent.mu,_,_ = agent.env.get_mu(pi.detach().numpy().T)
        
        for step in range(args.num_steps):
            agent.action_train()
            if agent.done:
                break
            
        with lock:
            counter.value += 1
           
        if agent.done:
            agent.eps_len = 0

        policy_loss = 0
        value_loss = 0
        for i in reversed(range(len(agent.rewards))):
            TD_err = args.gamma * agent.values_[i].detach() + \
                agent.rewards[i] - agent.values[i]
            value_loss = value_loss + 0.5 * TD_err.pow(2)

            advantage = TD_err.detach()
            policy_loss = policy_loss - agent.log_probs[i] * advantage

        agent.model.zero_grad()
        (policy_loss + value_loss).backward()
        ensure_shared_grads(agent.model, shared_model)
        optimizers.step()
        agent.clear_actions()
