import numpy as np
import random

from copy import copy, deepcopy
from pathlib import Path

from bal import ImplicitMirrorDescentValueIteration, BoundedAdvantageLearning
from utils import save_as_npy


def get_maze(maze_type):
    if maze_type == 'small':
        maze = [
            [0, 0, 0, 1, 1, 1],
            [0, 0, 0, 0, 0, 0],
            [1, 1, 1, 0, 0, 0],
            [0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0],
            [0, 0, 0, 1, 1, 1],
        ]
        reward_map = [
            [0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0],
            [0, 1, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0],
        ]
    else:
        maze = [
            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
        ]

        reward_map = [
            [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
            [-1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1, -1],
            [-1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1],
            [-1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1],
            [-1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1],
            [-1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1],
            [-1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1],
            [-1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1],
            [-1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1],
            [-1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1],
            [-1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1],
            [-1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1],
            [-1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1],
            [-1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1],
            [-1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1],
            [-1,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  2, -1],
            [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
        ]
    return maze, reward_map


def get_cfg(args):
    exp_cfg = args.initial_psi + '_' + args.algo
    if args.algo == 'bal':
        exp_cfg += f'_f{args.bound_f}_g{args.bound_g}'
    exp_cfg = exp_cfg + f'_gam{args.gamma}'
    if args.alpha is None:
        exp_cfg = exp_cfg + f'_tau{args.tau}_kappa{args.kappa}'
    else:
        exp_cfg = exp_cfg + f'_alpha{args.alpha}_beta{args.beta}'
    exp_cfg = exp_cfg+\
                f'_nitr{args.num_iterations}'+\
                f'_randE{args.randomness_of_env}'+\
                f'_berr{args.backup_error_magnitude}'+\
                f'_nPE{args.num_pe_backups}'+\
                f'_dPE{args.max_diff_value}'+\
                f'_f{args.record_frequency}'
    return exp_cfg


def run(args):

    exp_cfg = get_cfg(args)

    maze, reward_map = get_maze(args.maze)
    world = gridworld_stochastic.GridWorld(maze, reward_map, epsilon=args.randomness_of_env)

    min_V = (np.array(reward_map).min() + args.tau*np.log(4)) / (1 - args.gamma)
    max_V = (np.array(reward_map).max() + args.tau*np.log(4)) / (1 - args.gamma)

    if args.algo == 'imdvi':
        solver = ImplicitMirrorDescentValueIteration(
            world, gamma=args.gamma,
            alpha=args.alpha, beta=args.beta,
            tau=args.tau, kappa=args.kappa,
            num_pe_backups=args.num_pe_backups, backup_error_magnitude=args.backup_error_magnitude,
            seed=args.seed, initial_psi=args.initial_psi,
        )
    elif args.algo == 'bal':
        solver = BoundedAdvantageLearning(
            world, gamma=args.gamma,
            alpha=args.alpha, beta=args.beta,
            tau=args.tau, kappa=args.kappa,
            num_pe_backups=args.num_pe_backups, backup_error_magnitude=args.backup_error_magnitude,
            seed=args.seed, initial_psi=args.initial_psi,
            bound_f=args.bound_f, bound_g=args.bound_g,
        )
    else:
        raise ValueError

    result_path = Path('./result').resolve()
    # result_path = result_path / exp_cfg
    result_path = result_path / exp_cfg / f'seed{str(args.seed)}'
    result_path.mkdir(parents=True, exist_ok=True)

    maze_painter = painter.Painter(solver)
    maze_painter.draw_state_value_by_table(min=min_V, max=max_V)
    maze_painter.save_grid(result_path / f'initial.pdf')
    maze_painter.reset_plot()

    V_opt = None
    psi_opt = None
    if args.record_curve:
        if args.V_opt_filename is None:
            psi_opt = np.load(args.psi_opt_filename)
        else:
            V_opt = np.load(args.V_opt_filename)

    record_psi, record_V = solver.fit(
        max_diff_value=args.max_diff_value,
        num_iterations=args.num_iterations,
        record_curve=args.record_curve,
        record_frequency=args.record_frequency,
        V_opt=V_opt, psi_opt=psi_opt,
        verbose=args.verbose)

    print(solver.psi)

    save_as_npy(result_path / 'record_psi', record_psi)
    save_as_npy(result_path / 'record_V', record_V)
    save_as_npy(result_path / 'psi_final', solver.get_psi())
    save_as_npy(result_path / 'policy_final', solver.get_policy())

    if args.record_curve:
        save_as_npy(result_path / 'error_curve', solver.get_error_curve())

    if args.plot_intermediate_value:
        for itr, V in enumerate(record_V):
            maze_painter.draw_state_value_by_table(value_2d=solver.get_V_2D(V=V), min=min_V, max=max_V)
            maze_painter.save_grid(result_path / f't{itr*record_frequency}.pdf')
            maze_painter.reset_plot()

    maze_painter.draw_state_value_by_table(min=min_V, max=max_V)
    maze_painter.save_grid(result_path / f'final.pdf')
    maze_painter.reset_plot()



def evaluate(args):
    # exp_cfg = get_cfg(args)

    maze, reward_map = get_maze(args.maze)
    world = gridworld_stochastic.GridWorld(maze, reward_map, epsilon=args.randomness_of_env)

    min_V = (np.array(reward_map).min() + args.tau*np.log(4)) / (1 - args.gamma)
    max_V = (np.array(reward_map).max() + args.tau*np.log(4)) / (1 - args.gamma)

    policy_path = Path(args.policy_filename).resolve()
    pi = np.load(policy_path)

    solver = ImplicitMirrorDescentValueIteration(
        world, gamma=args.gamma,
        alpha=args.alpha, beta=args.beta,
        tau=args.tau, kappa=args.kappa,
        num_pe_backups=args.num_pe_backups, backup_error_magnitude=args.backup_error_magnitude,
        seed=args.seed, initial_psi=args.initial_psi,
    )
    V_pi_tau, V_pi_alpha = solver.evaluate(pi, args.max_diff_value, args.verbose)

    result_path = policy_path.parent
    save_as_npy(result_path / 'V_pi_tau', V_pi_tau)
    save_as_npy(result_path / 'V_pi_alpha', V_pi_alpha)



if __name__ == '__main__':

    import gridworld_stochastic
    import painter

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--run', action='store_true')
    parser.add_argument('--evaluate', action='store_true')
    parser.add_argument('--aggregate', action='store_true')
    parser.add_argument('--plot_summary', action='store_true')
    parser.add_argument('--plot_tidy', action='store_true')
    parser.add_argument('--verbose', action='store_true')

    parser.add_argument('--maze', type=str, default='large')
    parser.add_argument('--algo', type=str, default='bal', choices=[
        'bal', 'imdvi'])
    parser.add_argument('--seed', type=int, default=459)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--alpha', type=float, default=None)
    parser.add_argument('--beta', type=float, default=None)
    parser.add_argument('--tau', type=float, default=0.1)
    parser.add_argument('--kappa', type=float, default=0.1)
    parser.add_argument('--randomness_of_env', type=float, default=0.1)
    parser.add_argument('--backup_error_magnitude', type=float, default=0.0)
    parser.add_argument('--initial_psi', type=str, default='random', choices=[
        'zero', 'optimistic', 'random'])

    parser.add_argument('--bound_f', type=str, default='rclip', choices=[
        'identity', 'nclip', 'rclip', 'ntanh', 'rtanh'])
    parser.add_argument('--bound_g', type=str, default='rclip', choices=[
        'identity', 'nclip', 'rclip', 'ntanh', 'rtanh'])

    parser.add_argument('--num_iterations', type=int, default=100)
    parser.add_argument('--num_pe_backups', type=int, default=1000)
    parser.add_argument('--max_diff_value', type=float, default=0.0001)

    parser.add_argument('--plot_intermediate_value', action='store_true')
    parser.add_argument('--record_curve', action='store_true')
    parser.add_argument('--record_frequency', type=int, default=10)
    parser.add_argument('--V_opt_filename', type=str, default=None)
    parser.add_argument('--psi_opt_filename', type=str, default=None)
    parser.add_argument('--policy_filename', type=str, default=None)

    args = parser.parse_args()

    # random.seed(args.seed)
    np.random.seed(args.seed)

    if args.run:
        run(args)

    if args.evaluate:
        evaluate(args)
