
import time
import numpy as np
import logger
import sys
from collections import deque
from memory import ReplayBuffer
from agent import Model
sys.path.append('..')
from runner import Runner

def learn(policy, train_env, test_env, total_timesteps, lr_ac, lr_cr, add_t, ircr, reg, scale_reward, policy_freq, buffer_size, reward_n,
          drop_type, drop_r, rand_r, batch_size=128, log_interval=4000, target_update_interval=1, ntrain=1, gamma=0.99, tau=0.005):

    total_timesteps = int(total_timesteps)

    ob_space = train_env.observation_space.shape
    ac_space = train_env.action_space
    if add_t:
        ob_space = (ob_space[0]+2, )

    replay_buffer = ReplayBuffer(limit=int(buffer_size / reward_n), reward_n=reward_n, ircr=ircr, batch_size=batch_size,
                                 action_shape=ac_space.shape, observation_shape=ob_space)

    model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, batch_size=batch_size, max_ac=train_env.action_space.high[0], reg=reg, reward_n=reward_n,
                scale_reward=scale_reward, tau=tau, gamma=gamma, lr_ac=lr_ac, lr_cr=lr_cr, policy_freq=policy_freq, target_update_interval=target_update_interval)

    runner = Runner(train_env=train_env, test_env=test_env, model=model, buffer=replay_buffer, add_t=add_t,
                    reward_n=reward_n, ob_shape=ob_space, drop_type=drop_type, drop_r=drop_r, rand_r=rand_r)

    eplen = deque(maxlen=10)
    eprex = deque(maxlen=10)
    tfirststart = time.time()
    iter = 0

    for step in range(1, total_timesteps+1):
        logging, log = (step % log_interval == 0), None
        train = replay_buffer.can_sample(int(3E3 / reward_n))
        ep_r_ex, ep_len = runner.run(policy=train)
        eplen.extend(ep_len)
        eprex.extend(ep_r_ex)
        if train:
            for i in range(ntrain):
                batch = replay_buffer.sample(batch_size)
                log = model.train(iter=iter, batch=batch, log=(logging and i == ntrain-1))
                iter += 1

        tnow = time.time()
        time_left = (tnow - tfirststart) * (total_timesteps - step) / (step * 3600)

        if logging and step >= 2E4:
            test_ep_r, test_ep_l = runner.eval()
            logger.logkv("total_timesteps", step)
            logger.logkv("time_left", time_left)
            logger.logkv('eprewardmean', safemean([i for i in eprex]))
            logger.logkv('eplenmean', safemean([i for i in eplen]))
            logger.logkv('test_reward_mean', test_ep_r)
            logger.logkv('test_len_mean', test_ep_l)
            logger.logkv("mu_mean", float(log[1]))
            logger.logkv("q_mean", float(log[0]))
            logger.logkv("td_loss", float(log[2]))
            logger.logkv("r_loss", float(log[3]))

            logger.dumpkvs()

    train_env.close()
    test_env.close()

def safemean(xs):
    return np.nan if len(xs) == 0 else np.mean(xs)
