from absl import app, flags
import csv
import os
import numpy as np
from scipy.stats import sem
import matplotlib.pyplot as plt

plt.rcParams.update({'font.size': 15})
FLAGS = flags.FLAGS

flags.DEFINE_string('filename', None, 'Input file name')
flags.DEFINE_boolean('raw_data', False, 'Lines must be computed from the raw-pbs-value script output')
flags.DEFINE_boolean('equal_transitions', False, 'Count all transitions as samples when comparing Gibbs Sampler')
flags.DEFINE_integer('pre_burn', 100, 'Number of transitions to burn before counting any Gibbs samples')
flags.DEFINE_multi_integer('num_burn', [10], 'Number of transitions in between each Gibbs sample')
flags.DEFINE_integer('eval_every', 50, 'Number of samples between points on the plot')
flags.DEFINE_string('save_loc', None, 'save directory')
flags.mark_flag_as_required('filename')

def plot_learning_curves(ax, avg_errs, samples, std_errs, labels, title=""):
  ax.spines['top'].set_visible(False)
  ax.spines['right'].set_visible(False)
  ax.set_ylim(0, 0.2)
  ax.set_title(title, y=0.4, x=0.7, size=13)
  lines = []
  for i in range(len(avg_errs)):
    l = ax.errorbar(samples, avg_errs[i], yerr=std_errs[i],
        label=labels[i], capsize=2.)
    lines.append(l)
  plt.ylabel("Value Error")
  plt.xlabel("Samples")
  plt.legend()
  return lines

def parse_lines():
	line_dict = {}
	line_dict["true"] = {}
	line_dict["importance"] = {}
	burns_used = []
	with open(FLAGS.filename) as f:
		reader = csv.reader(f, delimiter='\t')
		for row in reader:
			if len(row) < 4 or row[0] == 'num_suits':
				continue
			samples = int(row[4])
			if samples not in line_dict["true"]:
					line_dict["true"][samples] = []
					line_dict["importance"][samples] = []
			burn_list = row[5].replace('[', '').replace(']', '').split(",")
			for burn in burn_list:
				if burn not in burns_used:
					burns_used.append(burn)
				if burn not in line_dict:
					line_dict[burn] = {}
				if samples not in line_dict[burn]:
					line_dict[burn][samples] = []
			value = float(row[8])
			line_dict["true"][samples].append((value - float(row[9]))**2)
			line_dict["importance"][samples].append((value - float(row[10]))**2)
			for i in range(11, len(row)):
				line_dict[burn_list[i-11]][samples].append((value - float(row[i]))**2)
	return line_dict

def parse_raw_data_equal_transitions(burn_list, eval_every=50, pre_burn=0):
	line_dict = {}
	#line_dict["True"] = {}
	#line_dict["Importance"] = {}
	for burn in burn_list:
		line_dict[f'Gibbs burn={burn}'] = {}
	value = None
	with open(FLAGS.filename) as f:
		for line in f:
			if 'num_suits' in line:
				continue
			name, values_str = line.strip().split(":")
			if name == 'value':
				value = float(values_str)
			elif name == "True":
				continue
				values = [float(x) for x in values_str[1:-1].split(',')]
				i = pre_burn + eval_every
				while i <= len(values):
					if i - pre_burn not in line_dict[name]:
						line_dict[name][i - pre_burn] = []
					line_dict[name][i - pre_burn].append((value - np.mean(values[pre_burn:i-1]))**2)
					i += eval_every
			elif name == "Importance":
				continue
				values = []
				estimate = 0.
				estimate_weight = 0.
				raw_values = values_str.split('(')
				for i in range(pre_burn, len(raw_values)):
					v = raw_values[i]
					if len(v):
						v = v[:-2].replace(')','')
						x, y = v.split(',')
						estimate += float(y) * float(x)
						estimate_weight += float(y)
						values.append(estimate / estimate_weight)
				i = eval_every
				while i <= len(values):
					if i not in line_dict[name]:
						line_dict[name][i] = []
					line_dict[name][i].append((value - values[i-2])**2)
					i += eval_every
			else:
				for num_burn in burn_list:
					full_name = f'{name} burn={num_burn}'
					samples = []
					values = [float(x) for x in values_str[1:-1].split(',')]
					i = 0
					while i <= len(values):
						if (i % num_burn) == 0 and i >= pre_burn:
							samples.append(values[i-1])
						if (i % eval_every) == 0 and i >= pre_burn:
							if i - pre_burn not in line_dict[full_name]:
								line_dict[full_name][i - pre_burn] = []
							line_dict[full_name][i - pre_burn].append((value - np.mean(samples))**2)
						i += 1
	return line_dict

