import tensorflow as tf
import time
import numpy as np
import logging


class SummaryLog:
    def __init__(self, summary_writer_file, json_config):
        self.json_config = json_config
        cur_time = time.localtime(time.time())
        cur_time = time.strftime("%Y-%m-%d-%H:%M:%S", cur_time)

        self.summary_writer = tf.summary.create_file_writer(summary_writer_file)
        self.tag_values_dict = {}
        self.tag_step_dict = {}
        self.tag_output_threshold_dict = {}
        self.tag_func_dict = {}
        self.tag_total_add_count = {}
        self.tag_time_threshold = {}
        self.tag_time_data_timestamp = {}
        self.tag_time_last_print_time = {}
        self.total_tag_type = [
            "time_avg", "time_total",
            "avg", "total", "max", "min"
        ]
        self.tag_bins = {}

    def add_tag(self, tag, output_threshold, cal_type, time_threshold=0, bins=100):
        self.tag_values_dict[tag] = []
        self.tag_step_dict[tag] = 0
        self.tag_output_threshold_dict[tag] = output_threshold
        self.tag_func_dict[tag] = cal_type
        self.tag_total_add_count[tag] = 0
        if cal_type.startswith("time"):
            self.tag_time_threshold[tag] = time_threshold
            self.tag_time_data_timestamp[tag] = []
            self.tag_time_last_print_time[tag] = 0
        if cal_type.find("histogram") != -1:
            self.tag_bins[tag] = bins

    def has_tag(self, tag):
        if tag in self.tag_step_dict.keys():
            return True
        else:
            return False

    def get_tag_count(self, tag):
        return self.tag_total_add_count[tag]

    def generate_time_data_output(self, logger=None):
        for tag, threshold in self.tag_time_threshold.items():
            cur_time = time.time()
            if cur_time - self.tag_time_last_print_time[tag] > threshold:
                valid_list = []
                for i in range(len(self.tag_time_data_timestamp[tag])):
                    if cur_time - self.tag_time_data_timestamp[tag][i] < threshold:
                        valid_list.append(self.tag_values_dict[tag][i])

                if len(valid_list) >= 1:
                    if self.tag_func_dict[tag] == "time_avg":
                        out = sum(valid_list) / len(valid_list)
                    if self.tag_func_dict[tag] == "time_total":
                        out = sum(valid_list)
                else:
                    out = 0

                # logger.debug("out:%f" % out)
                with self.summary_writer.as_default():
                    tf.summary.scalar(tag, out, step=self.tag_step_dict[tag])

                self.tag_step_dict[tag] += 1
                self.tag_values_dict[tag] = []
                self.tag_time_data_timestamp[tag] = []
                self.tag_time_last_print_time[tag] = cur_time

    def add_summary(self, tag, value, timestamp=time.time()):
        self.tag_values_dict[tag].append(value)
        self.tag_total_add_count[tag] += 1
        if self.tag_func_dict[tag].startswith("time"):
            self.tag_time_data_timestamp[tag].append(timestamp)

        if self.tag_func_dict[tag].startswith("time") is False and \
                len(self.tag_values_dict[tag]) >= self.tag_output_threshold_dict[tag]:
            if self.tag_func_dict[tag].find("histogram") != -1:
                # each value is a list
                all_values = []
                for i in self.tag_values_dict[tag]:
                    all_values.extend(i)
                with self.summary_writer.as_default():
                    tf.summary.histogram(tag, all_values, step=self.tag_step_dict[tag])
                # self.log_histogram(tag, all_values, self.tag_step_dict[tag], self.tag_bins[tag])
            else:
                summary = tf.compat.v1.Summary()
                if self.tag_func_dict[tag] == "avg":
                    avg_value = sum(self.tag_values_dict[tag]) / len(self.tag_values_dict[tag])
                elif self.tag_func_dict[tag] == "total":
                    avg_value = sum(self.tag_values_dict[tag])
                elif self.tag_func_dict[tag] == "max":
                    avg_value = max(self.tag_values_dict[tag])
                elif self.tag_func_dict[tag] == "min":
                    avg_value = min(self.tag_values_dict[tag])
                elif self.tag_func_dict[tag] == "sd":
                    avg_value = np.array(self.tag_values_dict[tag]).std()

                with self.summary_writer.as_default():
                    tf.summary.scalar(tag, avg_value, step=self.tag_step_dict[tag])

            self.tag_step_dict[tag] += 1
            self.tag_values_dict[tag] = []


s = '''

summary_logger = SummaryLog("./summary_log")

summary_logger.add_tag('sampler/error_per_min', 0, "time_total", time_threshold=10)

while True:
    summary_logger.generate_time_data_output()
    time.sleep(1)
    for i in range(10):
        summary_logger.add_summary('sampler/error_per_min', 1, time.time())
'''

