'''
    Module for offline RL training.
    Contains the RunnerRL class and
    main method to train that Runner.
'''
import numpy as np
import click
import jax
import jax.numpy as jnp
import acme
from acme.agents.jax import actors
from acme.jax import variable_utils

from runner import Runner
import policies
import q_learners
import q_losses
import jax_networks

import naming_utils
from train_uncertainty import RunnerUncertainty
from train_bc import RunnerBC
from uncertainty import UncertaintyWrapper
import utils

def heuristic(dataset, unc_fn, type):
    """ function to generate hyperparam scale """
    _, transition = next(dataset)
    x = transition.data.observation             
    a = transition.data.action
    r = transition.data.reward
    unc = unc_fn(x)[jnp.arange(len(a)), a]
    if type == 'soft_spibb':
        unc_scale = jnp.mean(unc)
    elif type == 'pessimism':
        unc_scale = jnp.mean(r) / jnp.mean(unc)
    return unc_scale
    

class RunnerRL(Runner):
    """Runner object for RL algorithms
    """
    def __init__(self, params):
        super().__init__(params)

        if not self.already_ran:
            
            self.behavior_fn = self._load_behavior_fn()
            self.uncertainty_fn = self._load_uncertainty_fn()

            self.learner_rng, self.actor_rng = jax.random.split(self.rng, 2)
            epsilon = 0.0

            if self.params['learner'] == 'one_step':
                loss_fn = q_losses.SARSA(discount = self.params['discount'])
            
            elif self.params['learner'] == 'soft_spibb':
                if self.params['use_heuristic']:
                    unc_scale = heuristic(self.dataset, self.uncertainty_fn, 'soft_spibb')
                    epsilon = unc_scale * self.params['epsilon']
                    print(f'epsilon: {epsilon}')
                else:
                    epsilon = self.params['epsilon']
                loss_fn = q_losses.SoftSPIBB(discount = self.params['discount'],
                                            epsilon = epsilon)
            
            elif self.params['learner'] == 'bcq':
                loss_fn = q_losses.BCQ(discount = self.params['discount'],
                                        tau = self.params['tau'])

            elif self.params['learner'] == 'cql':
                loss_fn = q_losses.CQL(discount = self.params['discount'],
                                        alpha = self.params['alpha'])
            
            elif self.params['learner'] == 'pessimism':
                if self.params['use_heuristic']:
                    unc_scale = heuristic(self.dataset, self.uncertainty_fn, 'pessimism')
                    alpha = unc_scale * self.params['alpha']
                    print(f'alpha: {alpha}')
                else:
                    alpha = self.params['alpha']
                loss_fn = q_losses.Pessimism(discount = self.params['discount'],
                                            alpha = alpha)
            
            elif self.params['learner'].endswith('qr'):
                self.network = jax_networks.build_network(self.params['network'], 
                                                        self.dummy_obs,
                                                        self.num_actions * 201,
                                                        self.params['width'], 
                                                        self.params['depth'])
                
                if self.params['learner'] == 'cql_qr':
                    loss_fn = q_losses.CQLQuantile(discount = self.params['discount'],
                                                alpha = self.params['alpha'], 
                                                num_actions=self.num_actions)
                elif self.params['learner'] == 'bcq_qr':
                    loss_fn = q_losses.BCQQuantile(discount = self.params['discount'],
                                                tau = self.params['tau'],
                                                num_actions = self.num_actions)
                elif self.params['learner'] == 'soft_spibb_qr':
                    if self.params['use_heuristic']:
                        unc_scale = heuristic(self.dataset, self.uncertainty_fn, 'soft_spibb')
                        epsilon = unc_scale * self.params['epsilon']
                        print(f'epsilon: {epsilon}')
                    else:
                        epsilon = self.params['epsilon']
                    loss_fn = q_losses.SoftSPIBBQuantile(discount = self.params['discount'],
                                                epsilon = epsilon, num_actions=self.num_actions)
                elif self.params['learner'] == 'pessimism_qr':
                    if self.params['use_heuristic']:
                        unc_scale = heuristic(self.dataset, self.uncertainty_fn, 'pessimism')
                        alpha = unc_scale * self.params['alpha']
                        print(f'alpha: {alpha}')
                    else:
                        alpha = self.params['alpha']
                    loss_fn = q_losses.PessimismQuantile(discount = self.params['discount'],
                                                alpha = alpha, num_actions=self.num_actions)
                
            self.learner = q_learners.OfflineSGDLearner(
                                uncertainty_fn = self.uncertainty_fn,
                                behavior_fn = self.behavior_fn,
                                loss_fn = loss_fn,
                                network = self.network,
                                optimizer = self.opt,
                                data_iterator = self.dataset,
                                target_update_period = self.params['target_update_period'],
                                random_key = self.learner_rng,
                                logger = self.train_logger,
                                normalize_fn = self.normalize_fn
                                )
            self.eval_batch = next(self.dataset)[1].data
            self._init_actor(epsilon)
            self.epsilon = epsilon

    def _init_actor(self, epsilon):
            if self.params['actor'] == 'greedy':
                policy = policies.epsilon_greedy_policy(self.network, self.normalize_fn)
            elif self.params['actor'] == 'soft_spibb':
                policy = policies.soft_spibb_policy(self.network, self.uncertainty_fn,
                                        self.behavior_fn, self.normalize_fn,
                                        epsilon)
            elif self.params['actor'] == 'greedy_spibb_fixed':
                policy = policies.greedy_spibb_policy(self.network, self.uncertainty_fn,
                                        self.behavior_fn, self.normalize_fn,
                                        epsilon)
            elif self.params['actor'] == 'bcq':
                policy = policies.bcq_policy(self.network, self.behavior_fn,
                                        self.normalize_fn, self.params['tau'],
                                        self.params['min_prob'])
            elif self.params['actor'] == 'greedy_qr':
                policy = policies.epsilon_greedy_policy_qr(self.network, self.normalize_fn,
                                                            self.num_actions)
            elif self.params['actor'] == 'greedy_spibb_qr_fixed':
                policy = policies.greedy_spibb_policy_qr(self.network, self.uncertainty_fn,
                                        self.behavior_fn, self.normalize_fn,
                                        epsilon,
                                        self.num_actions)
            elif self.params['actor'] == 'soft_spibb_qr':
                policy = policies.soft_spibb_policy_qr(self.network, self.uncertainty_fn,
                                        self.behavior_fn, self.normalize_fn,
                                        epsilon,
                                        self.num_actions)
            elif self.params['actor'] == 'bcq_qr':
                policy = policies.bcq_policy_qr(self.network, self.behavior_fn,
                                        self.normalize_fn, self.params['tau'],
                                        self.num_actions)
            self.actor = actors.FeedForwardActor(
                                policy = policy,
                                random_key = self.actor_rng,
                                variable_client = variable_utils.VariableClient(self.learner, ''),
                                adder=None)

    def eval(self, step):
        loop = acme.EnvironmentLoop(environment=self.environment,
                                    actor=self.actor)
        metrics = {'step': step}
        returns = np.zeros((self.params['eval_episodes']))
        for ep in range(self.params['eval_episodes']):
            #print("eval: ", ep)
            result = loop.run_episode()
            returns[ep] = result['episode_return']
        metrics.update({'return_mean': np.mean(returns),
                        'return_std': np.std(returns)})

        self.eval_logger.write(metrics)

    def _load_behavior_fn(self):
        b_params = utils.config_and_options_to_dict(self.params['b_config'], {})
        b_params.update(naming_utils.prefix_params(self.params, 'b_', keep_non_prefix=False))
        b_params.update({'check_already_ran': False, 'overwrite': False, 
                            'local': self.params['local'], 'blob': self.params['blob']})
        b_params.update(naming_utils.get_data_params(self.params))
        
        b_runner = RunnerBC(b_params)
        b_runner.load(self.load_path, b_params['step'])
        return b_runner.learner.get_probs

    def _load_uncertainty_fn(self):
        unc_params = utils.config_and_options_to_dict(self.params['unc_config'], {})
        unc_params.update(naming_utils.prefix_params(self.params, 'unc_', keep_non_prefix=False))
        unc_params.update({'check_already_ran': False, 'overwrite': False,
                            'local': self.params['local'], 'blob': self.params['blob']})
        unc_params.update(naming_utils.get_data_params(self.params))
        
        if unc_params['type'] == 'n':
            def unit_unc(x):
                return jnp.ones((x.shape[0], 1))
            unc_fn = unit_unc
        else:
            unc_runner = RunnerUncertainty(unc_params)
            unc_runner.load(self.load_path, unc_params['step'])
            unc_fn = unc_runner.learner.uncertainty
        
        if self.params['wrap_unc'] and unc_params['type'] != 'sa':
            unc_fn = UncertaintyWrapper(unc_fn, self.behavior_fn).uncertainty
        return unc_fn

    def load(self, load_path, step):
        super().load(load_path, step)
        self._init_actor(self.epsilon)


@click.command()
@click.option('--config', '-c', default='train_rl', help='config file name')
@click.option('--options', '-o', multiple=True, nargs=2, type=click.Tuple([str, str]))
def main(config, options):
    params = utils.config_and_options_to_dict(config, options)
    runner = RunnerRL(params)
    print("Runner loaded!")
    if not runner.already_ran:
        runner.train()


if __name__ == '__main__':
    main()