import numpy as np
import pandas as pd
from pyscipopt import Model, quicksum, SCIP_PARAMSETTING, SCIP_PARAMEMPHASIS
from typing import Any, Tuple
from utils import Console, Convert, Metrics


class Explainer:
    # Constructor for MIP Explainer
    def __init__(self, timeout: int = 120, verbose: bool = False):
        self.console = Console(verbose=verbose)
        self.mse = float("inf")
        self.explanation = None
        self.status = "Unknown"
        self.timeout = timeout
        self.walltime = float("inf")

    # print MIP attributes
    def __str__(self):
        str = self.console.string("[MIP] Information", endl=True)
        str += self.console.string("Explanation", self.explanation, endl=True)
        str += self.console.string("Size", len(self.explanation), endl=True)
        str += self.console.string("Wall time", self.walltime, endl=True)
        str += self.console.string("Error", self.error, endl=True)
        return str

    # Main function
    def explain(
        self,
        *,
        features: pd.Index,
        maxsize: int,
        model: Any,
        samples: np.ndarray,
        target: np.ndarray,
        test_samples=None,
    ) -> Tuple[dict, float]:
        self.console.log("[MIP] Find explanation")
        self.model = model

        # Get data
        A = samples
        b = self.model.predict(A)
        h = (target, self.model.predict(target))
        k = maxsize

        # Solve problem
        self.solve(samples=A, labels=b, hyperplane=h, maxsize=k)

        # Extract explanation and statistics
        self.explanation = {}
        if self.status in ["optimal", "timelimit"]:
            self.explanation = Convert.vec_to_dict(
                vector=self.solution, columns=features
            )
            self.console.log("[MIP] Explanation")
            self.console.log(self.explanation)
            self.console.log("[MIP] Explanation Size", len(self.explanation))
            self.console.log("[MIP] Wall time", self.walltime)
        return self.explanation, self.walltime

    # Test explanation
    def test(self, *, samples: np.ndarray) -> float:
        error = float("inf")
        if self.status in ["optimal", "timelimit"]:
            self.console.log("[MIP] Test solution")
            error = Metrics.rmse(
                samples=samples,
                labels=self.model.predict(samples),
                solution=self.solution,
            )
            self.console.log("[MIP] Error", error)
        self.error = error
        return error

    # Build problem model and run MIP solver
    def solve(
        self,
        *,
        samples: np.ndarray,
        labels: np.ndarray,
        hyperplane: Tuple[np.ndarray, float],
        maxsize: int,
    ) -> None:
        self.console.log("[MIP] Solve problem using SCIP")
        self.build(
            samples=samples, labels=labels, hyperplane=hyperplane, maxsize=maxsize
        )
        self.status, self.solution, self.walltime = self.optimize()

    # Build problem model
    def build(
        self,
        *,
        samples: np.ndarray,
        labels: np.ndarray,
        hyperplane: Tuple[np.ndarray, float],
        maxsize: int,
    ) -> None:
        # Formulate problem
        self.problem = Model()
        # Constants
        A = samples
        b = labels
        n = A.shape[1]
        m = A.shape[0]
        k = maxsize
        l = n - k
        x = hyperplane[0]
        y = hyperplane[1]

        # Indicator vector of inverse support set: s[j] = 1 indicates that j is not in S
        s = np.zeros(n, dtype=object)
        for j in range(n):
            s[j] = self.problem.addVar(vtype="B", name=f"s_{j}")

        # Explanation variables (w is a sparse linear function)
        w = np.zeros(n, dtype=object)
        for j in range(n):
            w[j] = self.problem.addVar(vtype="C", lb=-1.0, ub=1.0, name=f"w_{j}")

        # Prediction variables (p = Zw)
        p = np.zeros(m, dtype=object)
        for i in range(m):
            p[i] = self.problem.addVar(vtype="C", lb=-1.0, ub=1.0, name=f"p_{i}")

        # Objective variable
        o = self.problem.addVar(vtype="C", name="o")

        # Cardinality constraint (sparsity)
        self.problem.addCons(quicksum(s[j] for j in range(n)) >= l, name="cd_cons")

        # Linear constraint (consistency)
        self.problem.addCons(
            quicksum(w[j] * x[j] for j in range(n)) == y, name="ln_cons"
        )

        # Objective equality constraints
        for i in range(m):
            self.problem.addCons(
                p[i] == quicksum(w[j] * A[i, j] for j in range(n)), name=f"eq_cons_{i}"
            )

        # SOS constraints (at most 1 is nonzero)
        for j in range(n):
            self.problem.addConsSOS1([w[j], s[j]], name=f"sos_cons_{j}")

        # Objective function
        self.problem.addCons(
            o >= quicksum((p[i] - b[i]) ** 2 for i in range(m)), name="qd_cons"
        )
        self.problem.setObjective(o, "minimize")

        self.result = w

    # Run SCIP optimizer by enforcing feasibility
    def optimize(self) -> Tuple[str, np.ndarray, float]:
        # Set SCIP parameters
        self.problem.setHeuristics(SCIP_PARAMSETTING.DEFAULT)
        self.problem.setEmphasis(SCIP_PARAMEMPHASIS.FEASIBILITY)
        if self.console.verbose:
            self.problem.setParam("display/verblevel", 4)
        else:
            self.problem.setParam("display/verblevel", 1)
        self.problem.setParam("limits/time", self.timeout)

        # Optimize objective
        self.problem.optimize()
        status = self.problem.getStatus()
        walltime = self.problem.getSolvingTime()

        # Extract solution
        w = self.result
        n = w.shape[0]
        solution = np.zeros(n, dtype="float64")
        if status in ["optimal", "timelimit"]:
            for j in range(n):
                solution[j] = self.problem.getVal(w[j])

        return status, solution, walltime
