'''
    Module for Runner class.
    This is a superclass containing code to be re-used
    across all training loops that inherit from Runner.
'''
import os
import pandas as pd

import jax
import optax
from acme.jax import savers
import tensorflow as tf

import naming_utils
import jax_networks
import data_utils
from utils import build_logger

gpu = tf.config.experimental.list_physical_devices('GPU')
if len(gpu) > 0:
    tf.config.experimental.set_memory_growth(gpu[0], True)

class Runner():
    """ Runner object """
    def __init__(self, params):
        self.params = params

        # compute train_dir
        if self.params['env_type'] == 'bsuite':
            if self.params['env_id'].startswith('cloud') or \
                self.params['env_id'].startswith('counter'):
                self.params['train_dir'] = self.params['train_type'] + \
                                    '_' + str(self.params['train_seed'])
            else:
                train_size_dict = {'cartpole_0.0': {'small': '20k', 'large': '100k'},
                                'catch_0.0': {'small': '2k', 'large': '10k'}}
                self.params['train_dir'] = self.params['train_type'] + \
                                    '_' + str(self.params['train_seed']) + \
                                    '_' + train_size_dict[self.params['env_id']][self.params['train_size']]
        elif self.params['env_type'] == 'atari':
            self.params['train_dir'] = 'run_' + str(self.params['run']) + \
                                        '_1percent'
                                    # '_num_' + str(self.params['num_shards'])

        print('\n')
        print('Parameters ')
        for key in self.params:
            print(key, self.params[key])
        print('\n')

        # initialize file names
        name = naming_utils.get_job_name(self.params)
        self.job_name = f"{self.params['env_id']}/{self.params['train_dir']}/{name}"


        # NOTE: THESE PATHS NEED TO BE SET TO RUN THE CODE
        self.ckpt_path = f'/mnt/my_output/ckpts/{self.job_name}'
        self.log_path = f'/mnt/my_output/logs/{self.job_name}'
        self.load_path = '/mnt/my_input/ckpts'
        self.data_path = '/mnt/my_input/data'


        # check if job has already run
        self.already_ran = False
        if self.params['check_already_ran']:
            if os.path.exists(self.log_path + '/logs/eval/logs.csv'):
                self.already_ran = True

                try:
                    df = pd.read_csv(self.log_path + '/logs/eval/logs.csv')
                    if len(df) < 50:
                        self.already_ran = False
                except:
                    self.already_ran = False

        # initalize logger, env, network, and data
        if not self.already_ran:

            if self.params['overwrite']:
                self.train_logger = build_logger('train', self.log_path)
                self.eval_logger = build_logger('eval', self.log_path, 0)
            else:
                self.train_logger, self.eval_logger = None, None

            self.environment = data_utils.load_env(self.params)
            self.num_actions = data_utils.get_num_actions(self.params)
            self.dummy_obs = data_utils.get_dummy_obs(self.params)

            if self.params['network'] != 'mlp':
                self.params['width'] = None
                self.params['depth'] = None
            self.network = jax_networks.build_network(self.params['network'], self.dummy_obs,
                                                        self.num_actions,
                                                        self.params['width'], self.params['depth'])
            self.opt = optax.adam(self.params['lr'])
            #optax.chain([optax.adam(self.params['lr']), optax.clip_by_global_norm(10.0)])
            self.rng = jax.random.PRNGKey(self.params['seed'])
            
            loaded_data = data_utils.load_data(self.data_path, self.params)
            self.dataset, self.eval_labels, self.eval_datasets = loaded_data

            self.normalize_fn = data_utils.normalize_fn(self.params['norm'])

    def train(self):
        """train function"""
        for t in range(self.params['train_steps']):
            if t % self.params['eval_period'] == 0:
                self.eval(t)
                self.train_logger._to[1].flush()
                self.eval_logger._to[1].flush()

            if t % self.params['ckpt_period'] == 0:
                ckpt_file = os.path.join(self.ckpt_path, str(t))
                savers.save_to_path(ckpt_file, self.learner.save())

            self.learner.step()

        self.train_logger._to[1].flush()
        self.eval_logger._to[1].flush()

    def load(self, load_path, step):
        path = os.path.join(load_path, self.job_name, str(step))
        self.learner.restore(savers.restore_from_path(path))

    def eval(self, step):
        raise NotImplementedError
