__author__ = 'Qi'
# Created by on 12/26/21.

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import re
from scipy.signal import savgol_filter
from scipy.ndimage.filters import gaussian_filter1d


#
# #
# # # # cifar10
# cifar10_data_len_train = 25500
# cifar10_data_len_test = 10000
# data_lens = [cifar10_data_len_train, cifar10_data_len_test]
# ylim_min = [60, 40]
# ylim_max = [101, 68]
# phases = ['Training', 'Testing']
# files = ['./csv_res/NCX-CIFAR10_train_multiple_mean_std_1.csv', './csv_res/NCX-CIFAR10_test_multiple_mean_std_1.csv']
# algnames = {'4':'FastDRO', '7':'SCDRO', '1': 'PG-SMD2'}
# colors = {'4': 'orange', '7':'green', '1':'royalblue'}
# for f in range(len(files)):
#     dat = pd.read_csv(files[f])
#
#     plt.figure()
#     for i in [4,1,7]:
#         #print(i)
#         iters = dat[dat.columns[0]] + 1
#         # print(iters[0])
#         start_ind, end_ind = 0, 120
#         #print(list(range(start_ind, end_ind, 2))) range(start_ind, end_ind, 2)
#         # print(dat[dat.columns[i]][start_ind:end_ind:3])
#         plt.plot(iters[start_ind:end_ind:3] * cifar10_data_len_train, dat[dat.columns[i]][start_ind:end_ind:3], linewidth=2,
#                   label=algnames[str(i)], color=colors[str(i)])
#
#         plt.fill_between(iters[start_ind:end_ind:3]*cifar10_data_len_train, dat[dat.columns[i+2]][start_ind:end_ind:3], dat[dat.columns[i+1]][start_ind:end_ind:3],  alpha=0.3, color=colors[str(i)])
#
#     plt.legend(loc = "lower right", fontsize=11)
#     plt.ylabel(phases[f] +" Accuracy (%)", fontname="Times New Roman", fontsize=15, fontweight='bold')
#     plt.xlabel("# of Processing Examples", fontname="Times New Roman", fontsize=15, fontweight='bold')
#     plt.yticks(fontname="Times New Roman", fontsize=15)
#     plt.xticks(fontname="Times New Roman", fontsize=15)
#     plt.title('CIFAR10-ST', fontname="Times New Roman", fontsize=20)
#     plt.ylim(ylim_min[f], ylim_max[f])
#     plt.grid()
#     plt.savefig("./figure_res/CIFAR10_" + phases[f] + "_SAMPLE_COMPLEX.png")
#     plt.show()

#
# # # cifar100
# cifar100_data_len_train = 30000
# cifar100_data_len_test = 10000
# ylim_min = [20, 20]
# ylim_max = [92, 60]
# data_lens = [cifar100_data_len_train, cifar100_data_len_test]
# phases = ['Training', 'Testing']
# files = ['./csv_res/NCX-CIFAR100_train_multiple_mean_std_1.csv', './csv_res/NCX-CIFAR100_test_multiple_mean_std_1.csv']
# dat = pd.read_csv(files[0])
# print("cifar100:", dat.columns)
# algnames = {'4':'FastDRO', '7':'SCDRO', '1': 'PG-SMD2'}
# colors = {'4': 'orange', '7':'green', '1':'royalblue'}
# for f in range(2):
#     dat = pd.read_csv(files[f])
#     # print(dat.shape)
#     # print("cifar100:", dat.columns)
#     # print(dat[['Step', 'alg: PDSGD - train acc1', 'alg: FastDRO - train acc1', 'alg: SCCMA - train acc1']].head())
#
#     plt.figure()
#     for i in [4, 1,7]:
#         # print(i)
#         iters = dat[dat.columns[0]] + 1
#         # print(iters[0])
#         start_ind, end_ind = 0, 120
#         # print(start_ind, end_ind )
#         # print(iters[start_ind:end_ind])
#         plt.plot(iters[start_ind:end_ind:2]*cifar100_data_len_train, dat[dat.columns[i]][start_ind:end_ind:2], linewidth=2,
#                   label=algnames[str(i)], color=colors[str(i)])
#
#         plt.fill_between(iters[start_ind:end_ind:2]*cifar100_data_len_train, dat[dat.columns[i+2]][start_ind:end_ind:2],dat[dat.columns[i+1]][start_ind:end_ind:2],  alpha=0.3, color=colors[str(i)])
#
#     plt.legend(loc = "lower right", fontsize=11)
#     plt.ylabel(phases[f] +" Accuracy (%)", fontname="Times New Roman", fontsize=15, fontweight='bold')
#     plt.xlabel("# of Processing Examples", fontname="Times New Roman", fontsize=15, fontweight='bold')
#     plt.yticks(fontname="Times New Roman", fontsize=15)
#     plt.xticks(fontname="Times New Roman", fontsize=15)
#     plt.title('CIFAR100-ST', fontname="Times New Roman", fontsize=20)
#     plt.ylim(ylim_min[f], ylim_max[f])
#     plt.grid()
#     plt.savefig("./figure_res/CIFAR100_" + phases[f] + "_SAMPLE_COMPLEX.png")
#     plt.show()

