import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib
import matplotlib.pyplot as plt


data_fc = pd.read_csv('./run-shuffle_topk_0995_normalFC-tag-evaluate_f1.csv')
data_random = pd.read_csv('./run-shuffle_topk_0995_MetaFE_random-tag-evaluate_f1.csv')
data_ours = pd.read_csv('./run-shuffle_topk_0995_Ours-tag-evaluate_f1.csv')

gisette_fc = pd.read_csv('./gisette-run-shuffle_topk_0995_fc-tag-evaluate_f1.csv')
gisette_random = pd.read_csv('./gisette-run-shuffle_topk_0995_MetaFE_random-tag-evaluate_f1 (1).csv')
gisette_ours = pd.read_csv('./gisette-run-shuffle_topk_0995_MetaFE-tag-evaluate_f1.csv')

# sns.lineplot(data=data_fc, x='Step', y='Value', hue="category", fill=True, palette="crest", alpha=.5, linewidth=1)


def load_data(pd_data, slice_: slice=None):
    x = np.array(pd_data['Step'])
    y = np.array(pd_data['Value'])
    if slice_ is None:
        return x, y
    else:
        res_x, res_y = [], []
        start_ = slice_.start if slice_.start is not None else 0
        stop_ = slice_.stop if slice_.stop is not None else x[-1] + 1
        for i, x_ in enumerate(x):
            if start_ <= x_ <= stop_:
                res_x.append(x[i])
                res_y.append(y[i])
        return np.array(res_x), np.array(res_y)


def cal_max(x_seq: list, alpha=0.9):
    _res = [x_seq[0]]
    for _x in x_seq[1:]:
        _res.append(max(_x, alpha * _res[-1] + (1 - alpha) * _x))
    return np.array(_res)


def cal_min(x_seq: list, alpha=0.9):
    _res = [x_seq[0]]
    for _x in x_seq[1:]:
        _res.append(min(_x, alpha * _res[-1] + (1 - alpha) * _x))
    return np.array(_res)


def cal_mean(x_seq: list, alpha=0.9):
    _res = [x_seq[0]]
    for _x in x_seq[1:]:
        _res.append(alpha * _res[-1] + (1 - alpha) * _x)
    return np.array(_res)


def re_map(x_seq, y_seq, to_x_min=0, to_x_max=1):
    # f(x_seq[0]) = to_x_min
    # f(x_seq[-1]) = to_x_max
    k = (to_x_max - to_x_min) / (x_seq[-1] - x_seq[0])
    f = lambda _x: k * _x + (to_x_min - x_seq[0] * k)
    return np.array([f(item) for item in x_seq]), y_seq


alpha = 0.6
# rc('text', usetex=True)

matplotlib.rcParams['font.family'] = 'serif'
#
# plt.figure(figsize=(8, 4))
# plt.subplot(121)
#
# x, y = load_data(data_ours, slice(None, 1300))
# ours, = plt.plot(x, y, color='red', alpha=0.5, label='FSS+MetaFE')
# plt.plot(x, cal_mean(y, alpha), color='red')
#
# x, y = load_data(data_fc, slice(None, 1300))
# fc, = plt.plot(x, y, color='blue', alpha=0.5, label='FSS+FC')
# plt.plot(x, cal_mean(y, alpha), color='blue')
#
# x, y = load_data(data_random, slice(None, 1300))
# rand, = plt.plot(x, y, color='green', alpha=0.5, label='Random+MetaFE')
# plt.plot(x, cal_mean(y, alpha), color='green')
#
# plt.ylabel('F1')
# plt.xlabel('Step')
# plt.title('Pre-train Stage')
# plt.ylim(ymin=0.5)
# plt.legend(handles=[ours, fc, rand])
# plt.grid(axis='y')
#
#
# plt.subplot(122)
#
# x, y = load_data(data_ours, slice(1500, 8500))
# ours, = plt.plot(x, y, color='red', alpha=0.5, label='FSS+MetaFE')
# plt.plot(x, cal_mean(y, alpha), color='red')
#
# x, y = load_data(data_fc, slice(1500, 8500))
# fc, = plt.plot(x, y, color='blue', alpha=0.5, label='FSS+FC')
# plt.plot(x, cal_mean(y, alpha), color='blue')
#
# x, y = load_data(data_random, slice(1500, 8500))
# rand, = plt.plot(x, y, color='green', alpha=0.5, label='Random+MetaFE')
# plt.plot(x, cal_mean(y, alpha), color='green')
#
# # plt.ylabel('F1')
# plt.xlabel('Step')
# plt.title('Collaborative Training Stage')
# plt.ylim(ymin=0.5)
# plt.legend(handles=[ours, fc, rand])
# plt.grid(axis='y')
#
#
# plt.tight_layout()
# plt.show()

