import matplotlib.pyplot as plt
import numpy as np


tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),    
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),    
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),    
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),    
             (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]
for i in range(len(tableau20)):
	r, g, b = tableau20[i]
	tableau20[i] = (r / 255., g / 255., b / 255.)


def get_value(value_str):
	_, value = value_str.split(':')
	return float(value)


def get_bound(eta, bound_1, bound_2):
	c1 = eta / (1. - np.exp(-eta))
	c2 = 1. / (1. - np.exp(-eta))
	return c1 * bound_1 + c2 * bound_2


def get_info(fname, T):
	logs = open(fname, 'r').read().split('\n')[1:-1]
	d = {}
	t, tr_acc, te_acc, bound = [], [], [], []
	for line in logs:
		items = line.split(', ')
		t.append(get_value(items[0]))
		tr_acc.append(get_value(items[1]))
		te_acc.append(get_value(items[2]))
		bound_1 = get_value(items[-3])
		bound_2 = get_value(items[-2])
		if t[-1] > T:
			break
		bound.append(get_bound(1.5, bound_1, bound_2))

	bound = np.clip(bound, 0, 1)
	return {'t' : t, 'tr_acc': tr_acc, 'te_acc' : te_acc, 'bound' : bound}


def get_infos(k, task='', T=1500):
	infos = {'t' : [], 'tr_acc': [], 'te_acc' : [], 'bound' : []}
	for i in range(k):
		fname = f'log/{task}/{i}.out'
		info = get_info(fname, T)
		for key in info:
			infos[key].append(info[key])
	
	for key in infos:
		mean = np.array(infos[key]).mean(axis=0)
		std = np.array(infos[key]).std(axis=0)
		infos[key] = (mean, std)
	return infos


def plot_test_bound(infos, x_min=0, x_max=100, err_bar=False):
	x = infos['t'][0][x_min: x_max]
	tr_acc = infos['tr_acc'][0]
	te_err_mean = 1.- infos['te_acc'][0][x_min: x_max]
	te_bound_mean = infos['bound'][0][x_min: x_max]
	print ('bound:', te_bound_mean[-1], 'tr_acc:', tr_acc[-1], 'te_acc:', 1. - te_err_mean[-1])
	plt.figure()
	plt.plot(x, te_bound_mean, linewidth=2.0, color=tableau20[6])
	plt.plot(x, te_err_mean, linewidth=2.0, color=tableau20[9])
	
	if err_bar:
		te_err_std = infos['te_acc'][1][x_min: x_max]
		te_bound_std = infos['bound'][1][x_min: x_max]
		plt.fill_between(x, te_bound_mean - te_bound_std, te_bound_mean + te_bound_std, color = tableau20[6], alpha=0.1)
		plt.fill_between(x, te_err_mean - te_err_std, te_err_mean + te_err_std, color = tableau20[9], alpha=0.1)

	plt.xlabel('steps', fontsize=20)
	plt.ylim(0,)
	plt.tick_params(labelsize=20)
	plt.legend(['our bound', 'test error'], loc='upper left', fontsize=20)
	plt.grid(axis='both', linestyle=':')
	plt.show()
	plt.close()



def get_mgrad(k, task="mnist"):
	infos = {'m' : [], 'sum_grad2': []}
	for i in range(k):
		fname = f'log/{task}/mgrad/{i}.out'
		logs = open(fname, 'r').read().split('\n')[:-1]
		ms, sum_grad2s = [], []
		for line in logs:
			items = line.split(', ')
			m = get_value(items[0])
			sum_grad2 = get_value(items[1])
			ms.append(m)
			sum_grad2s.append(sum_grad2)
		infos['m'].append(ms)
		infos['sum_grad2'].append(sum_grad2s)

	for key in infos.keys():
		mean = np.mean(infos[key], axis=0)
		std = np.std(infos[key], axis=0)
		infos[key] = (mean, std)

	return infos


def plot_mgrad(infos):
	x = infos['m'][0]
	y_mean, y_std = infos['sum_grad2']
	plt.figure()
	plt.plot(x, y_mean, linewidth=3.0, color=tableau20[1])
	plt.fill_between(x, y_mean - y_std, y_mean + y_std, color = tableau20[5], alpha=0.2)

	plt.xlabel('m = |J|', fontsize=20)
	plt.tick_params(labelsize=20)
	plt.legend(['$\\sum_t \\left\\| \\nabla f(w_t, S) - \\nabla f(w_t, S_J)\\right\\|^2$'], fontsize=18.5)
	plt.grid(axis='both', linestyle=':')
	plt.show()
	plt.close()


