from __future__ import division
# from setproctitle import setproctitle as ptitle
import torch
from synthetic_env import synthetic_env
from utils import ensure_shared_grads
from model import AClinear
from player_util import Agent
import time


def train(rank, args, shared_model, optimizers, env_conf, lock, counter):
    # ptitle('Training Agent: {}'.format(rank))
    gpu_id = args.gpu_ids[rank % len(args.gpu_ids)]
    torch.manual_seed(args.seed + rank)
    if gpu_id >= 0:
        torch.cuda.manual_seed(args.seed + rank)
        
    env = synthetic_env(env_conf["state_space"], env_conf["action_space"], 
                        env_conf["state_dim"], args.gamma, 0)
    env.seed(args.seed+rank)   
    
    agent = Agent(None, env, args, None, None)
    agent.gpu_id = gpu_id
    agent.model = AClinear(agent.env.state_space, agent.env.action_space, 
                            agent.env.state_dim)
    
    state, state_onehot = agent.env.reset()
    agent.state = torch.from_numpy(state).float()
    agent.state_onehot = torch.from_numpy(state_onehot).float()
    if gpu_id >= 0:
        with torch.cuda.device(gpu_id):
            agent.state = agent.state.cuda()
            agent.state_onehot = agent.state_onehot.cuda()
            agent.model = agent.model.cuda()
    agent.model.train()
    
    start_time = time.time()
    end_time = 60*args.minutes_per_run
    while (time.time() - start_time) <= end_time:
        if gpu_id >= 0:
            with torch.cuda.device(gpu_id):
                agent.model.load_state_dict(shared_model.state_dict())
        else:
            agent.model.load_state_dict(shared_model.state_dict())

        for step in range(args.num_steps):
            agent.action_train()
            if agent.done:
                break
            
        with lock:
            counter.value += 1
           
        if agent.done:
            agent.eps_len = 0
            state, state_onehot = agent.env.reset()
            agent.state = torch.from_numpy(state).float()
            agent.state_onehot = torch.from_numpy(state_onehot).float()
            if gpu_id >= 0:
                with torch.cuda.device(gpu_id):
                    agent.state = agent.state.cuda()
                    agent.state_onehot = agent.state_onehot.cuda()

        R = torch.zeros(1, 1)
        if not agent.done:
            value, _ = agent.model((agent.state.unsqueeze(0), 
                                   agent.state_onehot.unsqueeze(0)))
            R = value.detach()

        if gpu_id >= 0:
            with torch.cuda.device(gpu_id):
                R = R.cuda()

        agent.values.append(R)
        policy_loss = 0
        value_loss = 0
        for i in reversed(range(len(agent.rewards))):
            TD_err = args.gamma * agent.values[i+1].detach() + \
                agent.rewards[i] - agent.values[i]
            value_loss = value_loss + 0.5 * TD_err.pow(2)

            advantage = TD_err.detach()
            policy_loss = policy_loss - agent.log_probs[i] * advantage

        agent.model.zero_grad()
        (policy_loss + value_loss).backward()
        ensure_shared_grads(agent.model, shared_model, gpu=gpu_id >= 0)
        optimizers.step()
        agent.clear_actions()
