# -----------------------------------------------------------------------------
# MIT License
#
# Copyright (c) 2024 Ontolearn Team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# -----------------------------------------------------------------------------

"""Pyhon binders of other concept learners."""
import subprocess
from datetime import datetime
from typing import List, Dict
from .utils import create_experiment_folder
import re
import time
from ontolearn.learning_problem import PosNegLPStandard


class PredictedConcept:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

    def __iter__(self):
        yield self.Prediction


class DLLearnerBinder:
    """
    dl-learner python binder.
    """

    def __init__(self, binary_path=None, model=None, kb_path=None, storage_path=".", max_runtime=3):
        try:
            assert binary_path
            assert model
            assert kb_path
        except AssertionError:
            print(f'binary_path:{binary_path}, model:{model} kb_path{kb_path} can not be None')
            raise
        self.binary_path = binary_path
        self.kb_path = kb_path
        self.name = model
        self.max_runtime = max_runtime
        if storage_path is not None:
            self.storage_path = storage_path
        else:
            self.storage_path, _ = create_experiment_folder()
        self.best_predictions = None
        self.config_name_identifier = None

    def write_dl_learner_config(self, pos: List[str], neg: List[str]) -> str:
        """Writes config file for dl-learner.

        Args:
            pos: A list of URIs of individuals indicating positive examples in concept learning problem.
            neg: A list of URIs of individuals indicating negatives examples in concept learning problem.

        Returns:
            str: Path of generated config file.
        """
        assert len(pos) > 0 and isinstance(pos[0], str)
        assert len(neg) > 0 and isinstance(neg[0], str)

        Text = list()
        pos_string = "{ "
        neg_string = "{ "
        for i in pos:
            pos_string += "\"" + str(
                i) + "\","
        for j in neg:
            neg_string += "\"" + str(
                j) + "\","

        pos_string = pos_string[:-1]
        pos_string += "}"

        neg_string = neg_string[:-1]
        neg_string += "}"

        Text.append("rendering = \"dlsyntax\"")
        Text.append("// knowledge source definition")

        Text.append("cli.type = \"org.dllearner.cli.CLI\"")
        Text.append("ks.type = \"OWL File\"")
        Text.append("\n")

        Text.append("// knowledge source definition")
        Text.append(
            "ks.fileName = \"" + self.kb_path + '\"')
        Text.append("\n")
        Text.append("reasoner.type = \"closed world reasoner\"")
        Text.append("reasoner.sources = { ks }")
        Text.append("\n")

        Text.append("lp.type = \"PosNegLPStandard\"")
        Text.append("accuracyMethod.type = \"fmeasure\"")
        Text.append("\n")
        Text.append("lp.positiveExamples =" + pos_string)
        Text.append("\n")
        Text.append("lp.negativeExamples =" + neg_string)
        Text.append("\n")
        Text.append("alg.writeSearchTree = \"true\"")

        Text.append("op.type = \"rho\"")
        Text.append("op.useNumericDatatypes = \"false\"")
        Text.append("op.useCardinalityRestrictions = \"false\"")

        if self.name == 'celoe':
            Text.append("alg.type = \"celoe\"")
            Text.append("alg.stopOnFirstDefinition = \"true\"")
        elif self.name == 'ocel':
            Text.append("alg.type = \"ocel\"")
            Text.append("alg.showBenchmarkInformation = \"true\"")
        elif self.name == 'eltl':
            Text.append("alg.type = \"eltl\"")
            Text.append("alg.maxNrOfResults = \"1\"")
            Text.append("alg.stopOnFirstDefinition = \"true\"")
        else:
            raise ValueError('Wrong algorithm chosen.')

        Text.append("alg.maxExecutionTimeInSeconds = " + str(self.max_runtime))
        Text.append("\n")

        pathToConfig = self.storage_path + '/' + self.name + '_' + datetime.now().strftime("%Y%m%d_%H%M%S_%f") + '.conf'

        with open(pathToConfig, "wb") as wb:
            for i in Text:
                wb.write(i.encode("utf-8"))
                wb.write("\n".encode("utf-8"))
        return pathToConfig

    def fit(self, lp: PosNegLPStandard, max_runtime: int = None):
        """Fit dl-learner model on a given positive and negative examples.

        Args:
            lp:PosNegLPStandard
            lp.pos A list of URIs of individuals indicating positive examples in concept learning problem.
            lp.neg A list of URIs of individuals indicating negatives examples in concept learning problem.
            max_runtime: Limit to stop the algorithm after n seconds.

        Returns:
            self.
        """
        if max_runtime:
            self.max_runtime = max_runtime

        pathToConfig = self.write_dl_learner_config(pos=[i.str for i in lp.pos],
                                                    neg=[i.str for i in lp.neg])
        total_runtime = time.time()
        res = subprocess.run([self.binary_path, pathToConfig], capture_output=True, universal_newlines=True)
        total_runtime = round(time.time() - total_runtime, 3)
        self.best_predictions = self.parse_dl_learner_output(res.stdout.splitlines(), pathToConfig)
        self.best_predictions['Runtime'] = total_runtime
        return self

    def best_hypotheses(self, n: int = None) -> PredictedConcept:
        # @ TODO:
        # Convert string to OWL class object
        # {'Prediction': 'Child', 'Accuracy': 1.0, 'F-measure': 1.0, 'NumClassTested': 3, 'Runtime': 3.502}

        return PredictedConcept(**self.best_hypothesis())

    def best_hypothesis(self):
        """ Return predictions if exists.

        Returns:
            The prediction or the string 'No prediction found.'
        """
        if self.best_predictions:
            return self.best_predictions
        else:
            print('No prediction found.')

    def parse_dl_learner_output(self, output_of_dl_learner: List[str], file_path: str) -> Dict:
        """Parse the output received from executing dl-learner.

        Args:
            output_of_dl_learner: The output of dl-learner to parse.
            file_path: The file path to store the output.

        Returns:
            A dictionary of {'Prediction': ..., 'Accuracy': ..., 'F-measure': ...}.
        """
        solutions = None
        best_concept_str = None
        acc = -1.0
        f_measure = -1.0
        search_info = None
        num_expression_tested = -1

        # DL-learner does not provide a unified output :(
        # ELTL  => No info pertaining to the number of concept tested, number of retrieval etc.
        # CELOE => Algorithm terminated successfully (time: 245ms, 188 descriptions tested, 69 nodes in the search
        #          tree).
        # OCEL  => Algorithm stopped (4505 descriptions tested).

        time.time()
        txt_path = file_path + '.txt'  # self.storage_path + '/output_' + self.name + '_' + str(time.time()) + '.txt'

        # (1) Store output of dl learner and extract solutions.
        with open(txt_path, 'w') as w:
            for th, sentence in enumerate(output_of_dl_learner):
                w.write(sentence + '\n')
                if 'solutions' in sentence and '1:' in output_of_dl_learner[th + 1]:
                    solutions = output_of_dl_learner[th:]

                if 'Algorithm' in sentence:
                    search_info = sentence

            # check whether solutions found
            if solutions:  # if solution found, check the correctness of relevant part of dl-learner output.
                try:
                    assert isinstance(solutions, list)
                    assert 'solutions' in solutions[0]
                    assert len(solutions) > 0
                    assert '1: ' in solutions[1][:5]
                except AssertionError:
                    print(type(solutions))
                    print('####')
                    print(solutions[0])
                    print('####')
                    print(len(solutions))
            else:
                # no solution found.
                print('#################')
                print('#######{}##########'.format(self.name))
                print('#################')
                for i in output_of_dl_learner[-3:-1]:
                    print(i)
                    if 'descriptions' in i:
                        search_info = i
                print('#################')
                print('#######{}##########'.format(self.name))
                print('#################')

                _ = re.findall(r'\d+ descriptions tested', search_info)
                assert len(_) == 1
                # Get the numbers
                num_expression_tested = int(re.findall(r'\d+', _[0])[0])

                return {'Model': self.name, 'Prediction': best_concept_str, 'Accuracy': float(acc) * .01,
                        'F-measure': float(f_measure) * .01, 'NumClassTested': int(num_expression_tested)}

        # top_predictions must have the following form
        """solutions ......:
        1: Parent(pred.acc.: 100.00 %, F - measure: 100.00 %)
        2: ⊤ (pred.acc.: 50.00 %, F-measure: 66.67 %)
        3: Person(pred.acc.: 50.00 %, F - measure: 66.67 %)
        """
        best_solution = solutions[1]

        if self.name == 'ocel':
            """ parse differently"""
            token = '(accuracy '
            start_index = len('1: ')
            end_index = best_solution.index(token)
            best_concept_str = best_solution[start_index:end_index - 1]  # -1 due to white space between *) (*.
            quality_info = best_solution[end_index:]
            # best_concept_str => *Sister ⊔ (Female ⊓ (¬Granddaughter))*
            # quality_info     => *(accuracy 100%, length 16, depth 2)*

            # Create a list to hold the numbers
            predicted_accuracy_info = re.findall(r'accuracy \d*%', quality_info)

            assert len(predicted_accuracy_info) == 1
            assert predicted_accuracy_info[0][-1] == '%'  # percentage sign
            acc = re.findall(r'\d+\.?\d+', predicted_accuracy_info[0])[0]
            _ = re.findall(r'\d+ descriptions tested', search_info)
            assert len(_) == 1
            # Get the numbers
            num_expression_tested = int(re.findall(r'\d+', _[0])[0])

        elif self.name in ['celoe', 'eltl']:
            # e.g. => 1: Sister ⊔ (∃ married.Brother) (pred. acc.: 90.24%, F-measure: 91.11%)
            # Heuristic => Quality info start with *(pred. acc.: *
            token = '(pred. acc.: '
            start_index = len('1: ')
            end_index = best_solution.index(token)
            best_concept_str = best_solution[start_index:end_index - 1]  # -1 due to white space between *) (*.
            quality_info = best_solution[end_index:]
            # best_concept_str => *Sister ⊔ (Female ⊓ (¬Granddaughter))*
            # quality_info     => *(pred. acc.: 79.27%, F-measure: 82.83%)*

            # Create a list to hold the numbers
            predicted_accuracy_info = re.findall(r'pred. acc.: \d+.\d+%', quality_info)
            f_measure_info = re.findall(r'F-measure: \d+.\d+%', quality_info)

            assert len(predicted_accuracy_info) == 1
            assert len(f_measure_info) == 1

            assert predicted_accuracy_info[0][-1] == '%'  # percentage sign
            assert f_measure_info[0][-1] == '%'  # percentage sign

            acc = re.findall(r'\d+\.?\d+', predicted_accuracy_info[0])[0]
            f_measure = re.findall(r'\d+\.?\d+', f_measure_info[0])[0]

            if search_info is not None:
                # search_info is expected to be " Algorithm terminated successfully (time: 252ms, 188 descriptions
                # tested, 69 nodes in the search tree)."
                _ = re.findall(r'\d+ descriptions tested', search_info)
                if len(_) == 0:
                    assert self.name == 'eltl'
                else:
                    assert len(_) == 1
                    # Get the numbers
                    num_expression_tested = int(re.findall(r'\d+', _[0])[0])
        else:
            raise ValueError
        # 100% into range between 1.0 and 0.0
        return {'Prediction': best_concept_str, 'Accuracy': float(acc) * .01, 'F-measure': float(f_measure) * .01,
                'NumClassTested': int(num_expression_tested)}

    @staticmethod
    def train(dataset: List = None) -> None:
        """ Dummy method, currently it does nothing."""

    def fit_from_iterable(self, dataset: List = None, max_runtime=None) -> List[Dict]:
        """Fit dl-learner model on a list of given positive and negative examples.

        Args:
            dataset: A list of tuple (s,p,n) where
                s => string representation of target concept,
                p => positive examples, i.e. s(p)=1 and
                n => negative examples, i.e. s(n)=0.
            max_runtime: Limit to stop the algorithm after n seconds.

        Returns:
            self.
        """
        raise NotImplementedError
        assert len(dataset) > 0
        if max_runtime:
            assert isinstance(max_runtime, int)
            self.max_runtime = max_runtime

        return [self.fit(pos=p, neg=n, max_runtime=self.max_runtime).best_hypothesis() for (s, p, n) in dataset]
