import numpy as np


class Runner(object):

    def __init__(self, train_env, test_env, model, buffer, reward_n, ob_shape, add_t, drop_type, drop_r, rand_r):
        self.env = train_env
        self.eval_env = test_env
        self.model = model
        self.buffer = buffer
        # env parameters
        self.drop_type = drop_type
        self.drop_r = drop_r
        self.delay_low = int((1.-rand_r) * reward_n)

        self.delay = np.random.random_integers(self.delay_low, reward_n)
        self.reward_signal = False
        self.obs = np.zeros(ob_shape, dtype=np.float32)
        self.ob_shape = ob_shape

        self.add_t = add_t
        if self.add_t:
            self.obs[:-2] = self.env.reset()
        else:
            self.obs[:] = self.env.reset()

        self.max_ac = self.env.action_space.high[0]
        self.reward_n = reward_n
        self.delay_r_ex = np.zeros([reward_n])
        self.delay_step = np.zeros([1])
        self.ep_r_ex = np.zeros([1])
        self.ep_len = np.zeros([1])

    def run(self, policy=True):

        ep_r_ex, ep_len = [], []
        ac = self.model.policy.step((self.obs).reshape((1, -1)))[0]
        if not policy:
            ac = self.env.action_space.sample()
        n_obs, r_ex, dones, _ = self.env.step(ac)
        self.delay_r_ex[int(self.delay_step)] += r_ex
        self.delay_step += 1

        if dones or int(self.delay_step[0]) >= int(self.delay):
            r_ex = self.env_reward(self.delay_r_ex, self.delay_step)
            self.delay_step = np.zeros([1])
            self.delay_r_ex = np.zeros([self.reward_n])
            self.delay = np.random.random_integers(self.delay_low, self.reward_n)
            self.reward_signal = True
        else:
            r_ex = 0.
            self.reward_signal = False

        self.ep_r_ex += r_ex
        self.ep_len += 1

        n_obs = self.add_time(n_obs, self.delay_step)

        done = dones and (self.ep_len < 1000)
        self.buffer.append(self.obs, ac, r_ex, n_obs, float(done), self.reward_signal)

        if dones:
            self.obs = self.obs * 0.
            ep_r_ex.append(self.ep_r_ex.copy())
            ep_len.append(self.ep_len.copy())
            self.ep_r_ex[0], self.ep_len[0] = 0, 0
            if self.add_t:
                self.obs[:-2] = self.env.reset()
            else:
                self.obs[:] = self.env.reset()
        else:
            self.obs = n_obs

        return ep_r_ex, ep_len

    def eval(self, num_episodes=7):
        test_ep_reward = list()
        test_ep_len = list()
        for _ in range(num_episodes):
            # setup
            ep_reward = np.zeros([1])
            ep_len = np.zeros([1])
            done = False
            delay = np.random.random_integers(self.delay_low, self.reward_n)
            delay_step = np.zeros([1])
            delay_reward = np.zeros([self.reward_n])
            obs = np.zeros(self.ob_shape, dtype=np.float32)
            if self.add_t:
                obs[:-2] = self.eval_env.reset()
            else:
                obs[:] = self.eval_env.reset()

            while not done:
                ac = self.model.policy.step(obs.reshape(1, -1), test=True)[0]
                n_obs, r_ex, done, _ = self.eval_env.step(ac)
                delay_reward[int(delay_step)] += r_ex
                delay_step += 1

                if int(delay_step) >= int(delay) or done:
                    ep_reward += self.env_reward(delay_reward, delay_step)
                    delay_step = np.zeros([1])
                    delay_reward = np.zeros([self.reward_n])
                    delay = np.random.random_integers(self.delay_low, self.reward_n)

                obs = self.add_time(n_obs, delay_step)
                ep_len += 1

            test_ep_len.append(ep_len)
            test_ep_reward.append(ep_reward)

        return np.mean(np.array(test_ep_reward)), np.mean(np.array(test_ep_len))

    def add_time(self, obs, delay_step):
        if self.add_t:
            obs = np.concatenate((obs, 5 * delay_step.copy() / self.reward_n))
            return np.concatenate((obs, -4 * delay_step.copy() / self.reward_n))
        else:
            return obs

    def env_reward(self, delay_r, delay_step):
        if self.drop_type == 'Sum':
            return np.sum(delay_r[:])
        else:
            raise NotImplementedError