import gymnasium
from analysis.viz import *


def make_mujoco(task, horizon=None):
    root = f'data/mujoco/{task}'
    plot_line_X_trials(f'{root}/DrAC_0.1', line_label='DrAC_0.1', horizon=horizon)
    plot_line_X_trials(f'{root}/DrAC_0.2', line_label='DrAC_0.2', horizon=horizon)
    plot_line_X_trials(f'{root}/DrAC_0.3', line_label='DrAC_0.3', horizon=horizon)
    plot_line_X_trials(f'{root}/SAC', line_label='SAC', horizon=horizon, line_kwargs={'color': '#202020'})
    after_plot(title=task.capitalize(), fpath=f'data/mujoco/raw_{task}.png')

    plot_line_X_trials(f'{root}/DrAC_0.1', ema=0.1, line_label='DrAC_0.1', horizon=horizon)
    plot_line_X_trials(f'{root}/DrAC_0.2', ema=0.1, line_label='DrAC_0.2', horizon=horizon)
    plot_line_X_trials(f'{root}/DrAC_0.3', ema=0.1, line_label='DrAC_0.3', horizon=horizon)
    plot_line_X_trials(f'{root}/SAC', ema=0.1, line_label='SAC', horizon=horizon, line_kwargs={'color': '#202020'})
    after_plot(title=task.capitalize(), fpath=f'data/mujoco/ema0.1_{task}.png')
    pass

def make_layer_comparison(task, horizon=None, ema=None, title=None):
    algos = ('SAC', 'DrAC-AM', 'DrAC-logGM')
    k = len(algos)
    _, axs = plt.subplots(ncols=k, sharey=True, figsize=(3 * k, 3))
    for ax, algo in zip(axs, algos):
        for layer in range(2, 8):
            # root = f'data/{task}/l{layer}/{algo}'
            plot_line_X_trials(f'data/{task}/l{layer}/{algo}', line_label=f'l {layer}', ax=ax, horizon=horizon, ema=ema)
        ax.legend()
        ax.set_title(algo)
        ax.set_xlabel('Training Steps')
        ax.set_ylabel('Return')
    plt.suptitle(task if title is None else title)
    plt.tight_layout(pad=0.5)
    fname = 'reward' if ema is None else f'reward_ema{ema}'
    plt.savefig(gp(f'data/{task}/{fname}.png'))
    plt.close()
    pass

def viz_maze_maps():
    from envs.mgmaze.point_maze import register_point_maze
    register_point_maze()
    for level in ('simple', 'medium', 'hard'):
        env = gymnasium.make('MultiGoalPointMaze', maze_map=level)
        fig = plt.figure(figsize=(4, 4))
        ax = fig.subplots()
        env.unwrapped.plot(ax)
        ax.set_xlim(env.unwrapped.observation_space.low[0], env.unwrapped.observation_space.high[0])
        ax.set_ylim(env.unwrapped.observation_space.low[1], env.unwrapped.observation_space.high[1])
        plt.axis('equal')
        plt.grid(False)
        after_plot(title=level.capitalize(), fpath=f'envs/mgmaze/{level}.png', legend=False)
        env.close()

