import tensorflow as tf
import numpy as np

class Evaluator():
    def __init__(self, model, datastream, args, recoder, logger):
        super(Evaluator, self).__init__()
        self.args = args
        self.model = model
        self.task_num = datastream.__len__()
        self.datastream = datastream
        self.BuildObjective()
        self.recorder = recoder
        self.logger = logger

    def BuildObjective(self):
        self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
        self.loss_object_for_onehot = tf.keras.losses.CategoricalCrossentropy()
        self.test_loss = [tf.keras.metrics.Mean(name='test_loss_{}'.format(i)) for i in range(self.task_num)]
        self.test_acc = [tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy_{}'.format(i)) for i in range(self.task_num)]

    def test_step(self, images, labels, task_id, mask):
        predictions = self.model(images, mask, training=False)
        t_loss = self.loss_object(labels, predictions)
        self.test_loss[task_id](t_loss)
        self.test_acc[task_id](labels, predictions)

    def reset_state(self):
        for task in range(self.task_num):
            self.test_loss[task].reset_states()
            self.test_acc[task].reset_states()

    def EvaluateTask(self,labelset_id, timeline_id, run_id, batch_id):
        end_time = [e for s, e in self.datastream.TimeLine]
        # 1. if a task finished, evaluate for its first shot acc
        if batch_id in end_time:
            task_ids = [i for i, idx in enumerate(end_time) if idx == batch_id] # some tasks end at the same time
            for task_id in task_ids:
                for images, labels in self.datastream.TestStream[task_id]:
                    self.test_step(images, labels, task_id, self.datastream.MaskSet[task_id])
                # 1.1 record the first shot acc
                self.recorder.first_acc_record[labelset_id, timeline_id, run_id, task_id] =  self.test_acc[task_id].result().numpy()
                # 1.2 log info the result
                self.logger.PrintResultPerTask(labelset_id, timeline_id, run_id, task_id, self.test_loss[task_id].result(), self.test_acc[task_id].result(), is_first=True)
                # 1.3 reset eval state
                self.reset_state()
        
        # 2. if all task finished, evaluate all tasks
        if batch_id == max([e for s, e in self.datastream.TimeLine]): #or batch_id == 399:
            self.reset_state()
            # test_acc_cache = []
            for task_id in range(self.datastream.__len__()):
                for images, labels in self.datastream.TestStream[task_id]:
                    self.test_step(images, labels, task_id, self.datastream.MaskSet[task_id])
                # 2.2 record the finish shot result
                self.recorder.finish_acc_record[labelset_id, timeline_id, run_id, task_id] =  self.test_acc[task_id].result().numpy()
            # 2.3 log info the results
            first_acc_avg = np.mean(self.recorder.first_acc_record[labelset_id, timeline_id, run_id])
            first_acc_std = np.std(self.recorder.first_acc_record[labelset_id, timeline_id, run_id])
            finish_acc_avg = np.mean(self.recorder.finish_acc_record[labelset_id, timeline_id, run_id])
            finish_acc_std = np.std(self.recorder.finish_acc_record[labelset_id, timeline_id, run_id])
            forget_record = self.recorder.finish_acc_record[labelset_id, timeline_id, run_id] - self.recorder.first_acc_record[labelset_id, timeline_id, run_id]
            self.recorder.forget_record[labelset_id, timeline_id, run_id] = forget_record
            forget_avg = np.mean(forget_record)
            forget_std = np.std(forget_record)
            self.logger.PrintResultPerRun(labelset_id, timeline_id, run_id, first_acc_avg, first_acc_std, finish_acc_avg, finish_acc_std, forget_avg, forget_std)
            # 2.4 reset state
            self.reset_state()
        
