import numpy as np
import pandas as pd
import dataset
import sampler
import regressor
from explainer_cvx import Explainer as ExplainerCVX
from explainer_iht import Explainer as ExplainerIHT
from explainer_lime import Explainer as ExplainerLIME
from explainer_maple import Explainer as ExplainerMAPLE
from explainer_mip import Explainer as ExplainerMIP
from explainer_gurobi import Explainer as ExplainerGUROBI
from pathlib import Path
from utils import Console


class Experiment:
    # Constructor for experiments. The arguments are
    # the dataset id (openml),
    # the regressor name,
    # the sampler name,
    # the number of experiments
    # the number of training samples for generating the explanation
    # the number of test samples for evaluating the explanation
    # additional parameters about sampling and explainers
    def __init__(
        self,
        *,
        data_id: int,
        regressor_name: str,
        sampler_name: str,
        explainer_names: list,
        nb_experiments: int,
        nb_train: int,
        nb_test: int,
        maxsize=7,
        nb_bins=4,
        radius=0,
        spread=0.0,
        verbose=True,
        data_home=None,
        verbosity_model=False,
        timeout_mip=2,
    ):
        self.console = Console(verbose=verbose)
        self.dataId = data_id
        self.explainerNames = explainer_names
        self.regressorName = regressor_name
        self.samplerName = sampler_name
        self.nbExperiments = nb_experiments
        self.nbTrainSamples = nb_train
        self.nbTestSamples = nb_test
        self.maxsize = maxsize
        self.nbBins = nb_bins
        self.radius = radius
        self.spread = spread
        self.data_home = data_home
        self.timeout_mip = timeout_mip

        # Dictionary of explainers
        self.explainers = {
            "CVX": ExplainerCVX(norm=1, verbose=verbosity_model),
            "IHT": ExplainerIHT(
                iterations=1000,
                preprocessing=True,
                postprocessing=False,
                stepsize=1.0,
                verbose=verbosity_model,
            ),
            "LIME": ExplainerLIME(verbose=verbosity_model),
            "MIP": ExplainerMIP(timeout=self.timeout_mip, verbose=verbosity_model),
            "MAPLE": ExplainerMAPLE(verbose=False),
            "GUROBI": ExplainerGUROBI(
                timeout=self.timeout_mip, verbose=verbosity_model
            ),
        }

    # Print statistics of explainers
    def __str__(self):
        s = self.console.string("EXPERIMENTAL RESULTS", endl=True)
        s += "\n"
        s += str(self.data)
        s += "\n"
        s += str(self.regressor)
        s += "\n"
        s += str(self.sampler)
        s += "\n"
        s += self.console.string("Parameters", endl=True)
        s += self.console.string("# Experiments", second=self.nbExperiments, endl=True)
        s += self.console.string(
            "# Train Samples", second=self.nbTrainSamples, endl=True
        )
        s += self.console.string("# Test Samples", second=self.nbTestSamples, endl=True)
        s += self.console.string("# Sparsity Level", second=self.maxsize, endl=True)
        s += "\n"
        for name in self.explainerNames:
            s += self.console.string("Explainer", second=name, endl=True)
            s += f"{self.statistics[name]}\n\n"
        return s

    def init(self) -> None:
        # Setup data
        self.console.log("Setup dataset")
        self.data = dataset.Dataset(
            data_id=self.dataId,
            nb_bins=self.nbBins,
            verbose=True,
            data_home=self.data_home,
        )
        self.dataName = self.data.setup()
        if self.data.nbBinFeatures < self.maxsize:
            self.maxsize = self.data.nbBinFeatures
        # Setup regressor
        self.console.log("Setup regressor", self.regressorName)
        self.regressor = regressor.Regressor(name=self.regressorName, verbose=True)
        self.regressor.train(instances=self.data.X, labels=self.data.Y)
        # Setup sampler
        self.console.log("Setup sampler")
        self.sampler = sampler.Sampler(
            features=self.data.X.columns,
            name=self.samplerName,
            radius=self.radius,
            spread=self.spread,
            verbose=True,
        )
        self.sampler.setup()

    def run(self) -> None:
        self.console.log("Number of experiments", self.nbExperiments)

        # Setup statistics
        results = np.zeros((self.nbExperiments, 3))
        statistics = pd.DataFrame(
            np.zeros((3, 2)),
            columns=["avg", "std"],
            index=["error", "size", "wall time"],
        )
        self.results = {name: results.copy() for name in self.explainerNames}
        self.statistics = {name: statistics.copy() for name in self.explainerNames}

        # Run experiments
        for e in range(self.nbExperiments):

            self.console.log("Experiment", e + 1)
            target = self.sampler.query(instances=self.data.X)
            train_samples = self.sampler.sample(nb_samples=self.nbTrainSamples)
            test_samples = self.sampler.sample(nb_samples=self.nbTestSamples)
            self.console.log("Query", self.sampler.centerID)

            for name in self.explainerNames:
                self.console.log(f"[{name}] Run")
                explainer = self.explainers.get(name)
                explanation, walltime = explainer.explain(
                    features=self.data.X.columns,
                    maxsize=self.maxsize,
                    model=self.regressor,
                    samples=train_samples,
                    target=target,
                    test_samples=test_samples,
                )
                error = explainer.test(samples=test_samples)
                size = len(explanation)
                self.results[name][e] = [error, size, walltime]
                self.console.log(
                    f"[{name}] Statistics",
                    second=f"error = {error:.3f}, |S| = {size}, time = {walltime:.3f}s",
                )

        # Collect statistics
        for name in self.explainerNames:
            self.statistics[name]["avg"] = np.average(self.results[name], axis=0)
            self.statistics[name]["std"] = np.std(self.results[name], axis=0)
        print(self)

    # Save statistics
    def save(self) -> None:
        path = Path(
            f"./expt/[D]{self.dataName}_[M]{self.regressorName}_[S]{self.maxsize}_[#E]{self.nbExperiments}_[#Tr]{self.nbTrainSamples}_[#Te]{self.nbTestSamples}.csv"
        )
        path.parent.mkdir(parents=True, exist_ok=True)
        path.write_text(str(self))
