
import numpy as np
import matplotlib.pyplot as plt
import pickle

import os
import torch



col = ['dodgerblue', "tab:orange", "mediumaquamarine", 'lightcoral', 'skyblue', 'sandybrown']
marker = ['*', 'x', 'o', 'd', 'v', 'h']

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"


if __name__ == '__main__':

    # plot meta learning
    # =============================================================================
    # method = ['cg', 'ns', 'ad']
    #
    # scale = 0.01
    # plt.rcParams["figure.figsize"] = (9,7)
    # plt.figure(1)
    # # plt.yscale("log")
    # for idx, m in enumerate(method):
    #     with open(os.path.join('results', 'RHGD_meta_'+ m + '.pickle'), "rb") as input_file1:
    #         rhgd = pickle.load(input_file1)
    #         rhgd_time = rhgd['runtime']
    #         rhgd_acc = rhgd['test_acc_mean']
    #         rhgd_std = rhgd['test_acc_std']
    #
    #     plt.plot(rhgd_time, rhgd_acc, label='RHGD-'+m, color=col[idx], marker=marker[idx], linewidth=2.4, markersize=2.4)
    #     # plt.fill_between(rhgd_time, rhgd_acc - scale*rhgd_std, rhgd_acc + scale*rhgd_std, alpha=0.3, fc=col[idx])
    #
    #     with open( os.path.join('results','PHGD_meta_'+ m + '.pickle'), "rb") as input_file2:
    #         phgd = pickle.load(input_file2)
    #         phgd_time = phgd['runtime']
    #         phgd_acc = phgd['test_acc_mean']
    #         phgd_std = phgd['test_acc_std']
    #
    #     plt.plot(phgd_time, phgd_acc, label='PHGD-' + m, color=col[idx], linestyle='--', linewidth=2.4)
    #     # plt.fill_between(phgd_time, phgd_acc - scale*phgd_std, phgd_acc + scale*phgd_std, alpha=0.3, fc=col[idx])
    #
    # plt.legend(loc='lower right', fontsize=21)
    # plt.xlabel("Runtime", fontsize=25)
    # plt.xticks(fontsize=15)
    # plt.ylabel("Test Accuracy", fontsize=25)
    # plt.yticks(fontsize=15)
    # plt.savefig(os.path.join('figs', 'meta_learning_time.pdf'), bbox_inches='tight')



    # method = ['cg', 'ns', 'ad']
    #
    # scale = 0.01
    # plt.rcParams["figure.figsize"] = (9, 7)
    # plt.figure(1)
    # # plt.yscale("log")
    # for idx, m in enumerate(method):
    #     with open(os.path.join('results', 'RHGD_meta_' + m + "_1" + '.pickle'), "rb") as input_file1:
    #         rhgd = pickle.load(input_file1)
    #         rhgd_time = rhgd['runtime']
    #         rhgd_acc = rhgd['test_acc_mean']
    #         rhgd_std = rhgd['test_acc_std']
    #
    #     plt.plot(rhgd_time, rhgd_acc, label='RHGD-' + m, color=col[idx], marker=marker[idx], linewidth=2.4,
    #              markersize=2.4)
    #     # plt.fill_between(rhgd_time, rhgd_acc - scale*rhgd_std, rhgd_acc + scale*rhgd_std, alpha=0.3, fc=col[idx])
    #
    #     with open(os.path.join('results', 'PHGD_meta_' + m + "_1" + '.pickle'), "rb") as input_file2:
    #         phgd = pickle.load(input_file2)
    #         phgd_time = phgd['runtime']
    #         phgd_acc = phgd['test_acc_mean']
    #         phgd_std = phgd['test_acc_std']
    #
    #     plt.plot(phgd_time, phgd_acc, label='PHGD-' + m, color=col[idx], linestyle='--', linewidth=2.4)
    #     # plt.fill_between(phgd_time, phgd_acc - scale*phgd_std, phgd_acc + scale*phgd_std, alpha=0.3, fc=col[idx])
    #
    # plt.legend(loc='lower right', fontsize=21)
    # plt.xlabel("Runtime", fontsize=25)
    # plt.xticks(fontsize=15)
    # plt.ylabel("Test Accuracy", fontsize=25)
    # plt.yticks(fontsize=15)
    # plt.savefig(os.path.join('figs', 'meta_learning_time_1.pdf'), bbox_inches='tight')
    # # =============================================================================


    # plot syn
    # ============================================================================
    method = ['cg', 'ns', 'ad', 'hinv']
    runtime_base = []
    epochs_base = []
    hg_error_base = []
    loss_upper_base = []

    runtime_liter50 = []
    epochs_liter50 = []
    hg_error_liter50 = []
    loss_upper_liter50 = []
    for idx, m in enumerate(method):
        with open(os.path.join('results', 'syn_' + m + '_base' + '.pickle'), "rb") as input_file1:
            rhgd = pickle.load(input_file1)
            runtime_base.append( np.cumsum(rhgd['runtime']) )
            epochs_base.append(rhgd['epochs'])
            hg_error_base.append(rhgd['hg_error'])
            loss_upper_base.append(-rhgd['loss_upper'])

        with open(os.path.join('results', 'syn_' + m + '_liter50' + '.pickle'), "rb") as f:
            rhgd = pickle.load(f)
            runtime_liter50.append( np.cumsum(rhgd['runtime']) )
            epochs_liter50.append(rhgd['epochs'])
            hg_error_liter50.append( rhgd['hg_error'] )
            loss_upper_liter50.append(-rhgd['loss_upper'])



    # loss outer vs eps
    plt.rcParams["figure.figsize"] = (8, 6)
    plt.figure(1)
    # plt.yscale("log")
    for idx,m in enumerate(method):
        plt.plot(epochs_base[idx], loss_upper_base[idx], label=m + '-20', color=col[idx], marker=marker[idx], linewidth=2.4,
                 markersize=2.4)
        plt.plot(epochs_base[idx], loss_upper_liter50[idx], label=m + '-50', color=col[idx], linewidth=2.4, linestyle='--')
    ax = plt.gca()
    # ax.set_ylim([0.05, 1])
    plt.legend(loc='lower right', fontsize=21)
    plt.xlabel("Epochs", fontsize=25)
    plt.xticks(fontsize=15)
    plt.ylabel("Upper Objective", fontsize=25)
    plt.yticks(fontsize=15)
    # plt.show()
    plt.savefig(os.path.join('figs', 'syn_loss_liter_ep.pdf'), bbox_inches='tight')

    # loss outer vs time
    plt.rcParams["figure.figsize"] = (8, 6)
    plt.figure(2)
    # plt.yscale("log")
    for idx, m in enumerate(method):
        plt.plot(runtime_base[idx], loss_upper_base[idx], label=m + '-20', color=col[idx], marker=marker[idx],
                 linewidth=2.4,
                 markersize=2.4)
        plt.plot(runtime_liter50[idx], loss_upper_liter50[idx], label=m + '-50', color=col[idx], linewidth=2.4,
                 linestyle='--')
    # plt.legend(loc='lower right', fontsize=21)
    plt.xlabel("Time", fontsize=25)
    plt.xticks(fontsize=15)
    plt.ylabel("Upper Objective", fontsize=25)
    plt.yticks(fontsize=15)
    plt.savefig(os.path.join('figs', 'syn_loss_liter_time.pdf'), bbox_inches='tight')
    #
    # # error vs eps
    # method = ['cg', 'ns', 'ad']
    # plt.rcParams["figure.figsize"] = (8,6)
    # plt.figure(3)
    # plt.yscale("log")
    # for idx, m in enumerate(method):
    #     plt.plot(epochs_base[idx], hg_error_base[idx], label=m+'-20', color=col[idx], marker=marker[idx], linewidth=2.4, markersize=2.4)
    #     plt.plot(epochs_base[idx], hg_error_liter50[idx], label=m+'-50', color=col[idx], linewidth=2.4, linestyle='--')
    # # plt.legend(loc='upper right', fontsize=21)
    # plt.xlabel("Epochs", fontsize=25)
    # plt.xticks(fontsize=15)
    # plt.ylabel("Hypergrad error", fontsize=25)
    # plt.yticks(fontsize=15)
    # plt.savefig(os.path.join('figs', 'syn_hg_error_liter_ep.pdf'), bbox_inches='tight')
    #
    # # error vs time
    # plt.rcParams["figure.figsize"] = (8, 6)
    # plt.figure(4)
    # plt.yscale("log")
    # for idx, m in enumerate(method):
    #     plt.plot(runtime_base[idx], hg_error_base[idx], label=m + '-20', color=col[idx], marker=marker[idx], linewidth=2.4,
    #              markersize=2.4)
    #     plt.plot(runtime_liter50[idx], hg_error_liter50[idx], label=m + '-50', color=col[idx], linewidth=2.4, linestyle='--')
    # # plt.legend(loc='upper right', fontsize=21)
    # plt.xlabel("Runtime", fontsize=25)
    # plt.xticks(fontsize=15)
    # plt.ylabel("Hypergrad error", fontsize=25)
    # plt.yticks(fontsize=15)
    # plt.savefig(os.path.join('figs','syn_hg_error_liter_time.pdf'), bbox_inches='tight')


    ### ablation for ns
    # result_folder = 'results'
    # with open(os.path.join(result_folder, "syn_ns_30_1.0.pickle"), "rb") as f:
    #     rhgd = pickle.load(f)
    #     ns_301_epochs = rhgd['epochs']
    #     ns_301_error = rhgd['hg_error']
    # with open(os.path.join(result_folder, "syn_ns_liter50.pickle"), "rb") as f:
    #     rhgd = pickle.load(f)
    #     ns_501_epochs = rhgd['epochs']
    #     ns_501_error = rhgd['hg_error']
    # with open(os.path.join(result_folder, "syn_ns_100_1.0.pickle"), "rb") as f:
    #     rhgd = pickle.load(f)
    #     ns_1001_epochs = rhgd['epochs']
    #     ns_1001_error = rhgd['hg_error']
    # with open(os.path.join(result_folder, "syn_ns_200_1.0.pickle"), "rb") as f:
    #     rhgd = pickle.load(f)
    #     ns_2001_epochs = rhgd['epochs']
    #     ns_2001_error = rhgd['hg_error']
    # with open(os.path.join(result_folder, "syn_ns_200_0.5.pickle"), "rb") as f:
    #     rhgd = pickle.load(f)
    #     ns_20005_epochs = rhgd['epochs']
    #     ns_20005_error = rhgd['hg_error']
    # with open(os.path.join(result_folder, "syn_ns_200_2.pickle"), "rb") as f:
    #     rhgd = pickle.load(f)
    #     ns_2002_epochs = rhgd['epochs']
    #     ns_2002_error = rhgd['hg_error']
    #
    # plt.rcParams["figure.figsize"] = (8, 6)
    # plt.figure(5)
    # plt.yscale("log")
    #
    # plt.plot(ns_301_epochs, ns_301_error, label='30-1.0', color=col[0], marker=marker[0], linewidth=2.4,
    #          markersize=2.4)
    # plt.plot(ns_501_epochs, ns_501_error, label= '50-1.0', color=col[1], marker=marker[1], linewidth=2.4,
    #          markersize=2.4)
    # plt.plot(ns_1001_epochs, ns_1001_error, label='100-1.0', color=col[2], marker=marker[2], linewidth=2.4,
    #          markersize=2.4)
    # plt.plot(ns_2001_epochs, ns_2001_error, label='200-1.0', color=col[3], marker=marker[3], linewidth=2.4,
    #          markersize=2.4)
    # plt.plot(ns_20005_epochs, ns_20005_error, label='200-0.5', color=col[4], marker=marker[4], linewidth=2.4,
    #          markersize=2.4)
    # plt.plot(ns_2002_epochs, ns_2002_error, label='200-2.0', color=col[5], marker=marker[5], linewidth=2.4,
    #          markersize=2.4)
    # plt.legend(loc='upper right', fontsize=21)
    # plt.xlabel("Epochs", fontsize=25)
    # plt.xticks(fontsize=15)
    # plt.ylabel("Hypergrad error", fontsize=25)
    # plt.yticks(fontsize=15)
    # plt.savefig(os.path.join('figs','syn_hg_error_ns_aba.pdf'), bbox_inches='tight')

    # ===========================================================================


    # shallow hyrep SPD
    #============================================================================
    # method = ['cg', 'ns', 'ad', 'hinv']
    # runtime_full = []
    # epochs_full  = []
    # hg_error_full  = []
    # loss_upper_full  = []
    #
    # runtime_stoc = []
    # epochs_stoc = []
    # hg_error_stoc = []
    # loss_upper_stoc = []
    # for idx, m in enumerate(method):
    #     with open(os.path.join('results', 'shallow_hyrep_' + m + '_full' + '.pickle'), "rb") as input_file1:
    #         rhgd = pickle.load(input_file1)
    #         runtime_full.append( np.cumsum(rhgd['runtime']) )
    #         epochs_full.append(rhgd['epochs'])
    #         hg_error_full.append(rhgd['hg_error'])
    #         loss_upper_full.append(rhgd['loss_upper'])
    #
    #     with open(os.path.join('results', 'shallow_hyrep_' + m + '_sto' + '.pickle'), "rb") as f:
    #         rhgd = pickle.load(f)
    #         runtime_stoc.append( np.cumsum(rhgd['runtime']) )
    #         epochs_stoc.append(rhgd['epochs'])
    #         hg_error_stoc.append( rhgd['hg_error'] )
    #         loss_upper_stoc.append(rhgd['loss_upper'])
    #
    #     plt.rcParams["figure.figsize"] = (8, 6)
    #     plt.figure(1)
    #     plt.yscale("log")
    #     plt.plot(epochs_full[-1], loss_upper_full[-1], label= 'RHGD-' + m, color=col[idx], marker=marker[idx],
    #              linewidth=2.4, markersize=2.4)
    #     plt.plot(epochs_stoc[-1], loss_upper_stoc[-1], label= 'RSHGD-' + m, color=col[idx],
    #              linewidth=2.4, linestyle='--')
    #
    # plt.legend(loc='lower left', fontsize=21)
    # plt.xlabel("Epochs", fontsize=25)
    # plt.xticks(fontsize=15)
    # plt.ylabel("Upper loss", fontsize=25)
    # plt.yticks(fontsize=15)
    # plt.savefig(os.path.join('figs','hyrep_spd_loss.pdf'), bbox_inches='tight')


    # OT domain adaptation
    # with open(os.path.join('da_3_2' + '.pickle'), "rb") as f:
    #     rhgd = pickle.load(f)
    #     rhgd

    # deep hyrep SPD ad
    # method = ['cg', 'ns', 'ad']
    # plt.rcParams["figure.figsize"] = (8, 6)
    # plt.figure(1)
    # # plt.yscale("log")
    # for idx, m in enumerate(method):
    #     with open(os.path.join('results', 'hyrep_spd_' + m + '.pickle'), "rb") as f:
    #         rhgd = pickle.load(f)
    #         runtime = torch.cumsum(rhgd['runtime'], dim=1)
    #         loss_upper = rhgd['loss_upper']
    #         accuracy = rhgd['accuracy']
    #
    #         acc_mean = torch.mean(accuracy, dim=0)
    #         acc_std = torch.std(accuracy, dim=0)
    #         runtime_mean = torch.mean(runtime, dim=0)
    #
    #         plt.plot(runtime_mean, acc_mean, label= m, color=col[idx], marker=marker[idx], linewidth=2.4,
    #                  markersize=2.4)
    #         plt.fill_between(runtime_mean, acc_mean - acc_std, acc_mean + acc_std, alpha=0.3, fc=col[idx])
    #
    # plt.legend(loc='lower right', fontsize=21)
    # plt.xlabel("Time", fontsize=25)
    # plt.xticks(fontsize=15)
    # plt.ylabel("Val Accuracy", fontsize=25)
    # plt.yticks(fontsize=15)
    # plt.savefig(os.path.join('figs','deep_hyrep_spd_acc_time.pdf'), bbox_inches='tight')
    #
    # method = ['cg', 'ns', 'ad']
    # plt.rcParams["figure.figsize"] = (8, 6)
    # plt.figure(2)
    # # plt.yscale("log")
    # for idx, m in enumerate(method):
    #     with open(os.path.join('results', 'hyrep_spd_' + m + '.pickle'), "rb") as f:
    #         rhgd = pickle.load(f)
    #         runtime = torch.cumsum(rhgd['runtime'], dim=1)
    #         loss_upper = rhgd['loss_upper']
    #         accuracy = rhgd['accuracy']
    #
    #         acc_mean = torch.mean(accuracy, dim=0)
    #         acc_std = torch.std(accuracy, dim=0)
    #         runtime_mean = torch.mean(runtime, dim=0)
    #
    #         plt.plot(torch.arange(acc_mean.shape[0]), acc_mean, label=m, color=col[idx], marker=marker[idx], linewidth=2.4,
    #                  markersize=2.4)
    #         plt.fill_between(torch.arange(acc_mean.shape[0]), acc_mean - acc_std, acc_mean + acc_std, alpha=0.3, fc=col[idx])
    #
    # plt.legend(loc='lower right', fontsize=21)
    # plt.xlabel("Epochs", fontsize=25)
    # plt.xticks(fontsize=15)
    # plt.ylabel("Val Accuracy", fontsize=25)
    # plt.yticks(fontsize=15)
    # plt.savefig(os.path.join('figs', 'deep_hyrep_spd_acc_ep.pdf'), bbox_inches='tight')

    #============================================================================