import os
import numpy as np
import tensorflow as tf

import jax.numpy as jnp
import jax

import bsuite
import acme
from acme.datasets import NumpyIterator

import atari_utils
import bsuite_utils
from cloud_env import Cloud
from counter_env import Counter


""" Environment loading """

def load_env(params):
    if params['env_type'] == 'bsuite':
        env_id = params['env_id'].split('_')[0]
        env_noise = float(params['env_id'].split('_')[1])
        if env_id == 'cloud':
            raw_environment = Cloud(r_prob = env_noise, seed=params['seed'])
        elif env_id == 'counter':
            raw_environment = Counter(r_prob = env_noise, seed=params['seed'])
        else:
            raw_environment = bsuite.load(env_id + '_noise', kwargs={'seed':params['seed'],
                                                    'noise_scale':env_noise})
        environment = acme.wrappers.SinglePrecisionWrapper(raw_environment)

    elif params['env_type'] == 'atari':
        environment = atari_utils.environment(game=params['env_id'])

    return environment

def get_dummy_obs(params):
    env = load_env(params)
    obs_shape = env.observation_spec().shape
    dummy_obs = jnp.expand_dims(jnp.ones(obs_shape), 0).astype(jnp.float32)
    return dummy_obs

def get_num_actions(params):
    env = load_env(params)
    num_actions = env.action_spec().num_values
    return num_actions


""" Data loading """

def load_data(data_path, params):
    if params['env_type'] == 'bsuite':
        train_path = os.path.join(data_path, params['env_id'], params['train_dir'])
        data = bsuite_utils.my_bsuite_dataset(train_path)
        #mean, std = mean_and_std(data)
        data = tf.data.Dataset.zip((tf.data.Dataset.range(len(data)), data))
        data = batch_and_shuffle(data, params['batch_size'], params['seed'])

        if 'eval_dirs' in params.keys():
            eval_labels = params['eval_dirs']
            # TODO: fix this hack
            if params['env_id'] == 'catch_0.0':
                eval_labels = [e[:-3] + '2k' for e in eval_labels]
            if params['env_id'].startswith('cloud') or \
                params['env_id'].startswith('counter'):
                eval_labels = [params['train_dir']]

            eval_paths = [os.path.join(data_path, params['env_id'], d) for d in eval_labels]
            eval_data = [bsuite_utils.my_bsuite_dataset(p) for p in eval_paths]
            eval_data = [tf.data.Dataset.zip((tf.data.Dataset.range(len(d)), d))
                                for d in eval_data]
            eval_data = [batch_and_shuffle(d, params['batch_size'], params['seed'])
                                for d in eval_data]

    elif params['env_type'] == 'atari':
        #mean, std = 0, 1
        path = os.path.join(data_path, params['env_id'],
                            'run_' + str(params['run'])  +'_1percent')
        data = tf.data.experimental.load(path, compression='GZIP')
        data = data.enumerate()
        data = data.repeat()

        if 'eval_dirs' in params.keys():
            eval_labels = ['train', '1']
            eval_data = [data, atari_utils.atari_dataset(data_path, params['env_id'],
                                                        params['run'], [1],
                                                        repeat=True, include_idx=True)]
            eval_data = [just_batch(d, params['batch_size']) for d in eval_data]

        data = just_batch(data, params['batch_size'])

    if 'eval_dirs' not in params.keys():
        eval_labels, eval_data = None, None

    return data, eval_labels, eval_data


""" Normalization """

def normalize_fn(norm_type, mean=None, std=None):

    if norm_type == 'center':
        norm_fn = batch_center_fn(mean, std)
    elif norm_type == 'unit':
        norm_fn = unit_norm
    elif norm_type == 'scale':
        norm_fn = scale_by_255
    elif norm_type == 'id':
        norm_fn = identity
    elif norm_type == 'center_unit':
        norm_fn = batch_center_and_unit(mean)
    elif norm_type == 'center_1d':
        norm_fn = batch_center_1d(mean, std)
    else:
        raise NotImplementedError

    return norm_fn

def unit_norm(x):
    # Warning: assumes batched array as input
    assert x.ndim > 1
    x = x.astype(jnp.float32)
    # computes norm over non-batch dimensions
    norm = jnp.sqrt(jnp.sum(jnp.square(x), axis=tuple(range(1, x.ndim)), keepdims=True))
    dim = jnp.prod(jnp.array(x.shape[1:]))
    return jnp.sqrt(dim) * x / norm

def scale_by_255(x):
    x = x.astype(jnp.float32)
    return x / 255.

def identity(x):
    x = x.astype(jnp.float32)
    return x

def batch_center_fn(mean, std):
    def normalize(x):
        x = x.astype(jnp.float32)
        return (x - mean) / (std + 1e-3)
    return jax.vmap(normalize, in_axes=(0,))

def batch_center_1d(mean, std):
    mean = jnp.mean(mean)
    std = jnp.mean(std)
    def normalize(x):
        x = x.astype(jnp.float32)
        return (x - mean) / (std + 1e-3)
    return jax.vmap(normalize, in_axes=(0,))

def batch_center_and_unit(mean):
    def normalize(x):
        x = x.astype(jnp.float32)
        return unit_norm(x - mean)
    return jax.vmap(normalize, in_axes=(0,))


""" TF dataset helpers """

def batch_and_shuffle(dataset, batch_size, seed=0):
    dataset = dataset.repeat()
    dataset = dataset.shuffle(batch_size * 10, seed=seed)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    return NumpyIterator(dataset)

def just_batch(dataset, batch_size):
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    return NumpyIterator(dataset)

def mean_and_std(dataset):
    obs = next(iter(dataset.batch(1))).data.observation[0].numpy()
    length = dataset.reduce(0.0, lambda x, _: x + 1).numpy()
    mean = dataset.reduce(np.zeros_like(obs), lambda x, y: x + y.data.observation).numpy() / length
    var = dataset.reduce(np.zeros_like(obs), lambda x, y: x + (y.data.observation - mean)**2).numpy() / (length - 1)
    return mean, np.sqrt(var)