from __future__ import division
# from setproctitle import setproctitle as ptitle
import torch
import numpy as np
import gym
from model import ACMLP
from player_util import Agent
import time
from utils import setup_logger
import logging
import re


def test_(rank, args, shared_model, counter, num_run):
    # ptitle('Test Agent')
    gpu_id = args.gpu_ids[-1]
    
    log_name = '{0}-worker{1}-lr{2}-tts'.format(
        args.env_name, str(args.workers),str(args.lr))
    setup_logger('{}_log'.format(log_name), r'{0}{1}_log'.format(
        args.log_dir, log_name)) 
    log = {}
    log['{}_log'.format(log_name)] = logging.getLogger('{}_log'.format(
        log_name))
    d_args = vars(args)
    for k in d_args.keys():
        log['{}_log'.format(log_name)].info('{0}: {1}'.format(k, d_args[k]))
 
    torch.manual_seed(rank + args.seed)
    if gpu_id >= 0:
        torch.cuda.manual_seed(rank + args.seed)
    
    env = gym.make(args.env_name)
    env.seed(rank+args.seed)
    
    reward_sum = 0
    num_tests = 0
    reward_total_sum = 0

    agent = Agent(None, env, args, None)
    agent.gpu_id = gpu_id
    agent.model = ACMLP(agent.env.action_space.n, 
                        agent.env.observation_space.shape[0])

    state = agent.env.reset()
    agent.state = torch.from_numpy(state).float()
    if gpu_id >= 0:
        with torch.cuda.device(gpu_id):
            agent.model = agent.model.cuda()
            agent.state = agent.state.cuda()
            
    flag = True
    start_time = time.time()
    end_time = 60*args.minutes_per_run
    while (time.time() - start_time) <= end_time:
        if flag:
            model_counter = counter.value
            model_time = time.time() - start_time
            if gpu_id >= 0:
                with torch.cuda.device(gpu_id):
                    agent.model.load_state_dict(shared_model.state_dict())
            else:
                agent.model.load_state_dict(shared_model.state_dict())
            agent.model.eval()
            flag = False

        agent.action_test()
        reward_sum += agent.reward

        if agent.done:
            flag = True
            num_tests += 1
            reward_total_sum += reward_sum
            reward_mean = reward_total_sum / num_tests
            
            log['{}_log'.format(log_name)].info(
                "Num run {0}, time {1}, num steps {2}, episode length {3}, episode reward {4:.4f}, reward mean {5:.4f}".
                format(num_run,
                    time.strftime("%Hh %Mm %Ss",
                                  time.gmtime(model_time)),
                    model_counter, agent.eps_len, reward_sum, reward_mean))
            data = np.array([num_run, model_time, model_counter, agent.eps_len, 
                             reward_sum, reward_mean])[np.newaxis,:]
            with open('{0}/{1}_logdata.sav'.format(args.log_dir, log_name), 'a') as f:
                np.savetxt(f, data)
            
            reward_sum = 0
            agent.eps_len = 0
            state = agent.env.reset()
            agent.state = torch.from_numpy(state).float()
            if gpu_id >= 0:
                with torch.cuda.device(gpu_id):
                    agent.state = agent.state.cuda()             
            time.sleep(1.)


def test_itr(rank, args, shared_model, counter, num_run):
    # ptitle('Test Agent') 
    gpu_id = args.gpu_ids[-1]
    log_name = '{0}-worker{1}-lr{2}-tts(iter)'.format(
        args.env_name, str(args.workers),str(args.lr))
    # setup_logger('{}_log'.format(log_name), r'{0}{1}_log'.format(
    #     args.log_dir, log_name)) 
    # log = {}
    # log['{}_log'.format(log_name)] = logging.getLogger('{}_log'.format(
    #     log_name))
 
    torch.manual_seed(rank + args.seed)
    if gpu_id >= 0:
        torch.cuda.manual_seed(rank + args.seed)
        
    env = gym.make(args.env_name)
    env.seed(rank+args.seed)
     
    reward_sum = 0
    num_tests = 0
    reward_total_sum = 0
    model_counter = -1
    
    agent = Agent(None, env, args, None)
    agent.model = ACMLP(agent.env.action_space.n, 
                        agent.env.observation_space.shape[0])

    state = agent.env.reset()
    agent.state = torch.from_numpy(state).float()
    if gpu_id >= 0:
        with torch.cuda.device(gpu_id):
            agent.model = agent.model.cuda()
            agent.state = agent.state.cuda()
            
    start_time = time.time()
    end_time = 60*args.minutes_per_run
    while (time.time() - start_time) <= end_time:
        if counter.value != model_counter and divmod(counter.value, args.test_itr_interval)[1]==0:
            model_counter = counter.value
            model_time = time.time() - start_time
            if gpu_id >= 0:
                with torch.cuda.device(gpu_id):
                    agent.model.load_state_dict(shared_model.state_dict())
            else:
                agent.model.load_state_dict(shared_model.state_dict())
            agent.model.eval()
            
            while True:
                agent.action_test()
                reward_sum += agent.reward
                if agent.done:
                    break
    
            num_tests += 1
            reward_total_sum += reward_sum
            reward_mean = reward_total_sum / num_tests
            
            # log['{}_log'.format(log_name)].info(
            #     "Num run {0}, time {1}, num steps {2}, episode length {3}, episode reward {4:.4f}, reward mean {5:.4f}".
            #     format(num_run,
            #         time.strftime("%Hh %Mm %Ss",
            #                       time.gmtime(model_time)),
            #         model_counter, agent.eps_len, reward_sum, reward_mean))
            data = np.array([num_run, model_time, model_counter, agent.eps_len, 
                             reward_sum, reward_mean])[np.newaxis,:]
            with open('{0}/{1}_logdata.sav'.format(args.log_dir, log_name), 'a') as f:
                np.savetxt(f, data)
            
            reward_sum = 0
            agent.eps_len = 0
            state = agent.env.reset()
            agent.state = torch.from_numpy(state).float()
            
                    


