import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os

class MetricsRecorder(object):
    """Computes and stores the average and current value"""
    def __init__(self, metrics_names=None, hyperparams=None):
        
        if metrics_names is None:
            self.names = []
        else:
            self.names = metrics_names
        if hyperparams==None:
            self.hyperparams = {"beta":-1, "bs":256, "lr":0.1}
        else:
            self.hyperparams = hyperparams
        self.hist = {}
        self.reset()

    def reset(self):
        for name in self.names:
            self.hist[name] = []

    def recValue(self, name, value, add_name=True):
        if add_name and not(name in self.names):
                self.names.append(name)
                self.hist[name] = []
        self.hist[name].append(value)

    def recDict(self, d, add_name=True):
        for name, value in d.items():
            if add_name and not(name in self.names):
                self.names.append(name)
                self.hist[name] = []
            self.hist[name].append(value)

    def getMetricHist(self, name):
        return self.hist[name]

    def getDataFrame(self):
        metrics_df = pd.DataFrame.from_dict(self.hist)
        # metrics_df.insert(0, "samples", self.getSampleIndex())
        return metrics_df

    def saveMetrics(self, job_path, split='train', save_hyper=True):

        metrics_path = job_path + '_' + split + '_metrics.csv'
        metrics_df = self.getDataFrame()
        metrics_df.to_csv(metrics_path)

        if save_hyper:
            hyper_path = job_path + '_hyperparams.csv'
            hyperparams_df = pd.DataFrame(self.hyperparams, index=[0])
            hyperparams_df.to_csv(hyper_path)

    def add(self, other):
        for name in self.names:
            self.hist[name] = self.hist[name] + other.getMetricHist(name)

    def reduction(self):
        metrics_df = self.getDataFrame()
        # return metrics_df.drop(columns=["samples"]).mean()
        return metrics_df.mean()

    def getSampleIndex(self):
        # Index for plotting
        l = len(self.hist[self.names[0]])
        t = np.arange(0, l*self.hyperparams["bs"], self.hyperparams["bs"])
        return t

    def plotMetric(self, name):
        t = self.getSampleIndex()
        fig, ax = plt.subplots()

        ax.plot(t, self.hist[name])
        ax.set(xlabel='Samples', ylabel=name)
        ax.legend()
        plt.show()