import time
import numpy as np
import copy
import logging
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import cross_val_score
from ngboost import NGBRegressor
from ngboost.distns import Normal
from ngboost.scores import LogScore

from naslib.predictors.predictor import Predictor
from naslib.predictors.lcsvr import loguniform
from naslib.predictors.zerocost import ZeroCost
from naslib import utils
from naslib.search_spaces.core.query_metrics import Metric

logger = logging.getLogger(__name__)


def parse_params(params, identifier):
    to_return = {}
    for k, v in params.items():
        if k.startswith(identifier):
            to_return[k.replace(identifier, "")] = v

    return to_return


class OmniNGBPredictor(Predictor):
    def __init__(
        self,
        zero_cost,
        lce,
        encoding_type,
        ss_type=None,
        config=None,
        n_hypers=35,
        run_pre_compute=True,
        min_train_size=0,
        max_zerocost=np.inf,
    ):

        self.zero_cost = zero_cost
        self.encoding_type = encoding_type
        self.config = config
        self.n_hypers = n_hypers
        self.config = config
        self.lce = lce
        self.ss_type = ss_type
        self.run_pre_compute = run_pre_compute
        self.min_train_size = min_train_size
        self.max_zerocost = max_zerocost

    def pre_compute(self, xtrain, xtest):
        """
        All of this computation could go into fit() and query(), but we do it
        here to save time, so that we don't have to re-compute Jacobian covariances
        for all train_sizes when running experiment_types that vary train size or fidelity.
        """
        self.xtrain_zc_info = {}
        self.xtest_zc_info = {}

        if len(self.zero_cost) > 0:
            self.train_loader, _, _, _, _ = utils.get_train_val_loaders(
                self.config, mode="train"
            )

            for method_name in self.zero_cost:
                zc_method = ZeroCost(method_type=method_name)
                zc_method.train_loader = copy.deepcopy(self.train_loader)
                xtrain_zc_scores = zc_method.query(xtrain)
                xtest_zc_scores = zc_method.query(xtest)

                train_mean = np.mean(np.array(xtrain_zc_scores))
                train_std = np.std((np.array(xtrain_zc_scores)))

                normalized_train = (np.array(xtrain_zc_scores) - train_mean) / train_std
                normalized_test = (np.array(xtest_zc_scores) - train_mean) / train_std

                self.xtrain_zc_info[f"{method_name}_scores"] = normalized_train
                self.xtest_zc_info[f"{method_name}_scores"] = normalized_test

    def get_random_params(self):
        params = {
            "param:n_estimators": int(loguniform(128, 512)),
            "param:learning_rate": loguniform(0.001, 0.1),
            "param:minibatch_frac": np.random.uniform(0.1, 1),
            "base:max_depth": np.random.choice(24) + 1,
            "base:max_features": np.random.uniform(0.1, 1),
            "base:min_samples_leaf": np.random.choice(18) + 2,
            "base:min_samples_split": np.random.choice(18) + 2,
        }
        return params

    def run_hpo(self, xtrain, ytrain):
        min_score = 100000
        best_params = None
        for i in range(self.n_hypers):
            params = self.get_random_params()
            for key in ["base:min_samples_leaf", "base:min_samples_split"]:
                params[key] = max(2, min(params[key], int(len(xtrain) / 3) - 1))

            score = self.cross_validate(xtrain, ytrain, params)
            if score < min_score:
                min_score = score
                best_params = params
                logger.info("{} new best {}, {}".format(i, score, params))
        return best_params

    def cross_validate(self, xtrain, ytrain, params):
        base_learner = DecisionTreeRegressor(
            criterion="friedman_mse",
            random_state=None,
            splitter="best",
            **parse_params(params, "base:"),
        )
        model = NGBRegressor(
            Dist=Normal,
            Base=base_learner,
            Score=LogScore,
            verbose=False,
            **parse_params(params, "param:"),
        )
        scores = cross_val_score(model, xtrain, ytrain, cv=3)
        return np.mean(scores)

    def prepare_features(self, xdata, info, train=True):
        # prepare training data features
        full_xdata = [[] for _ in range(len(xdata))]
        if len(self.zero_cost) > 0 and self.train_size <= self.max_zerocost:
            if self.run_pre_compute:
                for key in self.xtrain_zc_info:
                    if train:
                        full_xdata = [
                            [*x, self.xtrain_zc_info[key][i]]
                            for i, x in enumerate(full_xdata)
                        ]
                    else:
                        full_xdata = [
                            [*x, self.xtest_zc_info[key][i]]
                            for i, x in enumerate(full_xdata)
                        ]
            else:
                # if the zero_cost scores were not precomputed, they are in info
                full_xdata = [[*x, info[i]] for i, x in enumerate(full_xdata)]

        if "sotle" in self.lce and len(info[0]["TRAIN_LOSS_lc"]) >= 3:
            train_losses = np.array([lcs["TRAIN_LOSS_lc"][-1] for lcs in info])
            mean = np.mean(train_losses)
            std = np.std(train_losses)
            normalized = (train_losses - mean) / std
            full_xdata = [[*x, normalized[i]] for i, x in enumerate(full_xdata)]

        elif "sotle" in self.lce and len(info[0]["TRAIN_LOSS_lc"]) < 3:
            logger.info("Not enough fidelities to use train loss")

        if "valacc" in self.lce and len(info[0]["VAL_ACCURACY_lc"]) >= 3:
            val_accs = [lcs["VAL_ACCURACY_lc"][-1] for lcs in info]
            mean = np.mean(val_accs)
            std = np.std(val_accs)
            normalized = (val_accs - mean) / std
            full_xdata = [[*x, normalized[i]] for i, x in enumerate(full_xdata)]

        if self.encoding_type is not None:
            xdata_encoded = np.array(
                [
                    arch.encode(encoding_type=self.encoding_type)
                    for arch in xdata
                ]
            )
            full_xdata = [[*x, *xdata_encoded[i]] for i, x in enumerate(full_xdata)]

        return np.array(full_xdata)

    def fit(self, xtrain, ytrain, train_info, learn_hyper=True):

        # if we are below the min train size, use the zero_cost and lce info
        if len(xtrain) < self.min_train_size:
            self.trained = False
            return None
        self.trained = True
        self.train_size = len(xtrain)

        # prepare training data labels
        self.mean = np.mean(ytrain)
        self.std = np.std(ytrain)
        ytrain = (np.array(ytrain) - self.mean) / self.std
        xtrain = self.prepare_features(xtrain, train_info, train=True)
        params = self.run_hpo(xtrain, ytrain)

        # todo: this code is repeated in cross_validate
        base_learner = DecisionTreeRegressor(
            criterion="friedman_mse",
            random_state=None,
            splitter="best",
            **parse_params(params, "base:"),
        )
        self.model = NGBRegressor(
            Dist=Normal,
            Base=base_learner,
            Score=LogScore,
            verbose=True,
            **parse_params(params, "param:"),
        )
        self.model.fit(xtrain, ytrain)

    def query(self, xtest, info):
        if self.trained:
            test_data = self.prepare_features(xtest, info, train=False)
            return np.squeeze(self.model.predict(test_data)) * self.std + self.mean
        else:
            logger.info("below the train size, so returning info")
            return info

    def get_data_reqs(self):
        """
        Returns a dictionary with info about whether the predictor needs
        extra info to train/query.
        """
        if len(self.lce) > 0:
            # add the metrics needed for the lce predictors
            required_metric_dict = {
                "sotle": Metric.TRAIN_LOSS,
                "valacc": Metric.VAL_ACCURACY,
            }
            self.metric = [required_metric_dict[key] for key in self.lce]

            reqs = {
                "requires_partial_lc": True,
                "metric": self.metric,
                "requires_hyperparameters": False,
                "hyperparams": {},
                "unlabeled": False,
                "unlabeled_factor": 0,
            }
        else:
            reqs = super().get_data_reqs()

        return reqs
