import json
import os
from contextlib import contextmanager

import matplotlib.pyplot as plt

import plot


CORE_ALGORITHMS = ('SGD',
                   'SGDmwd',
                   'Adam',
                   # 'SGDQLR_Undamped_Hessian',
                   # 'SGDQLR_Damped_Hessian',
                   # 'AdamQLR_Undamped_Hessian',
                   # 'AdamQLR_Undamped_Fisher',
                   # 'AdamQLR_Undamped_Clipped',
                   'AdamQLR_Damped',
                   # 'AdamQLR_Damped_Clipped',
                   # 'AdamQLR_Damped_Enveloped',
                   # 'AdamQLR_Damped_Hessian',
                   # 'AdamQLR_Damped_Hessian_NoHPO_SFN',
                   # 'AdamQLR_Damped_AdamDampedCurvature',
                   'AdamQLR_NoHPO',
                   # 'AdamQLR_Damped_Hessian_DecreasingLossDamping',
                   # 'AdamQLR_NoHPO_DecreasingLossDamping',
                   'KFAC')
def tight_savefig(directory):
    plt.savefig(directory,
                bbox_inches='tight',
                pad_inches=0)
    plt.close()


@contextmanager
def paper_theme():
    with plt.style.context('Solarize_Light2'):
        yield
        plt.gcf().set_facecolor('white')
        plt.gcf().set_size_inches(6, 4)

def rosenbrock_trajectory_plot():
    with plot.inhibit_plt_show(), paper_theme():
        plot.plot_rosenbrock_paths(
            '/path/to/2023-09-26 Rosenbrock AuthorNameLRClipping',
            ['GD', 'GDmwd', 'Adam', 'AdamQLR_Damped_Hessian', 'AdamQLR_NoHPO'])
    plt.xlabel(r'$x$')
    plt.ylabel(r'$y$')
    tight_savefig('./plots/paper_AuthorNameLRClipping/RosenbrockTrajectory.pdf')


def loss_evolution_plots():
    source_directories = (
        '2023-09-24 UCI_Energy AuthorNameLRClipping',
        '2023-09-24 UCI_Protein AuthorNameLRClipping',
        '2023-09-24 Fashion-MNIST AuthorNameLRClipping',
        '2023-09-21 CIFAR-10 AuthorNameLRClipping',
        '2023-09-24 SVHN AuthorNameLRClipping',
        # '2023-09-07 PennTreebank_GPT2_Reset VanillaAuthorNameClipping',
    )
    # source_directories = (
    #     '2023-09-27 UCI_Energy ASHA_Time_Training',
    #     '2023-09-27 UCI_Protein ASHA_Time_Training',
    #     '2023-09-27 Fashion-MNIST ASHA_Time_Training',
    #     # '2023-09-27 CIFAR-10 ASHA_Time_Training',
    #     # '2023-09-27 SVHN ASHA_Time_Training',
    #     '2023-09-27 UCI_Energy ASHA_Time_Validation',
    #     '2023-09-27 UCI_Protein ASHA_Time_Validation',
    #     '2023-09-27 Fashion-MNIST ASHA_Time_Validation',
    #     # '2023-09-27 CIFAR-10 ASHA_Time_Validation',
    #     # '2023-09-27 SVHN ASHA_Time_Validation',
    # )
    for directory in source_directories:
        dataset_name = directory.split(' ')[1]
        match dataset_name:
            case 'UCI_Energy': break_point=300
            case 'UCI_Protein': break_point=300
            case 'Fashion-MNIST': break_point = 30
            case 'SVHN': break_point=250
            case 'CIFAR-10': break_point=2000
            case 'PennTreebank_GPT2_Reset': break_point=4000
        valid_metrics = ['Loss/Training',
                         'Loss/Test',
                         'Adaptive/Learning_Rate',]
        # if not dataset_name.startswith('UCI'):
        #     valid_metrics.extend(['Accuracy/Training',
        #                           'Accuracy/Test'])
        if 'ASHA_Time_' in directory:
            dataset_name += directory.split(' ')[-1]
        for metric in valid_metrics:
            with plot.inhibit_plt_show(), paper_theme():
                axes = plot.plot_best_run_envelopes(
                    f'//path/to/{directory}',
                    metric,
                    log_x_axis=False,
                    included_algorithms=CORE_ALGORITHMS,
                    aggregation='median',
                    break_x_axis=break_point)
                if metric.startswith('Accuracy'):
                    axes[0].set_yscale('linear')
                    axes[0].set_ylim(0, 1.0)
            metric_name = metric.split('/')[1]
            if metric.startswith('Loss'):
                suffix = ' Loss'
            elif metric.startswith('Accuracy'):
                suffix = ' Accuracy'
            else:
                suffix = ''
            pretty_metric_name = ' '.join(metric_name.split('_'))
            axes[0].set_ylabel(f'{pretty_metric_name}{suffix}')
            tight_savefig(f'./plots/paper_AuthorNameLRClipping/{dataset_name}_{metric_name}{suffix.split(" ")[-1]}.pdf')


