import os
import pickle
import tensorflow as tf
import numpy as np
import gym
import joblib
import argparse
import matplotlib.pyplot as plt
from tensorflow.distributions import Normal


def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var,
                                       batch_count):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta * batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count


class RunningMeanStd(object):
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = np.zeros(shape, 'float64')
        self.var = np.ones(shape, 'float64')
        self.count = epsilon

    def load(self, path, var_type='obs'):
        # load_path = os.path.join('../baselines/models', path)
        para = joblib.load(path)
        print(para.keys())
        if var_type == 'obs':
            self.mean = np.array(para['ob_rms/mean:0'])
            self.var = np.array(para['ob_rms/std:0'])
        elif var_type == 'ret':
            self.mean = np.array(para['ret_rms/mean:0'])
            self.var = np.array(para['ret_rms/std:0'])

    def load_npz(self, path):
        para = np.load(path)
        self.mean = para['arr_0']
        self.var = para['arr_1']

    def update(self, x):
        batch_mean = np.mean(x, axis=0)
        batch_var = np.var(x, axis=0)
        batch_count = x.shape[0]
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var,
            batch_count)


class NormalizeEnv(object):
    def __init__(self, env_name, gamma=0.99, update=True, random_seed=6666):
        self.update = update
        self.env = gym.make(env_name)
        self.seed(random_seed)

        self.observation_space = self.env.observation_space
        self.action_space = self.env.action_space
        self._max_episode_steps = self.env._max_episode_steps

        self.obs_rms = RunningMeanStd(shape=self.observation_space.shape[0])

        self.env_name = env_name
        self.clip = 10
        self.epsilon = 1e-8

    def seed(self, random_seed):
        return self.env.seed(random_seed)

    def get_obs(self, state):
        obs = np.clip((state - self.obs_rms.mean) /
                      np.sqrt(self.obs_rms.var + self.epsilon), -self.clip,
                      self.clip)
        return obs

    def obs_normalize(self, state):
        if self.update:
            self.obs_rms.update(state)
        obs = np.clip((state - self.obs_rms.mean) /
                      np.sqrt(self.obs_rms.var + self.epsilon), -self.clip,
                      self.clip)
        return obs

    def reset(self):
        state = self.env.reset()
        obs = self.get_obs(state)
        return state, obs

    def step(self, act):
        state, rew, done, info = self.env.step(act)
        obs = self.get_obs(state)
        return state, obs, rew, done, info

    def render(self, mode=None):
        return self.env.render(mode)


class VanillaEnv(object):
    def __init__(self, env_name, gamma=0.99, random_seed=6666):
        self.env_name = env_name
        self.env = gym.make(env_name)

        self.observation_space = self.env.observation_space
        self.action_space = self.env.action_space
        self._max_episode_steps = self.env._max_episode_steps

        self.seed(random_seed)

    def seed(self, random_seed):
        return self.env.seed(random_seed)

    def reset(self):
        return self.env.reset()

    def step(self, act):
        return self.env.step(act)

    def render(self, mode=None):
        return self.env.render(mode)

    def set_state(self, delta):
        position = self.env.sim.data.qpos.copy()
        velocity = self.env.sim.data.qvel.copy()
        if self.env_name == 'Swimmer-v2':
            new_pos = np.concatenate((position.flat[:2],
                                      position.flat[2:] + delta[0, :3]))
            new_vel = velocity.flat + delta[0, 3:]
        elif self.env_name == 'Hopper-v2':
            new_pos = np.concatenate((position.flat[:1],
                                      position.flat[1:] + delta[0, :5]))
            new_vel = velocity.flat + delta[0, 5:]
        elif self.env_name == 'Walker2d-v2':
            new_pos = np.concatenate((position.flat[:1],
                                      position.flat[1:] + delta[0, :8]))
            new_vel = velocity.flat + delta[0, 8:]
        elif self.env_name == 'HalfCheetah-v2':
            new_pos = np.concatenate((position.flat[:1],
                                      position.flat[1:] + delta[0, :8]))
            new_vel = velocity.flat + delta[0, 8:]
        self.env.set_state(new_pos, new_vel)

    def get_state(self):
        position = self.env.sim.data.qpos.flat.copy()
        velocity = self.env.sim.data.qvel.flat.copy()

        if self.env_name == 'Swimmer-v2':
            position = position[2:]
        elif self.env_name == 'Hopper-v2':
            position = position[1:]
        elif self.env_name == 'Walker2d-v2':
            position = position[1:]
        elif self.env_name == 'HalfCheetah-v2':
            position = position[1:]
        observation = np.concatenate([position, velocity]).ravel()
        return observation
