import numpy as np
import pandas as pd

import gurobipy as gp
from gurobipy import GRB

from typing import Any, Tuple
from utils import Console, Convert, Metrics


class Explainer:
    # Constructor for GUROBI Explainer
    def __init__(self, timeout: int = 120, verbose: bool = False):
        self.console = Console(verbose=verbose)
        self.mse = float("inf")
        self.explanation = None
        self.status = "Unknown"
        self.timeout = timeout
        self.walltime = float("inf")
        self.iteration = 0
        self.test_samples = None

    # print GUROBU attributes
    def __str__(self):
        tstr = self.console.string("[GUROBI] Information", endl=True)
        tstr += self.console.string("Explanation", self.explanation, endl=True)
        tstr += self.console.string("Size", len(self.explanation), endl=True)
        tstr += self.console.string("Wall time", self.walltime, endl=True)
        tstr += self.console.string("Mean squared error", self.mse, endl=True)
        return tstr

    # Main function
    def explain(
        self,
        *,
        features: pd.Index,
        maxsize: int,
        model: Any,
        samples: np.ndarray,
        target: np.ndarray,
        test_samples=None,
    ) -> Tuple[dict, float]:
        self.console.log("[GUROBI] Find explanation")
        self.model = model
        self.test_samples = test_samples

        # Get data
        A = samples
        b = self.model.predict(A)
        h = (target, self.model.predict(target))
        k = maxsize

        # Solve problem
        self.solve(samples=A, labels=b, hyperplane=h, maxsize=k)

        # Extract explanation and statistics
        self.explanation = {}
        if self.status in [GRB.OPTIMAL, GRB.TIME_LIMIT]:
            self.explanation = Convert.vec_to_dict(
                vector=self.solution, columns=features
            )
            self.console.log("[GUROBI] Explanation")
            self.console.log(self.explanation)
            self.console.log("[GUROBI] Explanation Size", len(self.explanation))
            self.console.log("[GUROBI] Wall time", self.walltime)
        return self.explanation, self.walltime

    # Test explanation
    def test(self, *, samples: np.ndarray) -> float:
        error = float("inf")
        if self.status in [GRB.OPTIMAL, GRB.TIME_LIMIT]:
            self.console.log("[GUROBI] Test solution")
            error = Metrics.rmse(
                samples=samples,
                labels=self.model.predict(samples),
                solution=self.solution,
            )
            self.console.log("[GUROBI] Mean Squared Error", error)
        return error

    # Build problem model and run GUROBI solver
    def solve(
        self,
        *,
        samples: np.ndarray,
        labels: np.ndarray,
        hyperplane: Tuple[np.ndarray, float],
        maxsize: int,
    ) -> None:
        self.console.log("[GUROBI] Solve problem using GUROBI")
        self.build(
            samples=samples, labels=labels, hyperplane=hyperplane, maxsize=maxsize
        )
        self.status, self.solution, self.walltime = self.optimize()

    # Build problem model
    def build(
        self,
        *,
        samples: np.ndarray,
        labels: np.ndarray,
        hyperplane: Tuple[np.ndarray, float],
        maxsize: int,
    ) -> gp.Model:
        A = samples
        b = labels
        n = A.shape[1]
        m = A.shape[0]
        k = maxsize
        l = n - k
        x = hyperplane[0]
        y = hyperplane[1]

        self.problem = None

        try:
            self.problem = gp.Model("mip1")

            # Indicator vector of inverse support set: s[j] = 1 indicates that j is not in S
            s = np.zeros(n, dtype=object)
            for j in range(n):
                s[j] = self.problem.addVar(vtype=GRB.BINARY, name=f"s_{j}")

            # Explanation variables (w is a sparse linear function)
            w = np.zeros(n, dtype=object)
            for j in range(n):
                w[j] = self.problem.addVar(
                    vtype=GRB.CONTINUOUS, lb=-1.0, ub=1.0, name=f"w_{j}"
                )

            # Prediction variables (p = Zw)
            p = np.zeros(m, dtype=object)
            for i in range(m):
                p[i] = self.problem.addVar(
                    vtype=GRB.CONTINUOUS, lb=-1.0, ub=1.0, name=f"p_{i}"
                )

            # Cardinality constraint (sparsity)
            self.problem.addConstr(
                gp.quicksum(s[j] for j in range(n)) >= l, name="cd_cons"
            )

            # Linear constraint (consistency)
            self.problem.addConstr(
                gp.quicksum(w[j] * x[j] for j in range(n)) == y, name="ln_cons"
            )

            # Objective equality constraints
            for i in range(m):
                self.problem.addConstr(
                    p[i] == gp.quicksum(w[j] * A[i, j] for j in range(n)),
                    name=f"eq_cons_{i}",
                )

            # SOS constraints (at most 1 is nonzero)
            for j in range(n):
                self.problem.addSOS(GRB.SOS_TYPE1, [w[j], s[j]])

            # Objective function
            self.problem.setObjective(
                gp.quicksum((p[i] - b[i]) ** 2 for i in range(m)), GRB.MINIMIZE
            )

            self.result = w

        except gp.GurobiError as e:
            print(f"Error code {e.errno}: {e}")
        except AttributeError as e:
            print("Encountered an attribute error", e)

    # Run SCIP optimizer by enforcing feasibility
    def optimize(self) -> Tuple[str, np.ndarray, float]:
        if self.console.verbose:
            self.problem.Params.LogToConsole = 1
        else:
            self.problem.Params.LogToConsole = 0

        self.problem.setParam(gp.GRB.Param.Threads, 1)
        self.problem.setParam("TimeLimit", self.timeout)

        # Optimize objective
        self.iteration += 1

        def callback(model, where):
            if where == GRB.Callback.MIPSOL:
                obj = model.cbGet(GRB.Callback.MIPSOL_OBJ)
                time = model.cbGet(GRB.Callback.RUNTIME)
                w = self.result
                n = w.shape[0]
                solution = np.zeros(n, dtype="float64")

                for j in range(n):
                    solution[j] = self.problem.cbGetSolution(w[j])

                error = Metrics.rmse(
                    samples=self.test_samples,
                    labels=self.model.predict(self.test_samples),
                    solution=solution,
                )
                print("o", self.iteration, time, error)

        self.problem.optimize(callback)

        # Optimize objective
        self.problem.optimize()
        status = self.problem.Status
        walltime = self.problem.Runtime

        # Extract solution
        w = self.result
        n = w.shape[0]
        solution = np.zeros(n, dtype="float64")
        if status in [GRB.OPTIMAL, GRB.TIME_LIMIT]:
            for j in range(n):
                solution[j] = w[j].x

        return status, solution, walltime
