'''
    Module for behavior cloning.
    Contains the RunnerBC class and
    main method to train that Runner.
'''
import click
import numpy as np
import acme
from acme.agents.jax import actors
from acme.jax import variable_utils

import bc
from runner import Runner
import policies
import utils

class RunnerBC(Runner):
    '''
        A runner object for behavior cloning
    '''   
    def __init__(self, params):
        super().__init__(params)

        if not self.already_ran:
            # build learner
            self.learner = bc.BCLearner(net = self.network, opt = self.opt,
                                dataset = self.dataset,
                                normalize_fn= self.normalize_fn,
                                seed = self.params['seed'],
                                logger = self.train_logger)

            policy = policies.softmax_policy(self.network, self.normalize_fn)
            self.actor = actors.FeedForwardActor(
                                policy = policy,
                                random_key = self.rng,
                                variable_client = variable_utils.VariableClient(self.learner, ''),
                                adder=None)

    def eval(self, step):
        loop = acme.EnvironmentLoop(environment=self.environment,
                                    actor=self.actor)
        returns = np.zeros((self.params['eval_episodes']))
        for ep in range(self.params['eval_episodes']):
            result = loop.run_episode()
            returns[ep] = result['episode_return']

        for label, d in zip(self.eval_labels, self.eval_datasets):
            _, transition = next(d)
            metrics = self.learner.eval(transition.data.observation,
                                        transition.data.action,
                                        {'label': label, 'step': step,
                                        'return_mean': np.mean(returns),
                                        'return_std': np.std(returns)})
            self.eval_logger.write(metrics)

@click.command()
@click.option('--config', '-c', default='train_bc', help='config file name')
@click.option('--options', '-o', multiple=True, nargs=2, type=click.Tuple([str, str]))
def main(config, options):
    params = utils.config_and_options_to_dict(config, options)
    runner = RunnerBC(params)
    if not runner.already_ran:
        runner.train()


if __name__ == '__main__':
    main()