from __future__ import division
# from setproctitle import setproctitle as ptitle
import torch
import torch.nn
import gym
from utils import ensure_shared_grads
from model import ACMLP
from player_util import Agent
import time


def train(rank, args, shared_model, optimizers, lock, counter):
    # ptitle('Training Agent: {}'.format(rank))
    gpu_id = args.gpu_ids[rank % len(args.gpu_ids)]
    torch.manual_seed(args.seed + rank)
    if gpu_id >= 0:
        torch.cuda.manual_seed(args.seed + rank)
        
    env = gym.make(args.env_name)
    env.seed(args.seed+rank)   
    
    agent = Agent(None, env, args, None)
    agent.gpu_id = gpu_id
    agent.model = ACMLP(agent.env.action_space.n, 
                        agent.env.observation_space.shape[0])
    
    state = agent.env.reset()
    agent.state = torch.from_numpy(state).float()
    if gpu_id >= 0:
        with torch.cuda.device(gpu_id):
            agent.state = agent.state.cuda()
            agent.model = agent.model.cuda()
    agent.model.train()
    
    start_time = time.time()
    end_time = 60*args.minutes_per_run
    while (time.time() - start_time) <= end_time:
        if gpu_id >= 0:
            with torch.cuda.device(gpu_id):
                agent.model.load_state_dict(shared_model.state_dict())
        else:
            agent.model.load_state_dict(shared_model.state_dict())

        for step in range(args.num_steps):
            agent.action_train()
            if agent.done:
                break
            
        with lock:
            counter.value += 1
           
        if agent.done:
            agent.eps_len = 0
            state = agent.env.reset()
            agent.state = torch.from_numpy(state).float()
            if gpu_id >= 0:
                with torch.cuda.device(gpu_id):
                    agent.state = agent.state.cuda()

        R = torch.zeros(1, 1)
        if not agent.done:
            value, _ = agent.model(agent.state.unsqueeze(0))
            R = value.detach()

        if gpu_id >= 0:
            with torch.cuda.device(gpu_id):
                R = R.cuda()

        agent.values.append(R)
        policy_loss = 0
        value_loss = 0
        for i in reversed(range(len(agent.rewards))):
            TD_err = args.gamma * agent.values[i+1].detach() + \
                      agent.rewards[i] - agent.values[i]
            value_loss = value_loss + 0.5 * TD_err.pow(2)

            advantage = TD_err.detach()
            policy_loss = policy_loss - agent.log_probs[i] * advantage - \
                          0.01 * agent.entropies[i]

        agent.model.zero_grad()
        (policy_loss + value_loss).backward()
        torch.nn.utils.clip_grad_norm_(agent.model.parameters(), args.max_grad_norm)
        ensure_shared_grads(agent.model, shared_model, gpu=gpu_id >= 0)
        optimizers.step()
        agent.clear_actions()