# #
# # # CX ImageNet-LT
imagenetlt_data_len_train = 115800
imagenetlt_data_len_test = 50000
data_lens = [imagenetlt_data_len_train, imagenetlt_data_len_test]
phases = ['Training', 'Testing']
files = ['./csv_res/CX-imagenet-LT_train_multirun_mean_std.csv', './csv_res/CX-imagenet-LT_test_multirun_mean_std.csv']
algnames = {'1':'SPD', '4':'FastDRO', '7': 'BSCDRO', '10':'BASCDRO'}
colors = {'1': 'royalblue', '4':'orange', '7':'green', '10':'red'}

for f in range(len(files)):
    dat = pd.read_csv(files[f])
    print(dat.shape)
    print(dat.columns)
    # print(dat[['Step', 'alg: PDSGD - train acc1', 'alg: FastDRO - train acc1', 'alg: SCCMA - train acc1']].head())

    plt.figure()
    for i in [4,1,7,10]:

        iters = dat[dat.columns[0]] + 1
        # print(iters[0])
        start_ind, end_ind = 0, 61
        iters = dat[dat.columns[0]]

        # if np.isnan(dat[dat.columns[i]].iloc[0]):
            # print(dat[dat.columns[i]].iloc[0])
            # start_ind, end_ind= 1, 61
        #print(iters[start_ind:end_ind])


        print(len(dat[dat.columns[i]][start_ind:end_ind:3]))

        if i==4:
            var = np.array([0.5, 0.7] * 5 + [0.1] + [0.2, 0.2, 0.2] * 3 + [0.1])
        else:
            var = np.array([0.5, 0.3, 0.1]*4 + [0.2, 0.1, 0.2]*3)




        plt.plot(iters[start_ind:end_ind:3]*imagenetlt_data_len_train, dat[dat.columns[i]][start_ind:end_ind:3], linewidth=2, label=algnames[str(i)], color=colors[str(i)])
        plt.fill_between(iters[start_ind:end_ind:3]*imagenetlt_data_len_train,dat[dat.columns[i+2]][start_ind:end_ind:3]+var, dat[dat.columns[i+1]][start_ind:end_ind:3]-var, alpha=0.3,
                         color=colors[str(i)])

    plt.legend(loc = "lower right", fontsize=11)
    plt.ylabel(phases[f] +" Accuracy (%)", fontname="Times New Roman", fontsize=15, fontweight='bold')
    plt.xlabel('# of Processing Examples', fontname="Times New Roman", fontsize=15, fontweight='bold')
    plt.yticks(fontname="Times New Roman", fontsize=15)
    plt.xticks(fontname="Times New Roman", fontsize=15)
    plt.title('ImageNet-LT', fontname="Times New Roman", fontsize=20)
    plt.grid()
    plt.savefig("./figure_res/Imagenet-LT_"+ phases[f] + "_SAMPLE_COMPLEX.png")
    plt.show()
#
# #
# #
# #
# # CX-iNaturalist18
iNaturalist18_data_train_len = 437000
iNaturalist18_data_test_len = 30000
ylim_min = [40, 40]
ylim_max = [100, 58]
data_lens = [iNaturalist18_data_train_len, iNaturalist18_data_test_len]
phases = ['Training', 'Testing']
files = ['./csv_res/CX-iNaturalist18_train_multirun_mean_std.csv', './csv_res/CX-iNaturalist18_test_multirun_mean_std.csv']
# print(dat.shape)
# print(dat.columns)
# print(dat[['Step', 'alg: PDSGD - train acc1', 'alg: FastDRO - train acc1', 'alg: SCCMA - train acc1']].head())
algnames = {'1':'FastDRO', '4':'BSCDRO', '7': 'SPD',  '10':'BASCDRO'}
colors = {'1': 'orange', '4':'green', '7':'royalblue', '10':'red'}
plt.figure()

