import os
import yaml
from typing import Any, NamedTuple
from acme.utils import loggers
import jax
import jax.numpy as jnp

NestedArray = Any
class Transition(NamedTuple):
    """Container for a transition."""
    observation: NestedArray
    action: NestedArray
    reward: NestedArray
    discount: NestedArray
    next_observation: NestedArray
    next_action: NestedArray
    extras: NestedArray = ()

def config_and_options_to_dict(config, options):
    cwd = os.getcwd()
    cfg_file = os.path.join(cwd, config + '.yaml')
    params = yaml.safe_load(open(cfg_file, 'r'))

    # replacing params with command line options
    for opt in options:
        assert opt[0] in params
        dtype = type(params[opt[0]])
        if dtype == bool:
            new_opt = False if opt[1] != 'True' else True
        else:
            new_opt = dtype(opt[1])
        params[opt[0]] = new_opt

    return params

def build_logger(label, path, time_delay=1.0):
    try:
        os.remove(path + '/logs/' + label + '/logs.csv')
        print(f'Overwriting log file: {label}')
    except:
        pass
    term_logger = loggers.TerminalLogger(label=label, time_delta=time_delay, print_fn=print)
    csv_logger = loggers.CSVLogger(directory_or_file=path, label=label, add_uid=False)
    logger = loggers.Dispatcher([term_logger, csv_logger],
                                loggers.base.to_numpy)
    return logger

def eps_greedy_probs(a, epsilon, num_actions):
    one_hot = jax.nn.one_hot(a, num_actions, dtype=jnp.float32)
    return (1. - epsilon) * one_hot + epsilon * jnp.ones_like(one_hot)