##############################
# https://blog.csdn.net/CD_Don/article/details/88070453

plt.figure(figsize=(8, 5))
plt.subplot(211)

plt.axvspan(0, 1300, alpha=0.8, color='lightgray')
plt.axvspan(1300, 8500, alpha=0.8, color='lightyellow')

x, y = load_data(data_ours, slice(None, 8500))
plt.plot(x, y, color='indianred', alpha=0.5, label='FSS+MetaFE', linewidth=1.5)
ours, = plt.plot(x, cal_mean(y, alpha), label='FSS+MetaFE', color='indianred', linewidth=2)
x, y = load_data(data_fc, slice(None, 8500))
plt.plot(x, y, color='midnightblue', alpha=0.5, label='FSS+FC', linewidth=1.5)
fc, = plt.plot(x, cal_mean(y, alpha), label='FSS+FC',  color='midnightblue', linewidth=2)

x, y = load_data(data_random, slice(None, 8500))
plt.plot(x, y, color='darkgoldenrod', alpha=0.5, label='Uni+MetaFE', linewidth=1.5)
rand, = plt.plot(x, cal_mean(y, alpha), label='Uni+MetaFE', color='darkgoldenrod', linewidth=2)


plt.ylabel('F1')
# plt.xlabel('Step')
plt.title('QSAR Dataset', fontsize=14)

plt.ylim(ymin=0.45)
plt.xlim(xmin=0, xmax=8500)
plt.legend(handles=[ours, fc, rand])
plt.grid(axis='y')
plt.xticks([])
ax = plt.subplot(212)
plt.axvspan(0, 2950, alpha=0.8, color='lightgray')
plt.axvspan(2950, 8500, alpha=0.8, color='lightyellow')


x, y = load_data(gisette_fc, slice(None, None))
plt.plot(x, y, color='midnightblue', alpha=0.5, label='FSS+FC', linewidth=1.5)
fc, = plt.plot(x, cal_mean(y, alpha), label='FSS+FC',  color='midnightblue', linewidth=2)


x, y = load_data(gisette_ours, slice(None, 7450))
x, y = re_map(x, y, x[0], 2950)
plt.plot(x, y, color='indianred', alpha=0.5, label='FSS+MetaFE', linewidth=1.5)
ours, = plt.plot(x, cal_mean(y, alpha), label='FSS+MetaFE', color='indianred', linewidth=2)
x, y = load_data(gisette_ours, slice(7450, None))
x, y = re_map(x, y, 2950, 8178)
plt.plot(x, y, color='indianred', alpha=0.5, label='FSS+MetaFE', linewidth=1.5)
plt.plot(x, cal_mean(y, alpha), label='FSS+MetaFE', color='indianred', linewidth=2)


x, y = load_data(gisette_random, slice(None, 7450))
x, y = re_map(x, y, x[0], 2950)
plt.plot(x, y, color='darkgoldenrod', alpha=0.5, label='Uni+MetaFE', linewidth=1.5)
rand, = plt.plot(x, cal_mean(y, alpha), label='Uni+MetaFE', color='darkgoldenrod', linewidth=2)
x, y = load_data(gisette_random, slice(7450, None))
x, y = re_map(x, y, 2950, 8178)
plt.plot(x, y, color='darkgoldenrod', alpha=0.5, label='Uni+MetaFE', linewidth=1.5)
plt.plot(x, cal_mean(y, alpha), label='Uni+MetaFE', color='darkgoldenrod', linewidth=2)

plt.ylabel('F1')
# plt.xlabel('Step')
# plt.title('Collaborative Training Stage')
plt.ylim(ymin=0.6)
plt.xlim(xmin=0, xmax=8145)
plt.legend(handles=[ours, fc, rand])
plt.grid(axis='y')
plt.xticks([])
plt.title('Gisette Dataset', fontsize=14)
plt.tight_layout()
# plt.show()

ax = plt.subplot(211)
ax.annotate('Pre-train', (0.025, 0.05), xycoords='axes fraction', fontsize=13)
ax.annotate('Collaborative Learning', (0.4, 0.05), xycoords='axes fraction', fontsize=13)
ax = plt.subplot(212)
ax.annotate('Pre-train', (0.125, 0.05), xycoords='axes fraction', fontsize=13)
ax.annotate('Collaborative Learning', (0.55, 0.05), xycoords='axes fraction', fontsize=13)
#################################

# sns.set_style(style='darkgrid')