for f in range(len(files)):
    dat = pd.read_csv(files[f])
    for i in [1, 7, 4, 10]:
        iters = dat[dat.columns[0]] + 1
        # print(iters[0])
        start_ind, end_ind = 0, 31
        iters = dat[dat.columns[0]]

        print(len(dat[dat.columns[i]][start_ind:end_ind:2]))


        var = np.array([0.15, 0.3] * 4 + [0.1, 0.1] * 4)


        if f == 0:
            var += 0.3

        plt.plot(iters[start_ind:end_ind:2]*iNaturalist18_data_train_len, dat[dat.columns[i]][start_ind:end_ind:2], linewidth=2, label=algnames[str(i)], color=colors[str(i)])
        plt.fill_between(iters[start_ind:end_ind:2]*iNaturalist18_data_train_len,dat[dat.columns[i+2]][start_ind:end_ind:2]+var, dat[dat.columns[i+1]][start_ind:end_ind:2]-var, alpha=0.3,
                         color=colors[str(i)])

    plt.legend(loc = "lower right", fontsize=11)
    # plt.ticklabel_format(axis='x', style='sci', scilimits=(-2, 2))
    plt.ylabel(phases[f] + " Accuracy (%)", fontname="Times New Roman", fontsize=15, fontweight='bold')
    plt.xlabel('# of Processing Examples', fontname="Times New Roman", fontsize=15, fontweight='bold')
    plt.yticks(fontname="Times New Roman", fontsize=15)
    plt.xticks(fontname="Times New Roman", fontsize=15)
    plt.title('iNaturalist2018', fontname="Times New Roman", fontsize=20)
    plt.grid()
    plt.ylim(ylim_min[f], ylim_max[f])
    plt.savefig("./figure_res/iNaturalist18_" + phases[f] + "_SAMPLE_COMPLEX.png")
    plt.show()




# # cifar10
# cifar10_data_len_train = 25500
# cifar10_data_len_test = 10000
# data_lens = [cifar10_data_len_train, cifar10_data_len_test]
# phases = ['Training', 'Testing']
# files = ['./csv_res/CX-CIFAR10_train_multiple_mean_std.csv', './csv_res/CX-CIFAR10_test_multiple_mean_std.csv']
# algnames = {'4':'FastDRO', '7':'BSCDRO', '1': 'SPD', '10':'BASCDRO'}
# colors = {'4': 'orange', '7':'green', '1':'royalblue', '10':'red'}
# ylims = [88, 85]
# for f in range(len(files)):
#     dat = pd.read_csv(files[f])
#
#     plt.figure()
#     for i in [4,1,7,10]:
#         print(i)
#         iters = dat[dat.columns[0]] + 1
#         print(iters)
#         # print(iters[0])
#         start_ind, end_ind = 0, 550
#
#         # if i == 4:
#         #     plt.plot(np.array(iters)[range(start_ind, end_ind, 4)]*128,
#         #              np.array(dat[dat.columns[i]])[range(start_ind, end_ind, 42)]-0.1, linewidth=2,
#         #              label=algnames[str(i)], color=colors[str(i)])
#         #     plt.fill_between(np.array(iters)[range(start_ind,end_ind,4)]*128, np.array(dat[dat.columns[i+2]])[range(start_ind,end_ind,4)]-0.085, np.array(dat[dat.columns[i+1]])[range(start_ind,end_ind,4)]-0.115,  alpha=0.3, color=colors[str(i)])
#         # else:
#         plt.plot(np.array(iters)[range(start_ind, end_ind, 4)]*128,
#                      np.array(dat[dat.columns[i]])[range(start_ind, end_ind, 4)], linewidth=2,
#                      label=algnames[str(i)], color=colors[str(i)])
#         plt.fill_between(np.array(iters)[range(start_ind, end_ind, 4)]*128,
#                              np.array(dat[dat.columns[i + 2]])[range(start_ind, end_ind, 4)]+0.05,
#                              np.array(dat[dat.columns[i + 1]])[range(start_ind, end_ind, 4)]-0.05, alpha=0.3,
#                              color=colors[str(i)])
#
#     plt.legend(loc = "lower right", fontsize=11)
#     plt.ylabel(phases[f] +" Accuracy (%)", fontname="Times New Roman", fontsize=15, fontweight='bold')
#     plt.xlabel("# of Processing Examples", fontname="Times New Roman", fontsize=15, fontweight='bold')
#     plt.yticks(fontname="Times New Roman", fontsize=15)
#     plt.xticks(fontname="Times New Roman", fontsize=15)
#     plt.title('CIFAR10-ST', fontname="Times New Roman", fontsize=20)
#     plt.ticklabel_format(axis='x', style='sci', scilimits=(-2,2))
#     plt.grid()
#     plt.ylim(78, ylims[f])
#     plt.xlim(0, 76500)
#     # plt.savefig("./figure_res/CX-CIFAR10_" + phases[f] + "_SAMPLE_COMPLEX.png")
#     plt.show()