def plot_tr_te(infos, x_min=0, err_bar=True):
	x = infos['t'][0][x_min:]
	tr_err_mean = 1. - infos['tr_acc'][0][x_min:]
	te_err_mean = 1.- infos['te_acc'][0][x_min:]
	plt.figure()
	plt.plot(x, tr_err_mean, linewidth=2.0, color=tableau20[5])
	plt.plot(x, te_err_mean, linewidth=2.0, color=tableau20[9])
	
	if err_bar:
		tr_err_std = infos['tr_acc'][1][x_min:]
		te_err_std = infos['te_acc'][1][x_min:]
		plt.fill_between(x, tr_err_mean - tr_err_std, tr_err_mean + tr_err_std, color = tableau20[5], alpha=0.1)
		plt.fill_between(x, te_err_mean - te_err_std, te_err_mean + te_err_std, color = tableau20[9], alpha=0.1)

	plt.xlabel('steps', fontsize=20)
	plt.tick_params(labelsize=20)
	plt.legend(['train error', 'test error'], fontsize=20)
	plt.grid(axis='both', linestyle=':')
	plt.show()
	plt.close()


def plot_fgd_gd_detail(fgd_infos, gd_infos, x_min, err_bar=True):
	x = fgd_infos['t'][0][x_min:]
	fgd_tr_err_mean = 1. - fgd_infos['tr_acc'][0][x_min:]
	fgd_te_err_mean = 1.- fgd_infos['te_acc'][0][x_min:]

	gd_tr_err_mean = 1. - gd_infos['tr_acc'][0][x_min:]
	gd_te_err_mean = 1.- gd_infos['te_acc'][0][x_min:]

	plt.figure()
	plt.plot(x, fgd_te_err_mean, linewidth=2.0, color=tableau20[9])
	plt.plot(x, gd_te_err_mean, linewidth=2.0, color=tableau20[10])
	plt.plot(x, fgd_tr_err_mean, linewidth=2.0, color=tableau20[5])
	plt.plot(x, gd_tr_err_mean, linewidth=2.0, color=tableau20[6])
	

	if err_bar:
		fgd_tr_err_std = fgd_infos['tr_acc'][1][x_min:]
		fgd_te_err_std = fgd_infos['te_acc'][1][x_min:]
		gd_tr_err_std = gd_infos['tr_acc'][1][x_min:]
		gd_te_err_std = gd_infos['te_acc'][1][x_min:]
		plt.fill_between(x, fgd_tr_err_mean - fgd_tr_err_std, fgd_tr_err_mean + fgd_tr_err_std, color = tableau20[5], alpha=0.15)
		plt.fill_between(x, fgd_te_err_mean - fgd_te_err_std, fgd_te_err_mean + fgd_te_err_std, color = tableau20[9], alpha=0.15)
		plt.fill_between(x, gd_tr_err_mean - gd_tr_err_std, gd_tr_err_mean + gd_tr_err_std, color = tableau20[6], alpha=0.15)
		plt.fill_between(x, gd_te_err_mean - gd_te_err_std, gd_te_err_mean + gd_te_err_std, color = tableau20[10], alpha=0.15)

	plt.legend(['FGD test error', 'GD test error', 'FGD train error', 'GD train error'], fontsize=18)
	plt.grid(axis='both', linestyle=':')
	plt.xlabel('steps', fontsize=20)
	plt.tick_params(labelsize=20)
	plt.show()
	plt.close()

