import cvxpy as cp
import numpy as np
import pandas as pd
import time
from collections import defaultdict
from typing import Any, Tuple, Union
from utils import Console, Convert, Metrics


class Explainer:
    # Constructor for Iterative Hard Thresholding explainer (beta version)
    def __init__(
        self,
        iterations: int = 1000,
        preprocessing: bool = True,
        postprocessing: bool = True,
        stepsize: float = 1.0,
        verbose: bool = False,
    ):
        self.console = Console(verbose=verbose)
        self.epsilon = 1e-6
        self.iterations = iterations
        self.preprocessing = preprocessing
        self.postprocessing = postprocessing
        self.stepsize = stepsize

    # print IHT attributes
    def __str__(self):
        str = self.console.string("[IHT] Information", endl=True)
        str += self.console.string("Explanation", self.explanation, endl=True)
        str += self.console.string("Size", len(self.explanation), endl=True)
        str += self.console.string("Wall time", self.walltime, endl=True)
        str += self.console.string("Error", self.error, endl=True)
        return str

    # Main function
    def explain(
        self,
        *,
        features: pd.Index,
        maxsize: int,
        model: Any,
        samples: np.ndarray,
        target: np.ndarray,
        test_samples=None
    ) -> Tuple[dict, float]:
        self.console.log("[IHT] Find explanation")
        self.model = model

        # Get data
        A = samples
        b = self.model.predict(A)
        h = (target, self.model.predict(target))
        k = maxsize

        # Solve problem
        self.solve(samples=A, labels=b, hyperplane=h, maxsize=k)

        # Extract explanation and statistics
        self.explanation = Convert.vec_to_dict(vector=self.solution, columns=features)
        self.console.log("[IHT] Explanation")
        self.console.log(self.explanation)
        self.console.log("[IHT] Explanation Size", len(self.explanation))
        self.console.log("[IHT] Wall time", self.walltime)
        return self.explanation, self.walltime

    # Test explanation
    def test(self, *, samples: np.ndarray) -> float:
        self.console.log("[IHT] Test solution")
        error = Metrics.rmse(
            samples=samples, labels=self.model.predict(samples), solution=self.solution
        )
        self.console.log("[IHT] Error", error)
        self.error = error
        return error

    # Find explanation using IHT
    def solve(
        self,
        *,
        samples: np.ndarray,
        labels: np.ndarray,
        hyperplane: Tuple[np.ndarray, float],
        maxsize: int
    ) -> None:
        self.console.log("[IHT] Solve problem")

        # Parameters
        A = samples
        b = labels
        h = hyperplane
        k = maxsize

        # Preprocess -> PGD -> Postprocess
        time_start = time.time()
        o = self.preprocess(samples=A, labels=b, hyperplane=h, maxsize=k)
        p = self.projectedGradientDescent(
            samples=A, labels=b, hyperplane=h, maxsize=k, origin=o
        )
        self.solution = self.postprocess(samples=A, labels=b, hyperplane=h, point=p)
        time_end = time.time()
        self.walltime = max(0.0, time_end - time_start)

    # Preprocess problem by relaxing cardinality constraint
    def preprocess(
        self,
        *,
        samples: np.ndarray,
        labels: np.ndarray,
        hyperplane: Tuple[np.ndarray, float],
        maxsize: int
    ) -> np.ndarray:
        if self.preprocessing:
            self.console.log("[IHT] Preprocess problem")
            # Constants
            A = samples
            b = labels
            x = hyperplane[0]
            y = hyperplane[1]
            n = x.shape[0]
            k = maxsize
            # Variables
            w = cp.Variable(n)
            # Constraints
            constraints = [w @ x == y, cp.norm(w, 1) <= k]
            # Objective
            objective = cp.Minimize(cp.sum_squares(A @ w - b))
            # Problem
            problem = cp.Problem(objective, constraints=constraints)
            problem.solve(verbose=self.console.verbose)
            if problem.status == "optimal":
                return w.value
        return np.zeros(n, dtype="float64")

    # Postprocess sparse solution via least square optimization under unary constraints
    def postprocess(
        self,
        *,
        samples: np.ndarray,
        labels: np.ndarray,
        hyperplane: Tuple[np.ndarray, float],
        point: np.ndarray
    ) -> np.ndarray:
        if self.postprocessing:
            self.console.log("[IHT] Postprocess solution")
            # Constants
            A = samples
            b = labels
            x = hyperplane[0]
            y = hyperplane[1]
            zeros = np.where(point == 0)[0]
            n = point.shape[0]
            # Variables
            w = cp.Variable(n)
            # Constraints
            constraints = [w @ x == y]
            for j in zeros:
                constraints.append(w[j] == 0)
            # Objective
            objective = cp.Minimize(cp.sum_squares(A @ w - b))
            # Problem
            problem = cp.Problem(objective, constraints=constraints)
            problem.solve(verbose=self.console.verbose)
            if problem.status == "optimal":
                p = w.value.copy()
                p[np.abs(p) < self.epsilon] = 0
                return p
        return point

    # Projected gradient descent method with iterative threshold projection
    def projectedGradientDescent(
        self,
        *,
        samples: np.ndarray,
        labels: np.ndarray,
        hyperplane: Tuple[np.ndarray, float],
        maxsize: int,
        origin: np.ndarray
    ) -> np.ndarray:
        self.console.log("[IHT] Run projected gradient descent")
        A = samples
        b = labels
        h = hyperplane
        k = maxsize
        w = origin
        T = self.iterations
        eta = self.stepsize

        for t in range(T):
            # self.console.log("Iteration",t)
            # Compute gradient
            g = self.gradient(point=w, samples=A, labels=b)
            # Perform descent
            u = w - (eta * g)
            # Project onto feasible space
            v = self.projection(point=u, hyperplane=h, maxsize=k)
            w = v
            # log
            # self.console.log("Point", v[np.nonzero(w)])
        return w

    # Compute gradient of least squares
    def gradient(
        self, point: np.ndarray, samples: np.ndarray, labels: np.ndarray
    ) -> np.ndarray:
        w = point
        A = samples
        b = labels
        m = samples.shape[0]

        v = np.dot(A, w) - b
        g = (1 / m) * np.dot(A.T, v)
        return g

    # Project vector w onto the intersection of the hypercube C(n,k) and the hyperplane <w,x> = y
    def projection(
        self, point: np.ndarray, hyperplane: Tuple[np.ndarray, float], maxsize: int
    ) -> np.ndarray:
        # Input
        w = point
        k = maxsize
        x = hyperplane[0]
        y = hyperplane[1]

        # Grow function
        def grow(j, L):
            if L:
                avg = (np.sum(w[L]) - y) / len(L)
                return abs(w[j] - avg)
            else:
                return y * w[j]

        # Indices
        I = defaultdict(list)
        # Sort in a list each index of zero elements in x according to the absolute value of w
        I[0] = sorted(np.where(x == 0)[0], key=lambda j: abs(w[j]), reverse=True)
        # Collect in a list the indices of one elements in x
        I[1] = np.where(x == 1)[0].tolist()

        # Greedy construction of support
        S = defaultdict(list)
        for i in range(k):
            # Prioritize indices of ones
            if I[1]:
                values = [grow(j, S[1]) for j in I[1]]
                j = I[1][np.argmax(values)]
                S[1].append(j)
                I[1].remove(j)
            # If no indices of ones go for zeros
            else:
                j = I[0].pop(0)
                S[0].append(j)

        # return the projection of the point from S
        # the projection on ones must satisfy the hyperplane constraint
        u = self.projectOntoHyperplane(w, indices=S[1], offset=y)
        # the projection on zeros is free
        v = self.projectOntoSubset(w, indices=S[0])
        return u + v

    # Project w onto the indicator vector of S
    def projectOntoSubset(self, point: np.ndarray, indices: list) -> np.ndarray:
        S = indices
        w = point
        n = w.shape[0]
        u = np.zeros(n, dtype="float64")
        for j in S:
            u[j] = w[j]
        return u

    # Project w onto the hyperplane whose normal is the indicator vector of S
    def projectOntoHyperplane(
        self, point: np.ndarray, indices: list, offset: float
    ) -> np.ndarray:
        S = indices
        w = point
        y = offset
        n = w.shape[0]
        p = w[S]
        s = p.shape[0]
        tau = (1 / s) * (np.sum(p) - y)
        u = np.zeros(n, dtype="float64")
        for j in S:
            u[j] = w[j] - tau
        return u