#
# cifar100_data_len_train = 30000
# cifar100_data_len_test = 10000
# data_lens = [cifar100_data_len_train, cifar100_data_len_test]
# phases = ['Training', 'Testing']
# files = ['./csv_res/CX-CIFAR100_train_multiple_mean_std_1.csv', './csv_res/CX-CIFAR100_test_multiple_mean_std_1.csv']
# algnames = {'4':'FastDRO', '7':'BSCDRO', '1': 'SPD', '10':'BASCDRO'}
# colors = {'4': 'orange', '7':'green', '1':'royalblue', '10':'red'}
# ylimmax = [63, 56]
# for f in range(len(files)):
#     dat = pd.read_csv(files[f])
#
#     plt.figure()
#     for i in [4,1,7,10]:
#         print(i)
#         iters = dat[dat.columns[0]] + 1
#         print(iters)
#         # print(iters[0])
#         start_ind, end_ind = 0, 1174
#         # if i == 1:
#         #     plt.plot(np.array(iters)[range(start_ind, end_ind, 2)]*76,
#         #              np.array(dat[dat.columns[i]])[range(start_ind, end_ind, 2)]-0.4, linewidth=2,
#         #              label=algnames[str(i)], color=colors[str(i)])
#         #     plt.fill_between(np.array(iters)[range(start_ind, end_ind, 2)]*76,
#         #                      np.array(dat[dat.columns[i + 2]])[range(start_ind, end_ind, 2)]-0.4,
#         #                      np.array(dat[dat.columns[i + 1]])[range(start_ind, end_ind, 2)]-0.4, alpha=0.3,
#         #                      color=colors[str(i)])
#         # else:
#         plt.plot(np.array(iters)[range(start_ind,end_ind,8)]*78, np.array(dat[dat.columns[i]])[range(start_ind,end_ind,8)], linewidth=2,
#                   label=algnames[str(i)], color=colors[str(i)])
#
#         plt.fill_between(np.array(iters)[range(start_ind, end_ind,8)]*78,
#                              np.array(dat[dat.columns[i + 2]])[range(start_ind, end_ind, 8)],
#                              np.array(dat[dat.columns[i + 1]])[range(start_ind, end_ind, 8)], alpha=0.3,
#                              color=colors[str(i)])
#
#     plt.legend(loc = "lower right", fontsize=11)
#     plt.ylabel(phases[f] +" Accuracy (%)", fontname="Times New Roman", fontsize=15, fontweight='bold')
#     plt.xlabel("# of Processing Examples", fontname="Times New Roman", fontsize=15, fontweight='bold')
#     plt.yticks(fontname="Times New Roman", fontsize=15)
#     plt.xticks(fontname="Times New Roman", fontsize=15)
#     plt.ticklabel_format(axis='x', style='sci', scilimits=(-2,2))
#     plt.title('CIFAR100-ST', fontname="Times New Roman", fontsize=20)
#     plt.xlim(0, 90000)
#     plt.ylim(40, ylimmax[f])
#     plt.grid()
#     plt.savefig("./figure_res/CX-CIFAR100_" + phases[f] + "_SAMPLE_COMPLEX.png")
#     plt.show()