def make_point_maze():
    levels = ('simple', 'medium', 'hard')
    # metrics = ('reward', 'reachable_modes', 'multi_goal_score')
    # ylabels = ('Success rate (in %)', 'Reachable modes', 'Multi-goal score')
    metrics = ('reward', 'reachable_modes')
    ylabels = ('Success rate (in %)', 'Reachable goals')
    _, axs = plt.subplots(2, 3, figsize=(9.75, 5))
    for j, level in enumerate(levels):
        axs[0][j].set_title(level.capitalize())
        axs[1][j].set_xlabel("Training steps")
    for i, (metric, ylabel) in enumerate(zip(metrics, ylabels)):
        for j, level in enumerate(levels):
            ax = axs[i][j]
            l = plot_line_X_trials(f'formal/PointMaze/{level}/DrAC', line_label='DrAmort', ykey=metric, ax=ax, ema=0.5)
            plot_line_X_trials(
                f'formal/PointMaze/{level}/SQL', line_label='SQL', line_kwargs={'c': l.get_color(), 'ls': '--'}, 
                ykey=metric, ax=ax, ema=0.5
            )
            l = plot_line_X_trials(
                f'formal/PointMaze/{level}/DrDfs', line_label='DrDiffus', ykey=metric, ax=ax, ema=0.5,
                only_finished=False
            )
            plot_line_X_trials(
                f'formal/PointMaze/{level}/DACER', line_label='DACER', line_kwargs={'c': l.get_color(), 'ls': '--'}, 
                ykey=metric, ax=ax, ema=0.5
            )
            plot_line_X_trials(f'formal/PointMaze/{level}/SSAC', line_label='S$^2$AC', ykey=metric, ax=ax, ema=0.5)
            plot_line_X_trials(f'formal/PointMaze/{level}/SAC', line_label='SAC', ykey=metric, ax=ax, ema=0.5)
            if metric == 'reachable_modes':
                if level == 'hard':
                    ax.set_ylim(0, 8)
                else:
                    ax.set_ylim(0, 4)
            elif metric == 'reward':
                ax.set_ylim(0, 105)
            else:
                ax.set_ylim(0, 1)
            match level:
                case 'simple':
                    xticks = np.linspace(0, 10_0000, 6)
                    ax.set_xticks(xticks, [f'{int(x/1000)}K' for x in xticks])
                case 'medium':
                    xticks = np.linspace(0, 20_0000, 6)
                    ax.set_xticks(xticks, [f'{int(x/1000)}K' for x in xticks])
                case 'hard':
                    xticks = np.linspace(0, 50_0000, 6)
                    ax.set_xticks(xticks, [f'{int(x/1000)}K' for x in xticks])
            if j == 0:
                ax.set_ylabel(ylabel)
            if i == 0:
                ax.legend(fontsize=8)
    after_plot(title='', fpath=f'formal/PointMaze/PM-curves.png', xlabel='Training steps', legend=False, tight_pad=0.7)

def make_point_maze_robustness():
    levels = ('simple', 'medium', 'hard')
    metrics = ('removal-SR5', 'obstacle-SR5')
    ylabels = ('Removal robustness', 'Obstacle robustness')
    _, axs = plt.subplots(2, 3, figsize=(9.75, 5))
    for j, level in enumerate(levels):
        axs[0][j].set_title(level.capitalize())
        axs[1][j].set_xlabel("Training steps")
    for i, (metric, ylabel) in enumerate(zip(metrics, ylabels)):
        for j, level in enumerate(levels):
            ax = axs[i][j]
            cm_kwargs = dict(logfile='robustness', ykey=metric, ax=ax, ema=0.5)
            l = plot_line_X_trials(f'formal/PointMaze/{level}/DrAC', line_label='DrAmort', **cm_kwargs)
            plot_line_X_trials(f'formal/PointMaze/{level}/SQL', line_label='SQL', line_kwargs={'c': l.get_color(), 'ls': '--'}, **cm_kwargs)
            l = plot_line_X_trials(f'formal/PointMaze/{level}/DrDfs', line_label='DrDiffus', **cm_kwargs)
            plot_line_X_trials(f'formal/PointMaze/{level}/DACER', line_label='DACER', line_kwargs={'c': l.get_color(), 'ls': '--'}, **cm_kwargs)
            plot_line_X_trials(f'formal/PointMaze/{level}/SSAC', line_label='S$^2$AC', **cm_kwargs)
            plot_line_X_trials(f'formal/PointMaze/{level}/SAC', line_label='SAC', **cm_kwargs)
            # try:
            # except Exception:
                # print(f'No SSAC results for {level}')
            if metric == 'reachable_modes':
                if level == 'hard':
                    ax.set_ylim(0, 8)
                else:
                    ax.set_ylim(0, 4)
            elif metric == 'reward':
                ax.set_ylim(0, 105)
            else:
                ax.set_ylim(0, 1.05)
            if j == 0:
                ax.set_ylabel(ylabel)
            if i == 0:
                ax.set_title(level.capitalize())
            match level:
                case 'simple':
                    xticks = np.linspace(0, 10_0000, 6)
                    ax.set_xticks(xticks, [f'{int(x/1000)}K' for x in xticks])
                case 'medium':
                    xticks = np.linspace(0, 20_0000, 6)
                    ax.set_xticks(xticks, [f'{int(x/1000)}K' for x in xticks])
                case 'hard':
                    xticks = np.linspace(0, 50_0000, 6)
                    ax.set_xticks(xticks, [f'{int(x/1000)}K' for x in xticks])
            if i == 0:
                ax.legend(fontsize=8, ncol=2)
    after_plot(title='', fpath=f'formal/PointMaze/PM-robustness.png', xlabel='Training steps', legend=False, tight_pad=0.7)

