import sys
import json
from pathlib import Path

import numpy as np

import matplotlib.pyplot as plt


def main():
    datasets = ['adult', 'fico', 'higgs', 'magic', 'stumbleupon']
    for dataset in datasets:
        path = Path(f'results/tabular_{dataset}_scale_logp_split_none_calibrate_False.json')
        with open(path, 'r') as f:
            data = json.load(f)
        n_trees = data['n_trees']
        metric = data['metric']
        scratch_classifier = np.array(data['scratch_classifier'])
        scratch_regressor = np.array(data['scratch_regressor'])
        kd_mse = np.array(data['kd_mse'])
        teacher_result = np.array(data['teacher_result'])
        path = Path(f'results/tabular_{dataset}_scale_p_split_crossfit_calibrate_False.json')
        with open(path, 'r') as f:
            data = json.load(f)
        kd_mse_crossfit = np.array(data['kd_mse'])
        kd_relerr = np.array(data['kd_relerr'])
        kd_boundfast = np.array(data['kd_boundfast'])
        teacher_result_crossfit = np.array(data['teacher_result'])

        plt.figure(figsize=(6, 3.5))
        # plt.errorbar(n_trees, np.mean(scratch_classifier, 0), np.std(scratch_classifier, 0),
        #             label='From scratch, classifier')
        # plt.errorbar(n_trees, np.mean(scratch_regressor, 0), np.std(scratch_regressor, 0),
        plt.plot(n_trees, np.mean(scratch_regressor, 0), 'b', label='From scratch')
        plt.fill_between(n_trees, np.mean(scratch_regressor, 0) - np.std(scratch_regressor, 0),
                         np.mean(scratch_regressor, 0) + np.std(scratch_regressor, 0), color='b', alpha=0.1)
        # plt.errorbar(n_trees, np.mean(kd_mse, 0), np.std(kd_mse, 0), label='KD, no cross-fitting')
        plt.plot(n_trees, np.mean(kd_mse, 0), 'c', label='KD, no cross-fitting')
        plt.fill_between(n_trees, np.mean(kd_mse, 0) - np.std(kd_mse, 0),
                         np.mean(kd_mse, 0) + np.std(kd_mse, 0), color='c', alpha=0.1)
        # plt.errorbar(n_trees, np.mean(kd_mse_crossfit, 0), np.std(kd_mse_crossfit, 0), label='KD, w/ cross-fitting')
        plt.plot(n_trees, np.mean(kd_mse_crossfit, 0), 'g', label='KD, w/ cross-fitting')
        plt.fill_between(n_trees, np.mean(kd_mse_crossfit, 0) - np.std(kd_mse_crossfit, 0),
                         np.mean(kd_mse_crossfit, 0) + np.std(kd_mse_crossfit, 0), color='g', alpha=0.1)
        # plt.errorbar(n_trees, np.mean(kd_relerr, 0), np.std(kd_relerr, 0), label='KD, relative error, w/ crossfitting')
        # plt.errorbar(n_trees, np.mean(kd_boundfast, 0), np.std(kd_boundfast, 0), label=r'$\gamma$-corrected KD, w/ crossfitting')
        plt.plot(n_trees, np.mean(kd_boundfast, 0), 'r', label=r'$\gamma$-corrected KD, w/ crossfitting')
        plt.fill_between(n_trees, np.mean(kd_boundfast, 0) - np.std(kd_boundfast, 0),
                         np.mean(kd_boundfast, 0) + np.std(kd_boundfast, 0), color='r', alpha=0.1)
        # plt.errorbar(n_trees, [np.mean(teacher_result)] * len(n_trees),
        #             [np.std(teacher_result)] * len(n_trees), linestyle='--', label='Teacher (500 trees)')
        plt.plot(n_trees, [np.mean(teacher_result)] * len(n_trees), 'm--', label='Teacher (500 trees)')
        plt.fill_between(n_trees, [np.mean(teacher_result) - np.std(teacher_result)] * len(n_trees),
                         [np.mean(teacher_result) + np.std(teacher_result)] * len(n_trees), color='m', alpha=0.1)
        # plt.errorbar(n_trees, [np.mean(teacher_result_crossfit)] * len(n_trees),
        #             [np.std(teacher_result_crossfit)] * len(n_trees), linestyle='--', label='Teacher, w/ crossfitting')
        plt.xticks(n_trees)
        plt.ylabel(f'Test {metric}', fontsize=14)
        plt.xlabel("Student's number of trees", fontsize=14)
        plt.legend()
        plot_path = Path(f'results/tabular_{dataset}_crossfit.pdf')
        plt.savefig(str(plot_path), bbox_inches='tight')
        plt.close()

if __name__ == '__main__':
    main()