def parse_raw_data(burn_list, eval_every=50, pre_burn=0):
	line_dict = {}
	line_dict["True"] = {}
	line_dict["Importance"] = {}
	for burn in burn_list:
		line_dict[f'Gibbs burn={burn}'] = {}
	value = None
	with open(FLAGS.filename) as f:
		for line in f:
			if 'num_suits' in line:
				continue
			name, values_str = line.strip().split(":")
			if name == 'value':
				value = float(values_str)
			elif name == "True":
				values = [float(x) for x in values_str[1:-1].split(',')]
				i = pre_burn + eval_every
				while i <= (len(values) - pre_burn) // max(burn_list)  + pre_burn:
					if i - pre_burn not in line_dict[name]:
						line_dict[name][i - pre_burn] = []
					line_dict[name][i - pre_burn].append((value - np.mean(values[pre_burn:i-1]))**2)
					i += eval_every
			elif name == "Importance":
				values = []
				estimate = 0.
				estimate_weight = 0.
				raw_values = values_str.split('(')
				for i in range(pre_burn, (len(raw_values) - pre_burn) // max(burn_list) + pre_burn):
					v = raw_values[i]
					if len(v):
						v = v[:-2].replace(')','')
						x, y = v.split(',')
						estimate += float(y) * float(x)
						estimate_weight += float(y)
						values.append(estimate / estimate_weight)
				i = eval_every
				while i <= len(values):
					if i not in line_dict[name]:
						line_dict[name][i] = []
					line_dict[name][i].append((value - values[i-2])**2)
					i += eval_every
			else:
				for num_burn in burn_list:
					full_name = f'{name} burn={num_burn}'
					samples = []
					values = [float(x) for x in values_str[1:-1].split(',')]
					i = 0
					while len(samples) <= (len(values) - pre_burn) // max(burn_list):
						if (i % num_burn) == 0 and i >= pre_burn:
							samples.append(values[i-1])
						if ((i - pre_burn) // num_burn % eval_every) == 0 and i >= pre_burn:
							if (i - pre_burn) // num_burn not in line_dict[full_name]:
								line_dict[full_name][(i - pre_burn) // num_burn] = []
							line_dict[full_name][(i - pre_burn) // num_burn].append((value - np.mean(samples))**2)
						i += 1
	return line_dict

def main(_):
	if FLAGS.raw_data:
		if FLAGS.equal_transitions:
			line_dict = parse_raw_data_equal_transitions(FLAGS.num_burn, pre_burn=FLAGS.pre_burn, eval_every=FLAGS.eval_every)
		else:
			line_dict = parse_raw_data(FLAGS.num_burn, pre_burn=FLAGS.pre_burn, eval_every=FLAGS.eval_every)
	else:
		line_dict = parse_lines()

	ax = plt.gca()
	i = 0
	lines = []
	labels = list(line_dict.keys())
	x = list(line_dict[labels[0]].keys())
	x.sort()
	ys = []
	sems = []
	for l in labels:
		ys.append([np.mean(line_dict[l][y]) for y in x])
		sems.append([sem(line_dict[l][y]) for y in x])
	
	line  = plot_learning_curves(ax, ys, x, sems, labels)
	
	if FLAGS.save_loc:
		plt.savefig(os.path.join(FLAGS.save_loc, FLAGS.filename.split('/')[-1].replace('tsv', 'png')), bbox_inches="tight")
	plt.show()


if __name__ == "__main__":
	app.run(main)