import numpy as np
import jax.numpy as jnp
import tensorflow as tf

import acme
from acme.agents.jax import dqn
import acme.jax.networks as networks_lib
from acme.utils import loggers
from acme.jax import savers
import bsuite

import haiku as hk
import jax_networks

import click
import os
import yaml


def generate_behavior(environment, seed, n_episodes, save_path, epsilon=0.05):
    environment_spec = acme.specs.make_environment_spec(environment)

    num_actions = environment_spec.actions.num_values
    obs_shape = environment.observation_spec().shape
    dummy_obs = jnp.expand_dims(jnp.ones(obs_shape), 0).astype(jnp.float32)

    network_hk = hk.without_apply_rng(hk.transform(lambda x: jax_networks.MLP(widths=[64, 64, num_actions],)(x)))
    network = networks_lib.FeedForwardNetwork(
                        init=lambda rng: network_hk.init(rng, dummy_obs),
                        apply=network_hk.apply)

    # Construct the agent.
    agent = dqn.DQN(environment_spec=environment_spec, network=network,
                    seed=seed, samples_per_insert=32, epsilon=epsilon)

    # Run the environment loop.
    term_logger = loggers.TerminalLogger(label='train', time_delta=1.0, print_fn=print)
    loop = acme.EnvironmentLoop(environment, agent, logger=term_logger)
    loop.run(num_episodes=n_episodes)

    # save policy
    path = os.path.join(save_path, 'behavior')
    savers.save_to_path(path, agent._learner.save())
    print("Agent saved at: " + path)
    return agent


def generate_dataset(agent_fn, environment, env_id, env_noise, seed,
                        length, save_path, policy_name):
    data_dict = {'action': [], 'discount': [], 'episodic_reward': [],
                 'observation': [], 'reward': [], 'step_type': []}
    
    l = 0
    ep = 0
    while l < length:
        action_list, observation_list, reward_list, step_type_list = [],[],[],[]
        
        # Reset any counts and start the environment.
        episode_return = 0
        timestep = environment.reset()
        s = timestep.observation
        a = agent_fn(s)
        st = timestep.step_type

        # Run an episode.
        while not timestep.last():
            # Generate an action from the agent's policy and step the environment.
            timestep = environment.step(a)
            sp = timestep.observation
            ap = agent_fn(sp)
            stp = timestep.step_type

            observation_list.append([s,sp])
            action_list.append([a,ap])
            reward_list.append([timestep.reward])
            step_type_list.append([st, stp])

            s = sp
            a = ap
            st = timestep.step_type

            # Book-keeping.
            episode_return += timestep.reward

        l += len(observation_list)
        ep += 1

        # append trajectory to data_dict
        print(f'[Episode {ep}]: return {episode_return}')
        data_dict['action'].append(np.array(action_list))
        data_dict['observation'].append(np.array(observation_list))
        data_dict['reward'].append(np.array(reward_list))
        data_dict['step_type'].append(np.array(step_type_list))
        data_dict['episodic_reward'].append(episode_return * np.ones((len(reward_list), 1)))
        data_dict['discount'].append( np.ones((len(reward_list), 1), dtype=np.float32))
        
    data_dict = {k: np.concatenate(v, axis=0) for k, v in data_dict.items()}
    print('Mean return: ', np.mean(data_dict['episodic_reward']))
    dataset = tf.data.Dataset.from_tensor_slices(data_dict)

    path = os.path.join(save_path, env_id + '_' + str(env_noise), 
                            policy_name + '_' + str(seed))
    tf.data.experimental.save(dataset, path)
    print("Data saved at: " + path)
    return dataset


def noisy_data(environment, seed, save_path):
    length = 5000
    np.random.seed(seed)

    obs_shape = environment.observation_spec().shape
    environment_spec = acme.specs.make_environment_spec(environment)
    num_actions = environment_spec.actions.num_values
    data_dict = {'action': [], 'discount': [], 'episodic_reward': [],
                 'observation': [], 'reward': [], 'step_type': []}

    data_dict['observation'] = np.random.normal(0, 1, size=(length, 2, *obs_shape))
    data_dict['action'] = np.random.choice(num_actions, size=(length, 2))
    data_dict['discount'] = np.ones((length, 1))
    data_dict['episodic_reward'] = np.zeros((length, 1))
    data_dict['reward'] = np.zeros((length, 1))
    data_dict['step_type'] = np.ones((length, 2))

    dataset = tf.data.Dataset.from_tensor_slices(data_dict)
    path = os.path.join(save_path, 'data_noise')
    tf.data.experimental.save(dataset, path)
    print("Data saved at: " + path)
    return dataset


def uniform_data(environment, env_id, env_noise, seed, length, save_path):
    environment_spec = acme.specs.make_environment_spec(environment)
    num_actions = environment_spec.actions.num_values

    np.random.seed(seed)
    def agent_fn(obs):
        return np.int32(np.random.choice(num_actions))

    return generate_dataset(agent_fn, environment, 
                                env_id, env_noise, seed, 
                                length, save_path, 'uni')


@click.command()
@click.option('--config', '-c', default='generate_data', help='config file name')
@click.option('--options', '-o', multiple=True, nargs=2, type=click.Tuple([str, str]))
def main(config, options):

    cwd = os.getcwd()
    cfg_file = os.path.join(cwd, config + '.yaml')
    params = yaml.safe_load(open(cfg_file, 'r'))

    # replacing params with command line options
    for opt in options:
        assert opt[0] in params
        dtype = type(params[opt[0]])
        if dtype == bool:
            new_opt = False if opt[1] != 'True' else True
        else:
            new_opt = dtype(opt[1])
        params[opt[0]] = new_opt

    # names: 
    #   pure data: env_noise_policy_seed
    #   mixed data: 
    #       across policies: env_noise_mix_weights_seed
    #       across seeds: env_noise_mix_nseeds_idx

    episode_dict = {
        'cartpole':{ 'med': 200,
                    'exp': 500},
        'catch':{ 'med': 200,
                    'exp': 1000},           
    }

    data_len_dict = {
        'cartpole': 100000,
        'catch': 5000
    }

    raw_environment = bsuite.load(params['env_id'] + '_noise', kwargs={'seed':params['seed'], 
                                                        'noise_scale':params['env_noise']}) 
    environment = acme.wrappers.SinglePrecisionWrapper(raw_environment)

    if params['policy'] == 'noise':
        data = noisy_data(environment, params['seed'], params['save_path'])
    elif params['policy'] == 'uni':
        data = uniform_data(environment, params['env_id'],  params['env_noise'],
                            params['seed'], data_len_dict[params['env_id']],
                            params['save_path'])
    elif params['policy'] == 'med' or params['policy'] == 'exp':
        agent = generate_behavior(environment, params['env_id'],  params['env_noise'],
                                     params['seed'], 
                                    episode_dict[params['env_id']][params['policy']], 
                                    params['save_path'])
        agent_fn = agent._actor.select_action
        data = generate_dataset(agent_fn, environment, params['env_id'],  params['env_noise'],
                                    params['seed'], data_len_dict[params['env_id']], 
                                    params['save_path'], params['policy'])
    else:
        raise NameError


if __name__ == '__main__':
    main()