import cvxpy as cp
import numpy as np
import pandas as pd
from typing import Any, Tuple, Union
from utils import Console, Convert, Metrics


class Explainer:
    # Constructor for convex explainer (anu cvx.norm is allowed)
    def __init__(self, norm: Union[int, str] = 1, verbose: bool = False):
        self.console = Console(verbose=verbose)
        self.norm = norm

    # print CVX attributes
    def __str__(self):
        str = self.console.string("[CVX] Information", endl=True)
        str += self.console.string("Explanation", self.explanation, endl=True)
        str += self.console.string("Size", len(self.explanation), endl=True)
        str += self.console.string("Wall time", self.walltime, endl=True)
        str += self.console.string("ERROR", self.mse, endl=True)
        return str

    # Main function
    def explain(
        self,
        *,
        features: pd.Index,
        maxsize: int,
        model: Any,
        samples: np.ndarray,
        target: np.ndarray,
        test_samples=None,
    ) -> Tuple[dict, float]:
        self.console.log("[CVX] Find explanation")
        self.model = model

        # Get data
        A = samples
        b = self.model.predict(A)
        h = (target, self.model.predict(target))
        k = maxsize

        # Solve problem
        self.solve(samples=A, labels=b, hyperplane=h, maxsize=k)

        # Extract explanation and statistics
        self.explanation = {}
        if self.status == "optimal":
            self.explanation = Convert.vec_to_dict(
                vector=self.solution, columns=features
            )
            self.console.log("[CVX] Explanation")
            self.console.log(self.explanation)
            self.console.log("[CVX] Explanation Size", len(self.explanation))
            self.console.log("[CVX] Wall time", self.walltime)
        return self.explanation, self.walltime

    # Test explanation
    def test(self, *, samples: np.ndarray) -> float:
        error = float("inf")
        if self.status == "optimal":
            self.console.log("[CVX] Test solution")
            error = Metrics.rmse(
                samples=samples,
                labels=self.model.predict(samples),
                solution=self.solution,
            )
            self.console.log("[CVX] Error", error)
        self.error = error
        return self.error

    # Build and solve convex optimization problem
    def solve(
        self,
        *,
        samples: np.ndarray,
        labels: np.ndarray,
        hyperplane: Tuple[np.ndarray, float],
        maxsize: int,
    ) -> None:
        self.console.log("[CVX] Solve problem")
        # Constants
        A = samples
        b = labels
        x = hyperplane[0]
        y = hyperplane[1]
        n = x.shape[0]
        k = maxsize
        q = self.norm
        # Variables
        w = cp.Variable(n)
        # Constraints
        constraints = [w @ x == y, cp.norm(w, q) <= k]
        # Objective
        objective = cp.Minimize(cp.sum_squares(A @ w - b))
        # Problem
        self.problem = cp.Problem(objective, constraints=constraints)
        # Solve problem
        self.problem.solve(verbose=self.console.verbose)
        # Get solution and statistics
        self.status = self.problem.status
        if self.status == "optimal":
            self.solution = w.value
            self.walltime = self.problem.solver_stats.solve_time
        else:
            self.solution = np.zeros(n, dtype="float64")
            self.walltime = float("inf")
