import pandas as pd
import glob
import numpy as np
from scipy import interpolate
import seaborn as sns
import matplotlib.pyplot as plt
import os
from scipy.ndimage import gaussian_filter
import sys
import argparse
sns.set_theme()

parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, required=True,
						help='path to experiments')
parser.add_argument('--multistep', action='store_true')
args = parser.parse_args()	

class Run:
	def __init__(self, path):
		self.path = path
		self.episodes = {}
		self.run_summary = None
		self.__read_episodes__()
		self.__read_summary__()

	def __read_episodes__(self):
		all_episodes = glob.glob(self.path + "/episodes/*")
		for i in range(len(all_episodes)):
			self.episodes[f"Episode {i}"] = pd.read_csv(all_episodes[i])

	def __read_summary__(self):
		summary_path = self.path + "/Run0.csv"
		self.run_summary = pd.read_csv(summary_path)
		self.run_summary.drop(columns = ["Unnamed: 0"], inplace=True)

class Experiment:
	def __init__(self, path, multistep=None):
		self.path = path
		self.runs = []
		self.experiment_summary = None
		self.joined_run = None
		self.__read_runs__()
		if multistep:
			self.__join_runs__()
		else:
			self.__create_experiment_summary__()

	def __read_runs__(self):
		all_runs = sorted(glob.glob(self.path + "run*"))
		print(all_runs)
		for run_path in all_runs:
			self.runs.append(Run(run_path))

	def fill_nan(self, A):
		'''
		interpolate to fill nan values
		'''
		inds = np.arange(A.shape[0])
		good = np.where(np.isfinite(A))
		try:
		        f = interpolate.interp1d(inds[good], A[good],bounds_error=False,fill_value="extrapolate")
		        B = np.where(np.isfinite(A),A,f(inds))
		        return B

		except ValueError:
		        print("Unable to interpolate results... Returning original data frame")
		        return A

	def __create_experiment_summary__(self):
		columns = self.runs[0].run_summary.columns
		pack_channels = {}
		for column in columns:
			stacked = np.vstack([self.fill_nan(r.run_summary[column].values) for r in self.runs])
			v = stacked.mean(axis=0)
			v_std = stacked.std(axis=0)

			pack_channels["%s_mean" % column] = v.copy()
			pack_channels["%s_std" % column] = v_std.copy()

		self.experiment_summary = pd.DataFrame(pack_channels)
		self.experiment_summary.to_csv(self.path + 'summary.csv')

	def __join_runs__(self):
		columns = self.runs[0].run_summary.columns
		self.joined_run = self.runs[0].run_summary
		
		for i in range(len(self.runs) - 1):
			self.joined_run = self.joined_run.append(self.runs[i+1].run_summary)
		print(self.joined_run)
		self.joined_run.to_csv(self.path + 'joined_run.csv')


class ExperimentAnalyzer():
	def __init__(self, path, multistep):
		self.path = path
		self.experiments = []
		self.colors = ['blue', 'black', 'red', 'green', 'brown', 'cyan', 'yellow', 'green']
		self.labels = []
		self.multistep = multistep
		self.__read_experiments__()

	def __read_experiments__(self):
		experiments = glob.glob(self.path + "/")
		for experiment_path in experiments:
			self.experiments.append(Experiment(experiment_path, self.multistep))
			self.labels.append(os.path.basename(os.path.normpath(experiment_path)))

EA = ExperimentAnalyzer(args.path, args.multistep)