def random_label(n = 100, err_bar=True):
	info00 = get_infos(n, task="mnist/fgd", T=1000)
	info01 = get_infos(n, task="mnist/random_label/0.1", T=1000)
	info02 = get_infos(n, task="mnist/random_label/0.2", T=1000)
	info10 = get_infos(n, task="mnist/random_label/1.0", T=1000)

	x = info00['t'][0]

	plt.figure()
	plt.plot(x, info00["tr_acc"][0], linewidth=2.0, color=tableau20[2])
	plt.plot(x, info01["tr_acc"][0], linewidth=2.0, color=tableau20[4])
	plt.plot(x, info02["tr_acc"][0], linewidth=2.0, color=tableau20[6])
	plt.plot(x, info10["tr_acc"][0], linewidth=2.0, color=tableau20[10])
	
	if err_bar:
		plt.fill_between(x, info00["tr_acc"][0] - info00["tr_acc"][1], info00["tr_acc"][0] + info00["tr_acc"][1], color = tableau20[2], alpha=0.1)
		plt.fill_between(x, info01["tr_acc"][0] - info01["tr_acc"][1], info01["tr_acc"][0] + info01["tr_acc"][1], color = tableau20[4], alpha=0.1)
		plt.fill_between(x, info02["tr_acc"][0] - info02["tr_acc"][1], info02["tr_acc"][0] + info02["tr_acc"][1], color = tableau20[6], alpha=0.1)
		plt.fill_between(x, info10["tr_acc"][0] - info10["tr_acc"][1], info10["tr_acc"][0] + info10["tr_acc"][1], color = tableau20[10], alpha=0.1)

	plt.xlabel('steps', fontsize=20)
	plt.ylabel('training accuracy', fontsize=20)
	plt.tick_params(labelsize=20)
	plt.legend(['p=0', 'p=0.1', 'p=0.2', 'p=1.0'], loc='upper left', fontsize=10)
	plt.grid(axis='both', linestyle=':')
	plt.show()
	plt.close()

	plt.figure()
	plt.plot(x, 1-info00["te_acc"][0], linewidth=2.0, color=tableau20[2])
	plt.plot(x, 1-info01["te_acc"][0], linewidth=2.0, color=tableau20[4])
	plt.plot(x, 1-info02["te_acc"][0], linewidth=2.0, color=tableau20[6])
	plt.plot(x, 1-info10["te_acc"][0], linewidth=2.0, color=tableau20[10])
	
	if err_bar:
		plt.fill_between(x, 1-info00["te_acc"][0] - info00["te_acc"][1], 1-info00["te_acc"][0] + info00["te_acc"][1], color = tableau20[2], alpha=0.1)
		plt.fill_between(x, 1-info01["te_acc"][0] - info01["te_acc"][1], 1-info01["te_acc"][0] + info01["te_acc"][1], color = tableau20[4], alpha=0.1)
		plt.fill_between(x, 1-info02["te_acc"][0] - info02["te_acc"][1], 1-info02["te_acc"][0] + info02["te_acc"][1], color = tableau20[6], alpha=0.1)
		plt.fill_between(x, 1-info10["te_acc"][0] - info10["te_acc"][1], 1-info10["te_acc"][0] + info10["te_acc"][1], color = tableau20[10], alpha=0.1)

	plt.xlabel('steps', fontsize=20)
	plt.ylabel('test error', fontsize=20)
	plt.tick_params(labelsize=20)
	plt.legend(['p=0', 'p=0.1', 'p=0.2', 'p=1.0'], loc='upper left', fontsize=10)
	plt.grid(axis='both', linestyle=':')
	plt.show()
	plt.close()

	plt.figure()
	plt.plot(x, info00["bound"][0], linewidth=2.0, color=tableau20[2])
	plt.plot(x, info01["bound"][0], linewidth=2.0, color=tableau20[4])
	plt.plot(x, info02["bound"][0], linewidth=2.0, color=tableau20[6])
	plt.plot(x, info10["bound"][0], linewidth=2.0, color=tableau20[10])
	
	if err_bar:
		plt.fill_between(x, info00["bound"][0] - info00["bound"][1], info00["bound"][0] + info00["bound"][1], color = tableau20[2], alpha=0.1)
		plt.fill_between(x, info01["bound"][0] - info01["bound"][1], info01["bound"][0] + info01["bound"][1], color = tableau20[4], alpha=0.1)
		plt.fill_between(x, info02["bound"][0] - info02["bound"][1], info02["bound"][0] + info02["bound"][1], color = tableau20[6], alpha=0.1)
		plt.fill_between(x, info10["bound"][0] - info10["bound"][1], info10["bound"][0] + info10["bound"][1], color = tableau20[10], alpha=0.1)

	plt.xlabel('steps', fontsize=20)
	plt.ylabel('our bound', fontsize=20)
	plt.tick_params(labelsize=20)
	plt.legend(['p=0', 'p=0.1', 'p=0.2', 'p=1.0'], loc='upper left', fontsize=10)
	plt.grid(axis='both', linestyle=':')
	plt.show()
	plt.close()

if __name__ == '__main__':
    n = 100
    fgdinfos = get_infos(n, task='mnist/fgd', T=1000)
    plot_test_bound(fgdinfos, x_min=0, x_max=100, err_bar=True)
    plot_test_bound(fgdinfos, x_min=50, x_max=100, err_bar=True)
    plot_tr_te(fgdinfos, x_min=40)

    mginfos = get_mgrad(n)
    plot_mgrad(mginfos)

    gdinfos = get_infos(n, task='mnist/gd', T=1000)
    plot_tr_te(gdinfos, x_min=40)
    plot_fgd_gd_detail(fgdinfos, gdinfos, 30, True)

    random_label(n)