from __future__ import division
# from setproctitle import setproctitle as ptitle
import torch
from environment import atari_env
from utils import ensure_shared_grads, clip_grad_norm
from model import AClstm
from player_util import Agent
import time


def train(rank, args, shared_model, optimizer, env_conf, lock, counter):
    # ptitle('Training Agent: {}'.format(rank))
    gpu_id = args.gpu_ids[rank % len(args.gpu_ids)]
    torch.manual_seed(args.seed + rank)
    if gpu_id >= 0:
        torch.cuda.manual_seed(args.seed + rank)
        
    env = atari_env(args.env, env_conf, args)
    env.seed(args.seed + rank)
    agent = Agent(None, env, args, None)
    agent.gpu_id = gpu_id
    agent.model = AClstm(agent.env.observation_space.shape[0],
                         agent.env.action_space)

    agent.state = agent.env.reset()
    agent.state = torch.from_numpy(agent.state).float()
    if gpu_id >= 0:
        with torch.cuda.device(gpu_id):
            agent.state = agent.state.cuda()
            agent.model = agent.model.cuda()
    agent.model.train()
    
    start_time = time.time()
    end_time = 60*args.minutes_per_run
    while (time.time() - start_time) <= end_time:
        if gpu_id >= 0:
            with torch.cuda.device(gpu_id):
                agent.model.load_state_dict(shared_model.state_dict())
        else:
            agent.model.load_state_dict(shared_model.state_dict())
        if agent.done:
            if gpu_id >= 0:
                with torch.cuda.device(gpu_id):
                    agent.cx = torch.zeros(1, 512).cuda()
                    agent.hx = torch.zeros(1, 512).cuda()
            else:
                agent.cx = torch.zeros(1, 512)
                agent.hx = torch.zeros(1, 512)
        else:
            agent.cx = agent.cx.data
            agent.hx = agent.hx.data

        for step in range(args.num_steps):
            agent.action_train()
            if agent.done:
                break

        if agent.done:
            state = agent.env.reset()
            agent.state = torch.from_numpy(state).float()
            if gpu_id >= 0:
                with torch.cuda.device(gpu_id):
                    agent.state = agent.state.cuda()

        R = torch.zeros(1, 1)
        if not agent.done:
            value, _, _ = agent.model((agent.state.unsqueeze(0),
                                        (agent.hx, agent.cx)))
            R = value.detach()

        if gpu_id >= 0:
            with torch.cuda.device(gpu_id):
                R = R.cuda()
                
        agent.values.append(R)
        policy_loss = 0
        value_loss = 0
        gae = torch.zeros(1, 1)
        if gpu_id >= 0:
            with torch.cuda.device(gpu_id):
                gae = gae.cuda()
        for i in reversed(range(len(agent.rewards))):
            TD_err = args.gamma * agent.values[i+1].data + \
                      agent.rewards[i] - agent.values[i]
            value_loss = value_loss + 0.5 * TD_err.pow(2)

            gae = TD_err.data
            policy_loss = policy_loss - agent.log_probs[i] * gae - \
                          0.01 * agent.entropies[i]


        agent.model.zero_grad()
        (policy_loss + value_loss).backward()
        clip_grad_norm(agent.model.parameters(), 0.5)
        ensure_shared_grads(agent.model, shared_model, gpu=gpu_id >= 0)
        optimizer.step()
        with lock:
            counter.value += 1
        agent.clear_actions()
