from __future__ import division
# from setproctitle import setproctitle as ptitle
import torch
import torch.nn.functional as F
import numpy as np
from synthetic_env import synthetic_env
from model import AClinear
from player_util import Agent
import time
from utils import setup_logger
import logging


def test_(rank, args, env_conf, shared_model, counter, num_run):
    # ptitle('Test Agent') 
    log_name = 'worker{0}-lr{1}-synthetic-iid'.format(
        str(args.workers),str(args.lr))
    setup_logger('{}_log'.format(log_name), r'{0}{1}_log'.format(
        args.log_dir, log_name)) 
    log = {}
    log['{}_log'.format(log_name)] = logging.getLogger('{}_log'.format(
        log_name))
    d_args = vars(args)
    for k in d_args.keys():
        log['{}_log'.format(log_name)].info('{0}: {1}'.format(k, d_args[k]))
 
    torch.manual_seed(rank + args.seed)
        
    env = synthetic_env(env_conf["state_space"], env_conf["action_space"], 
                        env_conf["state_dim"], args.gamma, 0, False)
    env.seed(rank+args.seed)
    env_td_opt = synthetic_env(env_conf["state_space"], env_conf["action_space"], 
                        env_conf["state_dim"], args.gamma, 0, True)
    
    reward_sum = 0
    num_tests = 0
    reward_total_sum = 0
    critic_total_sum = 0  
    
    agent = Agent(None, env, args, None, None)
    agent.model = AClinear(env.state_space, env.action_space, env.state_dim)

    state, state_onehot = agent.env.reset()
    agent.state = torch.from_numpy(state).float()
    agent.state_onehot = torch.from_numpy(state_onehot).float()
            
    flag = True
    start_time = time.time()
    end_time = 60*args.minutes_per_run
    while (time.time() - start_time) <= end_time:
        if flag:
            model_counter = counter.value
            model_time = time.time() - start_time
            agent.model.load_state_dict(shared_model.state_dict())
            agent.model.eval() 
            flag = False

        agent.action_test()
        reward_sum += agent.reward

        if agent.eps_len >= args.test_episode_length:
            flag = True
            num_tests += 1
            reward_total_sum += reward_sum
            reward_mean = reward_total_sum / num_tests
                
            _, logit = agent.model((agent.state.unsqueeze(0),
                                   torch.eye(env_td_opt.state_space)))
            prob = F.softmax(logit, dim=-1).detach().numpy().T
            w_opt = env_td_opt.get_opt(prob)
            for p in agent.model.critic_linear.parameters():
                w = p.data.numpy()
            critic_gap = np.linalg.norm(w-w_opt)
            critic_total_sum += critic_gap
            critic_mean = critic_total_sum / num_tests
            
            log['{}_log'.format(log_name)].info(
                "Num run {0}, time {1}, num steps {2}, episode length {3}, episode reward {4:.4f}, reward mean {5:.4f}, critic gap {6:.4f}, gap mean {7:.4f}".
                format(num_run,
                    time.strftime("%Hh %Mm %Ss",
                                  time.gmtime(model_time)),
                    model_counter, agent.eps_len, reward_sum, reward_mean, critic_gap, critic_mean))
            data = np.array([num_run, model_time, model_counter, agent.eps_len, 
                             reward_sum, reward_mean, critic_gap, critic_mean])[np.newaxis,:]
            with open('{0}/{1}_logdata.sav'.format(args.log_dir, log_name), 'a') as f:
                np.savetxt(f, data)
            
            reward_sum = 0
            agent.eps_len = 0
            state, state_onehot = agent.env.reset()
            agent.state = torch.from_numpy(state).float()
            agent.state_onehot = torch.from_numpy(state_onehot).float()
                    
            time.sleep(1.)
            
def test_itr(rank, args, env_conf, shared_model, counter, num_run):
    # ptitle('Test Agent') 
    log_name = 'worker{0}-lr{1}-synthetic-iid(iter)'.format(
        str(args.workers),str(args.lr))
    # setup_logger('{}_log'.format(log_name), r'{0}{1}_log'.format(
    #     args.log_dir, log_name)) 
    # log = {}
    # log['{}_log'.format(log_name)] = logging.getLogger('{}_log'.format(
    #     log_name))
 
    torch.manual_seed(rank + args.seed)
        
    env = synthetic_env(env_conf["state_space"], env_conf["action_space"], 
                        env_conf["state_dim"], args.gamma, 0, False)
    env.seed(rank+args.seed)
    env_td_opt = synthetic_env(env_conf["state_space"], env_conf["action_space"], 
                        env_conf["state_dim"], args.gamma, 0, True)
    
    reward_sum = 0
    num_tests = 0
    reward_total_sum = 0
    critic_total_sum = 0
    model_counter = -1
    
    agent = Agent(None, env, args, None, None)
    agent.model = AClinear(env.state_space, env.action_space, env.state_dim)

    state, state_onehot = agent.env.reset()
    agent.state = torch.from_numpy(state).float()
    agent.state_onehot = torch.from_numpy(state_onehot).float()
            
    flag = True
    start_time = time.time()
    end_time = 60*args.minutes_per_run
    while (time.time() - start_time) <= end_time:
        if counter.value != model_counter and divmod(counter.value, args.test_itr_interval)[1]==0:
            model_counter = counter.value
            model_time = time.time() - start_time
            agent.model.load_state_dict(shared_model.state_dict())
            agent.model.eval()
            
            while agent.eps_len < args.test_episode_length:
                agent.action_test()
                reward_sum += agent.reward
    
            num_tests += 1
            reward_total_sum += reward_sum
            reward_mean = reward_total_sum / num_tests
                
            _, logit = agent.model((agent.state.unsqueeze(0),
                                   torch.eye(env_td_opt.state_space)))
            prob = F.softmax(logit, dim=-1).detach().numpy().T
            w_opt = env_td_opt.get_opt(prob)
            for p in agent.model.critic_linear.parameters():
                w = p.data.numpy()
            critic_gap = np.linalg.norm(w-w_opt)
            critic_total_sum += critic_gap
            critic_mean = critic_total_sum / num_tests
            
            # log['{}_log'.format(log_name)].info(
            #     "Num run {0}, time {1}, num steps {2}, episode length {3}, episode reward {4:.4f}, reward mean {5:.4f}, critic gap {6:.4f}, gap mean {7:.4f}".
            #     format(num_run,
            #         time.strftime("%Hh %Mm %Ss",
            #                       time.gmtime(model_time)),
            #         model_counter, agent.eps_len, reward_sum, reward_mean, critic_gap, critic_mean))
            data = np.array([num_run, model_time, model_counter, agent.eps_len, 
                             reward_sum, reward_mean, critic_gap, critic_mean])[np.newaxis,:]
            with open('{0}/{1}_logdata.sav'.format(args.log_dir, log_name), 'a') as f:
                np.savetxt(f, data)
            
            reward_sum = 0
            agent.eps_len = 0
            state, state_onehot = agent.env.reset()
            agent.state = torch.from_numpy(state).float()
            agent.state_onehot = torch.from_numpy(state_onehot).float()

