import numpy as np
import pandas as pd
import time
from typing import Any, Tuple
from utils import Console, Convert, Metrics

from sklearn.linear_model import LinearRegression, Ridge
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import root_mean_squared_error


class Explainer:
    # Constructor for MAPLE explainer with default parameters
    def __init__(
        self,
        estimator: Any = RandomForestRegressor(n_estimators=100),
        regressor: Any = Ridge(alpha=1.0),
        verbose: bool = False,
    ) -> None:
        self.estimator = estimator
        self.regressor = regressor
        self.console = Console(verbose=verbose)

    # print MAPLE attributes
    def __str__(self):
        str = self.console.string("[MAPLE] Information", endl=True)
        str += self.console.string("Explanation", self.explanation, endl=True)
        str += self.console.string("Size", len(self.explanation), endl=True)
        str += self.console.string("Wall time", self.walltime, endl=True)
        str += self.console.string("Error", self.error, endl=True)
        return str

    # Main function
    def explain(
        self,
        *,
        features: pd.Index,
        maxsize: int,
        model: Any,
        samples: np.ndarray,
        target: np.ndarray,
        test_samples=None,
    ) -> Tuple[dict, float]:
        self.console.log("[MAPLE] Find explanation")
        self.model = model

        # Solve problem
        self.solve(features=features, maxsize=maxsize, samples=samples, target=target)

        # Extract explanation and statistics
        self.explanation = self.extract(features=features)
        self.console.log("[MAPLE] Explanation")
        self.console.log(self.explanation)
        self.console.log("[MAPLE] Explanation Size", len(self.explanation))
        self.console.log("[MAPLE] Wall time", self.walltime)
        return self.explanation, self.walltime

    # Test explanation
    def test(self, *, samples: np.ndarray) -> float:
        self.console.log("[MAPLE] Test solution")
        p = self.regressor.predict(samples)
        b = self.model.predict(samples)
        error = root_mean_squared_error(b, p)
        self.console.log("[MAPLE] Error", error)
        self.error = error
        return error

    # Build problem model and run MAPLE explainer
    def solve(
        self,
        *,
        features: pd.Index,
        maxsize: int,
        samples: np.ndarray,
        target: np.ndarray,
    ) -> None:
        # Get data
        A = samples
        b = self.model.predict(A)
        k = maxsize
        x = target

        # Setup MAPLE Explainer
        time_start = time.time()
        self.fit_estimator(A, b)

        self.fit_regressor(x)
        time_end = time.time()
        self.walltime = max(0, 0, time_end - time_start)
        # Extract solution
        self.solution = self.extract(features=features)

    def extract(self, features: pd.Index) -> dict:
        coefficients = self.regressor.coef_.copy()
        intercept = self.regressor.intercept_
        (indices,) = np.nonzero(coefficients)
        dictionary = {}
        for j in indices:
            dictionary[features[j]] = coefficients[j]
        if intercept != 0.0:
            dictionary["Bias"] = intercept
        return dictionary

    # Maple function: fit estimation model
    def fit_estimator(self, samples: np.ndarray, labels: np.ndarray):
        A = samples.copy()
        b = labels.copy()
        self.estimator.fit(A, b)
        self.A_fitted = A
        self.b_fitted = b
        self.train_leaf_ix = self.estimator.apply(A)

    # Maple function: fit linear regression model
    def fit_regressor(self, target: np.ndarray):
        x = target
        w = self.get_weights(x)
        self.regressor.fit(self.A_fitted, self.b_fitted, sample_weight=w)

    # Maple function: get weights
    def get_weights(self, target: np.ndarray):
        K = self.estimator.n_estimators
        self.console.log("Number of trees", K)
        leaf_x = self.estimator.apply([target])[0]
        weights = []
        for ix in self.train_leaf_ix:
            weights.append(sum(ix == leaf_x) / K)
        self.console.log("Weights", weights)
        return weights
