import click
import eval_utils
import os
import numpy as np

import yaml

import data_utils
import eval_utils
import utils
import naming_utils


def get_returns(runner, env, eps=1):
    returns = []
    for i in range(eps):
        _, _, rews = eval_utils.collect_episode(env, runner.actor)
        returns.append(np.sum(rews))
        print(f'ep {i+1} of {eps}: return {np.sum(rews)}')
    return returns


@click.command()
@click.option('--config', '-c', default='eval_saved_model', help='config file name')
@click.option('--options', '-o', multiple=True, nargs=2, type=click.Tuple([str, str]))
def main(config, options):
    params = utils.config_and_options_to_dict(config, options)

    if params['env_type'] == 'atari':
        config = 'train_rl_atari'
    else:
        config = 'train_rl'

    cwd = os.getcwd()
    cfg_file = os.path.join(cwd, config + '.yaml')
    rl_params = yaml.safe_load(open(cfg_file, 'r'))

    rl_params.update({'check_already_ran': False, 
                        'overwrite': False, 
                        'local': False})
    
    if params['env_type'] == 'atari':
        rl_params.update({'env_id': params['env_id'], 
                    'run': params['run']})
    elif params['env_type'] == 'bsuite':
        rl_params.update({'env_id': params['env_id'], 
                    'train_type': params['train_type'],
                    'train_seed': params['seed']})
    else:
        raise NotImplementedError

    rl_params.update({'learner': params['learner'], 
                    'actor': params['actor'], 
                    'epsilon': params['train_epsilon'], 
                    'unc_noise_scale': params['unc_noise_scale']})

    job_name = naming_utils.get_job_name(rl_params)

    env = data_utils.load_env(params)
    if params['env_type'] == 'atari':
        step = 1000000
    else:
        step = 100000
    rl_runner = eval_utils.load_rl_runner('/mnt/my_input/ckpts', step, rl_params)
    
    epsilons = [rl_runner.epsilon * (3**i) for i in range(-6, 7)]
    for ep_idx in range(len(epsilons)):
        test_epsilon = epsilons[ep_idx]
        print(f'Test epsilon: {test_epsilon}')
        rl_runner._init_actor(test_epsilon)
        returns = get_returns(rl_runner, env, eps=params['episodes'])

        if params['env_type'] == 'atari':
            env_id = params['env_id']
            run = params['run']
            save_dir = f'/mnt/my_output/eval/{env_id}/{run}/{job_name}'
        else:
            env_id = params['env_id']
            train_type = params['train_type']
            train_seed = params['seed']
            save_dir = f'/mnt/my_output/eval/{env_id}/{train_type}/{train_seed}/{job_name}'
        os.makedirs(save_dir, exist_ok=True)
        np.save(f'{save_dir}/{ep_idx}', returns)


if __name__ == '__main__':
    main()