import json
import time
import socket
import os
import numpy as np

from library import configs


class ResultsJSON(object):

    def __init__(self, config: configs.DifflogicConfig):
        experiment_config = config.experiment_config
        self.eid = experiment_config.experiment_id
        self.name = experiment_config.experiment_name
        self.path = experiment_config.results_di

        self.init_time = time.time()
        self.save_time = None
        self.total_time = None

        self.args = None

        if experiment_config.store_raw_values:
            self.raw_values = []
        if experiment_config.store_logit_stats:
            self.logits_stats = [[0] * 4 for _ in range(config.model_config.last_layer_neurons)]

        self.server_name = socket.gethostname().split('.')[0]

    def store_args(self, args):

        self.args = vars(args)

    def store_results(self, results: dict):

        for key, val in results.items():
            if not hasattr(self, key):
                setattr(self, key, list())

            getattr(self, key).append(val)

    def store_final_results(self, results: dict):

        for key, val in results.items():
            key = key + '_'

            setattr(self, key, val)

    def store_raw_values(self, label, output, class_size): # Save the output values of individual samples
        self.raw_values.append([label, output])
    def store_logit_stats(self, label, output, num_classes): # Save the output values of individual samples
        
        true_labels = [0] * len(output)
        start_index = label * len(output)/num_classes
        end_index = start_index + len(output)/num_classes
        for i in range(int(start_index), int(end_index)):
            true_labels[i] = 1

        for i in range(len(output)):
            if true_labels[i] >= 0.5 and output[i] >= 0.5:
                self.logits_stats[i][0] += 1  # True Positive (TP)
            elif true_labels[i] >= 0.5 and output[i] < 0.5:
                self.logits_stats[i][3] += 1  # False Negative (FN)
            elif true_labels[i] < 0.5 and output[i] < 0.5:
                self.logits_stats[i][2] += 1  # True Negative (TN)
            elif true_labels[i] < 0.5 and output[i] >= 0.5:
                self.logits_stats[i][1] += 1  # False Positive (FP)

    def write_to_disk(self):
        self.save_time = time.time()
        self.total_time = self.save_time - self.init_time

        json_str = json.dumps(self.__dict__)
        """
        with open(os.path.join(self.path, '{:08d}.json'.format(self.eid)), mode='w') as f:
            f.write(json_str)
        """
        with open(os.path.join(self.path, self.name + '.json'.format(self.eid)), mode='w') as f:
            f.write(json_str)

    @staticmethod
    def load(eid: int, path: str, get_dict=False):
        with open(os.path.join(path, '{:08d}.json'.format(eid)), mode='r') as f:
            data = json.loads(f.read())

        if get_dict:
            return data

        self = ResultsJSON(-1, '')
        self.__dict__.update(data)

        assert eid == self.eid

        return self


if __name__ == '__main__':

    r = ResultsJSON(101, './')

    print(r.__dict__)

    r.save()

    r2 = ResultsJSON.load(101, './')

    print(r2.__dict__)