def sensitivity_plots():
    directories = (
        '/path/to/ray/2023-09-24T09:38:35.097408__fashion_mnist__AdamQLR_Damped__ASHA Sensitivity_Amplification',
        '/path/to/ray/2023-09-24T09:38:35.097408__fashion_mnist__AdamQLR_Damped__ASHA Sensitivity_BatchSize',
        '/path/to/ray/2023-09-24T09:38:35.097408__fashion_mnist__AdamQLR_Damped__ASHA Sensitivity_InitialDamping',
        '/path/to/ray/2023-09-24T09:38:35.097408__fashion_mnist__AdamQLR_Damped__ASHA Sensitivity_LRClipping',
        '/path/to/ray/2023-09-24T09:38:35.097408__fashion_mnist__AdamQLR_Damped__ASHA Sensitivity_SteppingFactor',
    )
    for directory in directories:
        ablation_type = directory.split('_')[-1]
        for metric in ('Loss/Training', 'Loss/Test'):
            with (plot.inhibit_plt_show(), paper_theme()):
                axes = plot.plot_ablation_trends(directory,
                                                 metric,
                                                 log_x_axis=False,
                                                 aggregation='median',
                                                 break_x_axis=14 if ablation_type == 'BatchSize' else False)
            loss_name = metric[5:]
            axes[0].set_ylabel(f'{loss_name} Loss')
            if ablation_type != 'BatchSize':
                plt.xlim(0, 14)
            tight_savefig(f'./plots/paper_AuthorNameLRClipping/Sensitivity_{ablation_type}_{loss_name}Loss.pdf')

    for metric in ('Loss/Training', 'Loss/Test'):
        with plot.inhibit_plt_show(), paper_theme():
            axes = plot.plot_best_run_envelopes(
                '/path/to/2023-09-24 Fashion-MNIST AuthorNameLRClipping',
                metric,
                log_x_axis=False,
                included_algorithms=CORE_ALGORITHMS,
                aggregation='median',
                break_x_axis=14)
        loss_name = metric[5:]
        axes[0].set_ylabel(f'{loss_name} Loss')
        tight_savefig(f'./plots/paper_AuthorNameLRClipping/Sensitivity_ReprisedFashion-MNIST_{loss_name}Loss.pdf')


