# Standard Imports
import numpy as np

class Evaluation:
    def __init__(self, synthetic_output, true_output):
        self.synthetic_output = synthetic_output
        self.true_output = true_output

    # Get indices where predictions of model 1 and model 2 are different
    def get_mismatched_indices(self):
        output_model_1 = np.array([instance["output"]["model1"] for instance in self.true_output])
        output_model_2 = np.array([instance["output"]["model2"] for instance in self.true_output])
        mismatched_indices = np.where(output_model_1 != output_model_2)
        return mismatched_indices[0]

    # Get indices where predictions of model 1 and model 2 are same
    def get_matched_indices(self):
        output_model_1 = np.array([instance["output"]["model1"] for instance in self.true_output])
        output_model_2 = np.array([instance["output"]["model2"] for instance in self.true_output])
        matched_indices = np.where(output_model_1 == output_model_2)
        return matched_indices[0]

    # Function to get the overall accuracy
    def get_overall_accuracy(self):
        correct_output = 0

        for i in range(len(self.true_output)):
            true_instance = self.true_output[i]
            synthetic_instance = self.synthetic_output[i]

            if true_instance.get("input") == synthetic_instance.get("input"):
                actual_output = true_instance["output"]["model2"]
                generated_output = synthetic_instance["output"]["model2"]

                if actual_output == generated_output:
                    correct_output += 1
            else:
                raise Exception("The input of the models and synthetic output is not the same")

        return correct_output / len(self.true_output)

    # Function get accuracy of predictions where the outputs of model 1 and model 2 differ
    def get_mismatched_accuracy(self):
        mismatched_indices = self.get_mismatched_indices()
        return self.get_accuracy(mismatched_indices)

    # Function get accuracy of predictions where the outputs of model 1 and model 2 are same
    def get_matched_accuracy(self):
        matched_indices = self.get_matched_indices()
        return self.get_accuracy(matched_indices)

    # General function - Can be used for overall accuracy as well
    def get_accuracy(self, indices=None):
        if indices is None or len(indices) == 0:
            indices = range(len(self.true_output))

        correct_output = 0

        for i in indices:
            true_instance = self.true_output[i]
            synthetic_instance = self.synthetic_output[i]

            if true_instance.get("input") == synthetic_instance.get("input"):
                true_value = true_instance["output"]["model2"]
                generated_value = synthetic_instance["output"]["model2"]

                if true_value == generated_value:
                    correct_output += 1
            else:
                raise Exception("The input of the models and synthetic output is not the same")

        return correct_output / len(indices)

    def get_evaluation_scores(self):
        overall_accuracy = self.get_overall_accuracy()
        mismatched_accuracy = self.get_mismatched_accuracy()
        matched_accuracy = self.get_matched_accuracy()

        return {
            "overall_accuracy": overall_accuracy,
            "mismatched_accuracy": mismatched_accuracy,
            "matched_accuracy": matched_accuracy,
        }