import lime.lime_tabular
import numpy as np
import pandas as pd
import time
from typing import Any, Tuple
from utils import Console, Convert, Metrics


class Explainer:
    # Constructor for LIME explainer with default parameters
    def __init__(self, verbose: bool = False) -> None:
        self.console = Console(verbose=verbose)

    # print LIME attributes
    def __str__(self):
        str = self.console.string("[LIME] Information", endl=True)
        str += self.console.string("Explanation", self.explanation, endl=True)
        str += self.console.string("Size", len(self.explanation), endl=True)
        str += self.console.string("Wall time", self.walltime, endl=True)
        str += self.console.string("Error", self.error, endl=True)
        return str

    # Main function
    def explain(
        self,
        *,
        features: pd.Index,
        maxsize: int,
        model: Any,
        samples: np.ndarray,
        target: np.ndarray,
        test_samples=None,
    ) -> Tuple[dict, float]:
        self.console.log("[LIME] Find explanation")
        self.model = model

        # Solve problem
        self.solve(features=features, maxsize=maxsize, samples=samples, target=target)

        # Extract explanation and statistics
        self.explanation = Convert.vec_to_dict(vector=self.solution, columns=features)
        self.console.log("[LIME] Explanation")
        self.console.log(self.explanation)
        self.console.log("[LIME] Explanation Size", len(self.explanation))
        self.console.log("[LIME] Wall time", self.walltime)
        return self.explanation, self.walltime

    # Test explanation
    def test(self, *, samples: np.ndarray) -> float:
        self.console.log("[LIME] Test solution")
        error = Metrics.rmse(
            samples=samples, labels=self.model.predict(samples), solution=self.solution
        )
        self.console.log("[LIME] Error", error)
        self.error = error
        return error

    # Build problem model and run MIP solver
    def solve(
        self,
        *,
        features: pd.Index,
        maxsize: int,
        samples: np.ndarray,
        target: np.ndarray,
    ) -> None:
        # Setup LIME
        self.lime = lime.lime_tabular.LimeTabularExplainer(
            samples,
            feature_names=features,
            mode="regression",
            discretize_continuous=False,
            verbose=self.console.verbose,
        )
        # Run LIME
        time_start = time.time()
        result = self.lime.explain_instance(
            target, self.model.predict, num_features=maxsize
        )
        time_end = time.time()
        self.walltime = max(0, 0, time_end - time_start)
        # Extract solution
        self.solution = Convert.list_to_vec(list=result.as_list(), columns=features)