def hyperparameter_table():
    source_directories = (
        '2023-09-26 Rosenbrock AuthorNameLRClipping',
        '2023-09-24 UCI_Energy AuthorNameLRClipping',
        '2023-09-24 UCI_Protein AuthorNameLRClipping',
        '2023-09-24 Fashion-MNIST AuthorNameLRClipping',
        '2023-09-21 CIFAR-10 AuthorNameLRClipping',
        '2023-09-24 SVHN AuthorNameLRClipping',
    )
    with open('./plots/paper_AuthorNameLRClipping/Hyperparameters.tex', 'w') as table:
        for directory_idx, directory in enumerate(source_directories):
            dataset_name = directory.split(' ')[1].replace('_', ' ')
            num_algorithms = len([d for d in os.scandir(f'/path/to/{directory}')
                                  if plot.KEY_TO_LABEL.get(d.name, False)])
            table.write(r'\multirow{'+ str(num_algorithms) + r'}{*}' + f'{{{dataset_name}}}\n')
            for algorithm, label in plot.KEY_TO_LABEL.items():
                best_runs_path = os.path.join('/path/to/', directory, algorithm)
                if not label:
                    continue
                if not os.path.exists(best_runs_path):
                    continue
                if algorithm == 'AdamQLR_Damped_Fisher':
                    # Duplicate folder for ease of labelling ablation plots
                    continue
                config_file = os.path.join(next(os.scandir(best_runs_path)), 'config.json')
                with open(config_file, 'r') as config_raw:
                    config_data = json.load(config_raw)

                table.write('& ' + plot.KEY_TO_LABEL[algorithm] + ' \t& ')

                if dataset_name != 'Rosenbrock':
                    table.write(f"{config_data['batch_size']} \t& ")
                else: table.write("{---} \t& ")

                if algorithm in ('SGD', 'SGDmwd', 'Adam'):
                    table.write(f"{config_data['optimiser']['learning_rate']:.4e} \t& ")
                else: table.write("{---} \t& ")

                if 'lr_clipping' in config_data['optimiser']:
                    table.write(f"{config_data['optimiser']['lr_clipping']:.3f} \t& ")
                else: table.write("{---} \t& ")

                if algorithm == 'SGDmwd':
                    table.write(f"{config_data['optimiser']['momentum']:.4f} \t& "
                                f"{config_data['optimiser']['add_decayed_weights']:.4e} \t& ")
                else: table.write("{---} \t& {---} \t& ")

                if 'Damped' in algorithm or algorithm in ('KFAC', 'AdamQLR_NoHPO'):
                    table.write(f"{config_data['optimiser']['initial_damping']:.4e} \t& ")
                else: table.write("{---} \t&")

                if config_data['optimiser'].get('damping_increase_factor', None):
                    table.write(f"{config_data['optimiser']['damping_decrease_factor']:.1f} \t& "
                                f"{config_data['optimiser']['damping_increase_factor']:.1f} ")
                else: table.write("{---} \t& {---} ")

                table.write(r'\\' + '\n')
            if directory_idx + 1 != len(source_directories):
                table.write(r'\midrule' + '\n')
        table.write(r'\bottomrule')


def ablation_plots():
    plot_configs = dict(
        Damping=dict(directories=('2023-09-24 Fashion-MNIST AuthorNameLRClipping',
                                  '2023-09-21 CIFAR-10 AuthorNameLRClipping'),
                     algorithms=('Adam',
                                 'AdamQLR_Undamped',
                                 'AdamQLR_Damped')),
        Curvature=dict(directories=('2023-09-24 Fashion-MNIST AuthorNameLRClipping',
                                    '2023-09-21 CIFAR-10 AuthorNameLRClipping',),
                       algorithms=('Adam',
                                   'AdamQLR_Damped_Hessian',
                                   'AdamQLR_Damped_Fisher')),
    )
    for plot_name, plot_config in plot_configs.items():
        for directory in plot_config['directories']:
            dataset_name = directory.split(' ')[1]
            for metric in ('Loss/Training', 'Loss/Test'):
                loss_name = metric[5:]
                with plot.inhibit_plt_show(), paper_theme():
                    axes = plot.plot_best_run_envelopes(
                        f'/path/to/{directory}',
                        metric,
                        log_x_axis=False,
                        included_algorithms=plot_config['algorithms'],
                        aggregation='median')
                axes[0].set_ylabel(f'{loss_name} Loss')
                # match dataset_name:
                #     case 'Fashion-MNIST': plt.ylim(7e-2, 2.5e0)
                #     case 'CIFAR-10': plt.ylim(7e-1, 2e1)
                tight_savefig(f'./plots/paper_AuthorNameLRClipping/Ablation_{plot_name}_{dataset_name}_{loss_name}Loss.pdf')


if __name__ == '__main__':
    rosenbrock_trajectory_plot()
    loss_evolution_plots()
    sensitivity_plots()
    hyperparameter_table()
    ablation_plots()
