import argparse
import numpy as np


def argmax(vector):
    # This argmax breaks ties randomly
    return np.random.choice(np.flatnonzero(vector == vector.max()))


class ArgsParser:
    """
    Read the user's input and parse the arguments properly. When returning args, each value is properly filled.
    Ideally one shouldn't have to read this function to access the proper arguments, but I postpone this.
    """

    @staticmethod
    def read_input_args():
        # Parse command line
        parser = argparse.ArgumentParser(
            description='Define algorithm\'s parameters.')

        parser.add_argument('-s', '--seed', type=int, default=1, help='Seed to be used in the code.')
        parser.add_argument('-i', '--input', type=str, default='mdps/toy.mdp',
                            help='File containing the MDP definition (default: mdps/toy.mdp).')
        parser.add_argument('-n', '--num_episodes', type=int, default=100,
                            help='For how many episodes we are going to learn.')
        parser.add_argument('-a', '--step_size', type=float, default=0.1,
                            help="Algorithm's step size. Alpha parameter in algorithms such as Sarsa.")
        parser.add_argument('-y', '--step_size_sr', type=float, default=0.1,
                            help="Step size to compute the SR with TD when using it in algorithms such as Sarsa.")
        parser.add_argument('-b', '--beta', type=float, default=1.0,
                            help="Real reward = Real reward + beta * Intrinsic Reward.")
        parser.add_argument('-g', '--gamma', type=float, default=0.95,
                            help='Gamma. Discount factor to be used by the algorithm.')
        parser.add_argument('-z ', '--gamma_sr', type=float, default=0.95,
                            help='Gamma value to compute the SR.')
        parser.add_argument('-e', '--epsilon', type=float, default=0.01,
                            help='Epsilon. This is the exploration parameter (trade-off).')
        parser.add_argument('-r', '--reward_structure', type=str, default="",
                            help="Valid values: 'dot-prod', 'diff', 'gamma-diff', 'norm' ")
        parser.add_argument('-d', '--divide', type=bool, default=False,
                            help="If true, the reward is equal to 1/reward_structure")

        parser.add_argument('-w', '--step_size_dr', type=float, default=0.1,
                            help="Step size to compute the DR with TD when using it in algorithms such as Sarsa.")
        parser.add_argument('-c ', '--gamma_dr', type=float, default=1.0,
                            help='Gamma value to compute the DR. (constant = 1)')
        parser.add_argument('-l ', '--lambda_dr', type=float, default=1.3,
                            help='Lambda value to compute the DR.')
        parser.add_argument('-t ', '--transform', type=str, default="l1",
                            help='Transform on the norm of DR - l1, l2, log_l1, log_l2')

        args = parser.parse_args()

        return args