def make_point_maze_training():
    levels = ('simple', 'medium', 'hard')
    for level in levels:
        plot_discovered_modes(f'data/PointMaze/{level}/SAC', label='SAC default')
        plot_discovered_modes(f'data/PointMaze/{level}/SAC-1.0', label='SAC $h=1.0$')
        plot_discovered_modes(f'data/PointMaze/{level}/SAC-1.2', label='SAC $h=1.2$')
        plot_discovered_modes(f'data/PointMaze/{level}/DrAC2-0.8', label='DrAC $\\beta=0.8$')
        plot_discovered_modes(f'data/PointMaze/{level}/DrAC2-0.7', label='DrAC $\\beta=0.7$')
        after_plot(
            title=f'MultiGoalPointMaze-{level}', fpath=f'data/PointMaze/{level}_discovered_modes.png', 
            ylabel='discovered_modes', legend_kwargs={'fontsize': 8, 'ncols': 2}
        )

def make_robustness_learning_curves():
    levels = ('simple', 'medium', 'hard')
    algos = ('DrAC', 'SAC', 'SQL')
    algo_names = ('DrAmort', 'SAC', 'SQL')
    argss = (('removing', 'SR-mean'), ('obstacle', 'reward'), ('obstacle', 'reachable_modes'), ('obstacle', 'multi_goal_score'))
    ylabels = ('Success rate', 'Success rate (in %)', 'reachable modes', 'Multi-goal score')
    for args, ylabel in zip(argss, ylabels):
        for level in levels:
            for algo, name in zip(algos, algo_names):
                plot_robustness_by_learning(f'formal/PointMaze/{level}/{algo}', *args, label=name)
            fname = args[0] + '_' + ylabel
            after_plot(title=f'{level}: {args[0]}', fpath=f'formal/PointMaze/{level}_{fname}', ylabel=ylabel)
    pass

def make_point_maze_coordinates():
    levels = ('simple', 'medium', 'hard')
    logfiles = ('final_scores.json', 'robustness.csv', 'robustness.csv')
    keys = ('multi_goal_score', 'removal-SR5', 'obstacle-SR5')
    ylabels = ('Multi-goal score', 'Robustness in removal', 'Robustness in obstacles')
    _, axs = plt.subplots(3, 3, figsize=(10.5, 7.5), sharex=True, sharey=True)
    algos = ('DrAC', 'SAC', 'SQL', 'SSAC')
    algo_names = ('DrAmort', 'SAC', 'SQL', 'S$^2$AC')
    axs[0][0].set_xlim(left=0, right=250)
    axs[0][0].set_ylim(0, 1.02)
    for i, (logfile, key, ylabel) in enumerate(zip(logfiles, keys, ylabels)):
        for j, level in enumerate(levels):
            ax = axs[i][j]
            for algo, name in zip(algos, algo_names):
                try:
                    x = read_unit_time(algo)
                    y, _ = read_final_performance(f'formal/PointMaze/{level}/{algo}', logfile, key)
                    ax.scatter([x], [y], label=name)
                except Exception:
                    print(f'No {algo} results on {level}')
            if j == 0:
                ax.set_ylabel(ylabel)
            if i == 0:
                ax.set_title(level.capitalize())
            if i == 2:
                ax.set_xlabel('Seconds per 1K training steps')
            ax.legend(fontsize=8)
    after_plot(title='', fpath=f'formal/PointMaze/time-performace.png', legend=False, tight_pad=1.0)


if __name__ == '__main__':
    make_point_maze()
    make_point_maze_robustness()
    # make_point_maze_coordinates()
    # make_robustness_learning_curves()