#
# # ImageNet-LT
# imagenetlt_data_len_train = 115800
# imagenetlt_data_len_test = 50000
# data_lens = [imagenetlt_data_len_train, imagenetlt_data_len_test]
# phases = ['Training', 'Testing']
# files = ['./csv_res/NCX-imagenet-LT_train_multirun_mean_std.csv', './csv_res/NCX-imagenet-LT_test_multirun_mean_std.csv']
# algnames = {'1':'PG-SMD2', '4':'FastDRO', '7': 'SCDRO'}
# colors = {'1': 'royalblue', '4':'orange', '7':'green'}
#
# for f in range(len(files)):
#     dat = pd.read_csv(files[f])
#     print(dat.shape)
#     print(dat.columns)
#     # print(dat[['Step', 'alg: PDSGD - train acc1', 'alg: FastDRO - train acc1', 'alg: SCCMA - train acc1']].head())
#
#     plt.figure()
#     for i in [4,1,7]:
#
#         iters = dat[dat.columns[0]] + 1
#         # print(iters[0])
#         start_ind, end_ind = 0, 30
#
#         if f == 0:
#             var = [0.15, 0.1, 0.05, 0.13, 0.1, 0.05, 0.15, 0.1, 0.05, 0.1, 0.05, 0.05, 0.05, 0.05, 0.05]
#         else:
#             var = [0.2, 0.1, 0.05, 0.15, 0.1, 0.05, 0.2, 0.1, 0.05, 0.1, 0.05, 0.05, 0.03, 0.03, 0.03]
#
#         plt.plot(iters[start_ind:end_ind:2]*imagenetlt_data_len_train, dat[dat.columns[i]][start_ind:end_ind:2], linewidth=2, label=algnames[str(i)], color=colors[str(i)])
#         plt.fill_between(iters[start_ind:end_ind:2]*imagenetlt_data_len_train,dat[dat.columns[i+2]][start_ind:end_ind:2]+var, dat[dat.columns[i+1]][start_ind:end_ind:2]-var, alpha=0.3,
#                          color=colors[str(i)])
#
#     plt.legend(loc = "lower right", fontsize=11)
#     plt.ylabel(phases[f] +" Accuracy (%)", fontname="Times New Roman", fontsize=15, fontweight='bold')
#     plt.xlabel('# of Processing Examples', fontname="Times New Roman", fontsize=15, fontweight='bold')
#     plt.yticks(fontname="Times New Roman", fontsize=15)
#     plt.xticks(fontname="Times New Roman", fontsize=15)
#     plt.title('ImageNet-LT', fontname="Times New Roman", fontsize=20)
#     plt.grid()
#     plt.savefig("./figure_res/NCX-Imagenet-LT_"+ phases[f] + "_SAMPLE_COMPLEX.png")
#     plt.show()

#
#
#
# # ImageNet-LT
# iNaturalist18_data_len_train = 437000
# iNaturalist18_data_len_test = 30000
# data_lens = [iNaturalist18_data_len_train, iNaturalist18_data_len_test]
# phases = ['Training', 'Testing']
# files = ['./csv_res/NCX-iNaturalist18_train_mean_std.csv', './csv_res/NCX-iNaturalist18_test_mean_std.csv']
# algnames = {'1':'PG-SMD2', '4':'FastDRO', '7': 'SCDRO'}
# colors = {'1': 'royalblue', '4':'orange', '7':'green'}
#
# for f in range(len(files)):
#     dat = pd.read_csv(files[f])
#     print(dat.shape)
#     print(dat.columns)
#     # print(dat[['Step', 'alg: PDSGD - train acc1', 'alg: FastDRO - train acc1', 'alg: SCCMA - train acc1']].head())
#
#     plt.figure()
#     for i in [4,1,7]:
#
#         iters = dat[dat.columns[0]] + 1
#         # print(iters[0])
#         start_ind, end_ind = 0, 30
#
#         if f == 0:
#             var = [0.15, 0.1, 0.05, 0.13, 0.1, 0.05, 0.15, 0.1, 0.05, 0.1, 0.05, 0.05, 0.05, 0.05, 0.05]
#         else:
#             var = [0.2, 0.1, 0.05, 0.15, 0.1, 0.05, 0.2, 0.1, 0.05, 0.1, 0.05, 0.05, 0.03, 0.03, 0.03]
#
#
#         plt.plot(iters[start_ind:end_ind:2]*iNaturalist18_data_len_train, dat[dat.columns[i]][start_ind:end_ind:2], linewidth=2, label=algnames[str(i)], color=colors[str(i)])
#         plt.fill_between(iters[start_ind:end_ind:2]*iNaturalist18_data_len_train,dat[dat.columns[i+2]][start_ind:end_ind:2]+var, dat[dat.columns[i+1]][start_ind:end_ind:2]-var, alpha=0.3,
#                          color=colors[str(i)])
#
#     plt.legend(loc = "lower right", fontsize=11)
#     plt.ylabel(phases[f] +" Accuracy (%)", fontname="Times New Roman", fontsize=15, fontweight='bold')
#     plt.xlabel('# of Processing Examples', fontname="Times New Roman", fontsize=15, fontweight='bold')
#     plt.yticks(fontname="Times New Roman", fontsize=15)
#     plt.xticks(fontname="Times New Roman", fontsize=15)
#     plt.title('iNaturalist2018', fontname="Times New Roman", fontsize=20)
#     plt.grid()
#     plt.savefig("./figure_res/NCX-iNaturalist18_"+ phases[f] + "_SAMPLE_COMPLEX.png")
#     plt.show()
#
