import numpy as np
import os
import pickle
import matplotlib.pyplot as plt

result_dir = '/Users/bowenlei/Desktop/study/research/Calibration/code/rigl_make/stats/'
save_dir = '/Users/bowenlei/Desktop/study/research/Calibration/code/rigl_make/img/'

def organize_stats(data):
    sp_seq = [0,0.5,0.6,0.7,0.8,0.9,0.95,0.99,0.995,0.9975]
    outputs = np.zeros((7,10))

    #data = 'CIFAR10'
    model = 'resnet50'

    i = 0
    for sp in sp_seq:
        sp = str(sp)
        if len(sp) > 6:
            sp = sp[:6]
        sp = sp.split('.')
        if len(sp) == 2:
            sp = sp[0] + '_' + sp[1]
        else:
            sp = sp[0]

        f_name = 'train' + "_" + data + "_" + model + '_' + sp + '.pickle'
        loss, top_1_accuracy, top_5_accuracy, ece = pickle.load(open(result_dir + f_name, 'rb'))
        f_name = 'test' + "_" + data + "_" + model + '_' + sp + '.pickle'
        loss_te, top_1_accuracy_te, top_5_accuracy_te, ece_te = pickle.load(open(result_dir + f_name, 'rb'))

        outputs[0,i] = ece_te
        outputs[1,i] = top_1_accuracy_te
        outputs[2,i] = top_1_accuracy
        outputs[3,i] = top_5_accuracy_te
        outputs[4,i] = top_5_accuracy
        outputs[5,i] = loss_te
        outputs[6,i] = loss

        print(i)
        i += 1
    outputs = np.round(outputs, 4)
    np.savetxt(result_dir + data + '_' + model + '.csv', outputs, delimiter=",")

    idx = list(range(len(sp_seq)))
    ece = outputs[0]
    acc_top1_te = outputs[1]
    acc_top1_tr = outputs[2]
    acc_top1_diff = [acc_top1_tr[i] - acc_top1_te[i] for i in range(len(idx))]

    acc_top5_te = outputs[3]
    acc_top5_tr = outputs[4]
    acc_top5_diff = [acc_top5_tr[i] - acc_top5_te[i] for i in range(len(idx))]

    loss_te = outputs[5]
    loss_tr = outputs[6]
    loss_diff = [loss_te[i] - loss_tr[i] for i in range(len(idx))]

    plt.figure(figsize=(6,8))
    plt.subplot(2, 1, 1)
    plt.plot(idx, acc_top1_te, marker='o', label='top 1 acc test')
    plt.plot(idx, acc_top1_tr, marker='o', label='top 1 acc train')
    plt.xticks(idx, sp_seq)
    plt.legend()
    plt.title(data + ' top 1 accuracy')
    plt.subplot(2, 1, 2)
    plt.plot(idx, acc_top1_diff, marker='o', label='top 1 acc diff')
    plt.xticks(idx, sp_seq)
    plt.legend()
    plt.title(data + ' top 1 accuracy')
    plt.savefig(save_dir + '_' + data + '_top1_acc.png')

    plt.figure(figsize=(6,8))
    plt.subplot(2, 1, 1)
    plt.plot(idx, acc_top5_te, marker='o', label='top 5 acc test')
    plt.plot(idx, acc_top5_tr, marker='o', label='top 5 acc train')
    plt.xticks(idx, sp_seq)
    plt.legend()
    plt.title(data + ' top 5 accuracy')
    plt.subplot(2, 1, 2)
    plt.plot(idx, acc_top5_diff, marker='o', label='top 5 acc diff')
    plt.xticks(idx, sp_seq)
    plt.legend()
    plt.title(data + ' top 5 accuracy')
    plt.savefig(save_dir + '_' + data + '_top5_acc.png')

    plt.figure(figsize=(6,8))
    plt.subplot(2, 1, 1)
    plt.plot(idx, loss_te, marker='o', label='loss test')
    plt.plot(idx, loss_tr, marker='o', label='loss train')
    plt.xticks(idx, sp_seq)
    plt.legend()
    plt.title(data + ' loss')
    plt.subplot(2, 1, 2)
    plt.plot(idx, loss_diff, marker='o', label='loss diff')
    plt.xticks(idx, sp_seq)
    plt.legend()
    plt.title(data + ' loss')
    plt.savefig(save_dir + '_' + data + '_loss.png')

    plt.figure(figsize=(6,4))
    plt.plot(idx, ece, marker='o', label='ece')
    plt.xticks(idx, sp_seq)
    plt.title('CIFAR-10 ECE')
    plt.savefig(save_dir + '_' + data + '_ece.png')


organize_stats('CIFAR10')
organize_stats('CIFAR100')


def organize_stats(data):
    sp_seq = [0,0.5,0.6,0.7,0.8,0.9,0.95,0.99,0.995,0.9975,0.999]
    outputs = np.zeros((7,10))

    #data = 'CIFAR10'
    model = 'resnet50'

    i = 0
    for sp in sp_seq:
        if sp != 0.999:
            sp = str(sp)
            if len(sp) > 6:
                sp = sp[:6]
            sp = sp.split('.')
            if len(sp) == 2:
                sp = sp[0] + '_' + sp[1]
            else:
                sp = sp[0]

            f_name = 'train' + "_" + data + "_" + model + '_' + sp + '.pickle'
            loss, top_1_accuracy, top_5_accuracy, ece = pickle.load(open(result_dir + f_name, 'rb'))
            f_name = 'test' + "_" + data + "_" + model + '_' + sp + '.pickle'
            loss_te, top_1_accuracy_te, top_5_accuracy_te, ece_te = pickle.load(open(result_dir + f_name, 'rb'))

            outputs[0,i] = ece_te
            outputs[1,i] = top_1_accuracy_te
            outputs[2,i] = top_1_accuracy
            outputs[3,i] = top_5_accuracy_te
            outputs[4,i] = top_5_accuracy
            outputs[5,i] = loss_te
            outputs[6,i] = loss

            print(i)
            i += 1
    outputs = np.round(outputs, 4)
    np.savetxt(result_dir + data + '_' + model + '.csv', outputs, delimiter=",")

    idx = list(range(len(sp_seq)))
    ece = list(outputs[0])
    ece.append(0.0490)

    plt.figure(figsize=(6,4))
    plt.plot(idx, ece[::-1], marker='o', label='ece')
    plt.xticks(idx, sp_seq[::-1])
    plt.title('CIFAR-10 ECE')
    plt.savefig(save_dir + '_' + data + '_ece_new.png')

organize_stats('CIFAR100')