import joblib
from itertools import repeat
from matplotlib.path import Path
import matplotlib.pyplot as plt
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from src.models.models_helper.dp import add_retraining_cost
from sklearn.linear_model import LogisticRegression, SGDClassifier
from typing import Dict, List, Optional, Callable, Tuple, Union
from sklearn.model_selection import train_test_split
import ray
from importlib import import_module
Classifier = Union[SGDClassifier, LogisticRegression, RandomForestClassifier]


def strategy_cost_array(retrains: List[int], C: np.ndarray) -> np.ndarray:
    """computes the costs of a strategy given its retrains

    Args:
        retrains (List[int]): the batches when retrains occurred
        C (np.ndarray): The upper-triangular cost matrix on which to compute the cost

    Returns:
        np.ndarray: All the cost of the strategy based on the retrains
    """
    T = C.shape[0] - 1
    indices = np.zeros((T + 1))
    # set the retrain locations
    indices[retrains] = retrains
    # forward fill the array
    np.maximum.accumulate(indices, axis=0, out=indices)
    costs = C[indices.astype(int), np.arange(T + 1)]
    return costs


def compute_metric(retrains: List[int], C: np.ndarray, logic: Optional[Callable] = strategy_cost_array, reduction: Optional[str] = "sum"):
    costs = logic(retrains, C)
    if reduction == "sum":
        return costs.sum()
    elif reduction == "mean":
        return costs.mean()
    else:
        raise ValueError(f"Unknown {reduction=}. Please check parameters")


def _compute_strategy_cost(retrains: List[int], C: np.ndarray, retrain_cost: Optional[float] = 0) -> float:
    """computes the cost of a strategy given retrains

    Args:
        retrains (List[int]): the batches when retrains occurred
        C (np.ndarray): The upper-triangular cost matrix on which to compute the cost
        retrain_cost (float, optional): The cost of retraining to be explicitly added to the cost matrix diagonal. Defaults to 0
    Returns:
        float: the cost of the strategy based on the retrains
    """
    return compute_metric(retrains, add_retraining_cost(C, retrain_cost), strategy_cost_array, "sum")


def _compute_prequential_accuracy(retrains: List[int], A: np.ndarray) -> float:
    """computes the prequential accuracy from the retrains

    Args:
        retrains (List[int]): the batches when retrains occurred
        A (np.ndarray): The upper-triangular accuracy matrix

    Returns:
        float: the prequential accuracy of the retrains
    """
    return compute_metric(retrains, A, prequential_cost_array, "mean")


def compute_strategy_cost(result: dict, C: np.ndarray, retrain_cost: Optional[float] = 0) -> float:
    """computes the cost of a strategy from its result dictionary

    Args:
        result (dict): the result dictionary from running a strategy. Must have the "retrains" key
        C (np.ndarray): The upper-triangular cost matrix on which to compute the cost
        retrain_cost (float, optional): The cost of retraining to be explicitly added to the cost matrix diagonal. Defaults to 0
    Returns:
        float: the cost of the strategy based on the retrains
    """
    # get the retrains from the results
    assert "retrains" in result
    retrains = result["retrains"]
    return _compute_strategy_cost(retrains, C, retrain_cost)


def compute_prequential_accuracy(result: dict, A: np.ndarray) -> float:
    """computes the prequential accuracy from a result dictionary

    Args:
        result (dict): the result dictionary from running a strategy. Must have the "retrains" key
        A (np.ndarray): The upper-triangular accuracy matrix

    Returns:
        float: the prequential accuracy of the retrains
    """
    # get the retrains from the results
    assert "retrains" in result
    retrains = result["retrains"]
    return _compute_prequential_accuracy(retrains, A)


def prequential_cost_array(retrains, C):
    """computes the prequential costs of a strategy

    Args:
        retrains (List[int]): the batches when retrains occurred
        C (np.ndarray): The upper-triangular cost matrix on which to compute the cost

    Returns:
        np.ndarray: All the cost of the strategy based on the retrains
    """
    T = C.shape[0] - 1
    indices = np.zeros((T + 1))
    # Prequential means test then train.
    # So if retrain at t then from t+1 use new model
    prequential_mask = np.clip(np.array(retrains)+1, 0, T)
    indices[prequential_mask] = retrains
    # forward fill the array
    np.maximum.accumulate(indices, axis=0, out=indices)
    indices
    costs = C[indices.astype(int), np.arange(T + 1)]
    return costs


def vary_cost(C: np.ndarray, strategy: Callable, retrain_costs=np.linspace(0, 1, 100)):
    T = C.shape[0] - 1
    strategies = []
    for cost in retrain_costs:
        strategies.append(strategy(C, cost))
    return strategies


def plot_vary_cost(C, C_baseline, retrain_costs=np.linspace(0, 1, 100), fig=None):
    if fig is None:
        fig = plt.figure()

    (ax1, ax2, ax3) = fig.subplots(1, 3)  # create 1x2 subplots on subfig1
    acc_gains, acc_strats = vary_cost(C, retrain_costs)
    acc_num_retrains = [len(s) - 1 for s in acc_strats]
    acc_strat_util = [
        compute_strategy_cost(s, add_retraining_cost(C_baseline, cost))
        for s, cost in zip(acc_strats, retrain_costs)
    ]
    ax1.plot(retrain_costs, acc_gains)
    ax1.set_ylabel("Avg decrease in cost")
    ax1.set_xlabel("Retrain Cost")
    ax2.plot(retrain_costs, acc_num_retrains)
    ax2.set_ylabel("Number of Retrains")
    ax2.set_xlabel("Retrain Cost")
    ax3.plot(retrain_costs, acc_strat_util)
    ax3.set_ylabel("Optimal Strategy Query Misclassifications")
    ax3.set_xlabel("Retrain Cost")


def _plot_single_retrain(T, retrains, ax, color, **kwargs):
    if 0 not in retrains:
        retrains.append(0)
    retrains = sorted(retrains)
    if len(retrains) > 1:
        for i in range(len(retrains) - 1):
            i1, i2 = retrains[i], retrains[i + 1]
            ax.plot([i1, i2 - 1], [i1, i1], color=color,
                    linestyle="-", marker="o", markevery=[0], **kwargs)
            ax.plot([i2 - 1, i2, i2, i2], [i1, i1, i1, i2],
                    color=color, linestyle=":", **kwargs)
        ax.plot([i2, T], [i2, i2], color=color, marker="o",
                linestyle="-", markevery=[0], **kwargs)
    else:
        ax.plot([0, T], [0, 0], color=color,
                marker="o", linestyle="-", **kwargs)


def plot_retrains(U, retrains, colors, ax=None, **kwargs):
    imshow_kwargs = kwargs.pop("imshow", {})
    retrain_kwargs = kwargs.pop("retrain", {})
    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = plt.gcf()
    T = U.shape[0] - 1
    im = ax.imshow(U, **imshow_kwargs)
    # im = ax.imshow(U)
    for retrain, color in zip(retrains, colors):
        _plot_single_retrain(T, retrain, ax=ax, color=color, **retrain_kwargs)
    return im


def get_online_offline_data_dict(
    X_full: List[np.ndarray],
    y_full: List[np.ndarray],
    offline_split: int = 25,
    queries_split: float = 0.1,
    use_ray: bool = False,
) -> Tuple[dict, dict]:
    """Gather the data dictionary from the Data Stream

    Args:
        X_full (List[np.ndarray]): The stream of points. Each array has a batch of points.
        y_full (List[np.ndarray]): The stream of targets. Each array has a batch of targets.
        offline_split (int, optional): The number of batches to consider as offline. Defaults to 25.
        queries_split (float, optional): The amount of data in each batch to be samples as potential query points. Defaults to 0.1.
        ray (bool): Whether or not to return ray object references in the dictionaries. Defaults to False.
    Returns:
        Tuple[dict,dict]: the offline and online data dictionaries
    """
    X_offline = X_full[:offline_split]
    X_online = X_full[offline_split:]
    y_offline = y_full[:offline_split]
    y_online = y_full[offline_split:]

    offline = {}
    for t, (_X, _y) in enumerate(zip(X_offline, y_offline)):
        X_train, X_query, y_train, y_query = train_test_split(
            _X, _y, test_size=queries_split, random_state=0
        )
        offline[t] = {
            "X_train": X_train,
            "y_train": y_train,
            "X_query": X_query,
            "y_query": y_query,
        }

    online = {}
    for t, (_X, _y) in enumerate(zip(X_online, y_online)):
        X_train, X_query, y_train, y_query = train_test_split(
            _X, _y, test_size=queries_split, random_state=0
        )
        online[t] = {
            "X_train": X_train,
            "y_train": y_train,
            "X_query": X_query,
            "y_query": y_query,
        }

    # if ray objects are required
    if use_ray:
        for k in offline:
            offline[k] = ray.put(offline[k])
        for k in online:
            online[k] = ray.put(online[k])

    return offline, online


def get_online_offline_data_dict_with_onehot(
    X_full: List[np.ndarray],
    X_onehot_full: List[np.ndarray],
    y_full: List[np.ndarray],
    offline_split: int = 25,
    queries_split: float = 0.1,
    use_ray: bool = False,
) -> Tuple[dict, dict]:
    """Gather the data dictionary from the Data Stream

    Args:
        X_full (List[np.ndarray]): The stream of points. Each array has a batch of points.
        X_full (List[np.ndarray]): The stream of one-hot encoded points. Each array has a batch of points.
        y_full (List[np.ndarray]): The stream of targets. Each array has a batch of targets.
        offline_split (int, optional): The number of batches to consider as offline. Defaults to 25.
        queries_split (float, optional): The amount of data in each batch to be samples as potential query points. Defaults to 0.1.
        ray (bool): Whether or not to return ray object references in the dictionaries. Defaults to False.
    Returns:
        Tuple[dict,dict]: the offline and online data dictionaries
    """
    X_offline = X_full[:offline_split]
    X_online = X_full[offline_split:]
    X_onehot_offline = X_onehot_full[:offline_split]
    X_onehot_online = X_onehot_full[offline_split:]
    y_offline = y_full[:offline_split]
    y_online = y_full[offline_split:]

    offline = {}
    for t, (_X, _X_onehot, _y) in enumerate(zip(X_offline, X_onehot_offline, y_offline)):
        X_train, X_query, X_onehot_train, X_onehot_query, y_train, y_query = train_test_split(
            _X, _X_onehot, _y, test_size=queries_split, random_state=0
        )
        offline[t] = {
            "X_train": X_train,
            "X_onehot_train": X_onehot_train,
            "y_train": y_train,
            "X_query": X_query,
            "X_onehot_query": X_onehot_query,
            "y_query": y_query,
        }

    online = {}
    for t, (_X, _X_onehot, _y) in enumerate(zip(X_online, X_onehot_online, y_online)):
        X_train, X_query, X_onehot_train, X_onehot_query, y_train, y_query = train_test_split(
            _X, _X_onehot, _y, test_size=queries_split, random_state=0
        )
        online[t] = {
            "X_train": X_train,
            "X_onehot_train": X_onehot_train,
            "y_train": y_train,
            "X_query": X_query,
            "X_onehot_query": X_onehot_query,
            "y_query": y_query,
        }

    # if ray objects are required
    if use_ray:
        for k in offline:
            offline[k] = ray.put(offline[k])
        for k in online:
            online[k] = ray.put(online[k])

    return offline, online


def get_online_offline_query_dict(
    X_full: List[np.ndarray],
    y_full: Optional[List[np.ndarray]] = None,
    offline_split: int = 25,
    use_ray: bool = False,
) -> Tuple[dict, dict]:
    """Gather the query dictionary from the Qata Stream

    Args:
        X_full (List[np.ndarray]): The stream of points. Each array has a batch of points.
        y_full (List[np.ndarray], optional): The optional stream of targets if ground truth is available. Defaults to None.
        offline_split (int, optional): The number of batches to consider as offline. Defaults to 25.
        use_ray (bool): Whether or not to return Ray object references in the dictionaries. Defaults to False.

    Returns:
        Tuple[dict,dict]: the offline and online query dictionaries
    """
    X_offline = X_full[:offline_split]
    X_online = X_full[offline_split:]
    if y_full is not None:
        y_offline = y_full[:offline_split]
        y_online = y_full[offline_split:]
        offline_iterator = zip(X_offline, y_offline)
        online_iterator = zip(X_online, y_online)
    else:
        offline_iterator = zip(X_offline, repeat(0))
        online_iterator = zip(X_online, repeat(0))

    offline = {}
    for t, (_Xoffline, _yoffline) in enumerate(offline_iterator):
        offline[t] = {"X_query": _Xoffline}
        # if there are ground truth query labels
        if y_full is not None:
            offline[t]["y_query"] = _yoffline

    online = {}
    for t, (_Xonline, _yonline) in enumerate(online_iterator):
        online[t] = {"X_query": _Xonline}
        # if there are ground truth query labels
        if y_full is not None:
            online[t]["y_query"] = _yonline

    # if ray objects are required
    if use_ray:
        for k in offline:
            offline[k] = ray.put(offline[k])
        for k in online:
            online[k] = ray.put(online[k])

    return offline, online


def get_model_dict(
    data_dict: Dict[list, dict], model_class: Classifier, use_ray: bool = False, sparse_train: bool = True, **kwargs
) -> Dict[int, Classifier]:
    """Train model using data batches

    Args:
        data_dict (Dict[list, dict]): The dictionary containing data batches
        model_class (Classifier): The sklearn model class
        use_ray (bool, optional): Whether or not to return ray objects in dictionary. Defaults to False.
        sparse_train (bool, optional): Whether training supports sparse matrices. Defaults to True.

    Returns:
        _type_: _description_
    """
    models = {}
    for (t, data) in data_dict.items():
        if use_ray:
            data = ray.get(data)
            X_train = data["X_train"]
            if not sparse_train:
                X_train = X_train.toarray()
            _model = model_class(**kwargs).fit(X_train, data["y_train"])
            models[t] = ray.put(_model)
        else:
            X_train = data["X_train"]
            if not sparse_train:
                X_train = X_train.toarray()
            _model = model_class(**kwargs).fit(X_train, data["y_train"])
            models[t] = _model
    return models


def save_model_dicts(offline_models: Dict[list, dict], online_models: Dict[list, dict], dir: Path, use_ray: bool = False):
    file_str = dir/"models.joblib"
    if use_ray:
        for key in offline_models:
            offline_models[key] = ray.get(offline_models[key])
        for key in online_models:
            online_models[key] = ray.get(online_models[key])
    joblib.dump([offline_models, online_models], file_str)


def load_model_dicts(file: Path, use_ray: bool = False):
    offline_models, online_models = joblib.load(file)
    if use_ray:
        for key in offline_models:
            offline_models[key] = ray.put(offline_models[key])
        for key in online_models:
            online_models[key] = ray.put(online_models[key])

    return offline_models, online_models


@ray.remote
def cost_fn_i_j(i, j, cost_fn, *args, **kwargs):
    return i, j, cost_fn(*args, **kwargs)


def compute_cost_matrix(
    data_dict: Dict[list, dict],
    model_dict: Dict[int, Classifier],
    cost_fn: Callable,
    query_dict: Dict[list, dict],
    use_ray: bool = False,
    **kwargs
):
    T = len(data_dict) - 1
    C = np.full((T + 1, T + 1), np.inf)
    refs = []
    for i in range(T + 1):
        for j in range(i, T + 1):
            # model trained at batch i
            _model = model_dict[i]
            # data used to train model from batch i
            _model_data = data_dict[i]
            # data from at batch j
            _new_data = data_dict[j]
            # queries from batch j
            _query_data = query_dict[j]
            if not use_ray:
                # compute cost function
                C[i, j] = cost_fn(_model, _model_data,
                                  _new_data, _query_data, **kwargs)
            else:
                # create ray jobs
                refs.append(
                    cost_fn_i_j.remote(
                        i, j, cost_fn, _model, _model_data, _new_data, _query_data, **kwargs
                    )
                )

    if use_ray:
        # gather ray jobs in a loop
        unfinished = refs
        while unfinished:
            finished, unfinished = ray.wait(unfinished, num_returns=1)
            i, j, c = ray.get(finished[0])
            C[i, j] = c

    return C


@ray.remote
def cost_fn_i_j_k(i, j, k, cost_fn, *args, **kwargs):
    return i, j, k, cost_fn(*args, **kwargs)


def compute_cost_matrix_multiple(
    data_dict: Dict[list, dict],
    model_dict: Dict[int, Classifier],
    cost_fns: List[Callable],
    query_dict: Dict[list, dict],
    use_ray: bool = False,
    cost_kwargs: Optional[List[dict]] = [{}]
):
    T = len(data_dict) - 1
    K = len(cost_fns)
    Cs = []
    for _ in range(K):
        Cs.append(np.full((T + 1, T + 1), np.inf))
    refs = []
    for i in range(T + 1):
        for j in range(i, T + 1):
            # model trained at batch i
            _model = model_dict[i]
            # data used to train model from batch i
            _model_data = data_dict[i]
            # data from at batch j
            _new_data = data_dict[j]
            # queries from batch j
            _query_data = query_dict[j]
            for k, cost_fn in enumerate(cost_fns):
                if not use_ray:
                    # compute cost function
                    Cs[k][i, j] = cost_fn(
                        _model, _model_data, _new_data, _query_data, **cost_kwargs[k])
                else:
                    # create ray jobs
                    refs.append(
                        cost_fn_i_j_k.remote(
                            i, j, k, cost_fn, _model, _model_data, _new_data, _query_data, **cost_kwargs[k]
                        )
                    )

    if use_ray:
        # gather ray jobs in a loop
        unfinished = refs
        while unfinished:
            finished, unfinished = ray.wait(unfinished, num_returns=1)
            i, j, k, c = ray.get(finished[0])
            Cs[k][i, j] = c

    return Cs


def import_obj(string):
    if not isinstance(string, str):
        raise ImportError('Object type should be string.')

    mod_str, _sep, class_str = string.rpartition('.')
    try:
        mod = import_module(mod_str)
        return getattr(mod, class_str)
    except ModuleNotFoundError:
        raise ModuleNotFoundError(
            'Object {} cannot be found in {}.'.format(class_str, mod_str))


def save_matrices(d: dict, results_dir: Path):
    save_dict = {}
    for name, mat in d.items():
        N = mat.shape[0]
        indices = np.triu_indices(N)
        values = mat[indices]
        save_dict.update({
            f"{name}_N": N,
            f"{name}_values": values})

    np.savez_compressed(
        results_dir/f"costs.npz",
        **save_dict
    )


def load_cost_matrices(file: Union[str, Path]):
    npzfile = np.load(file)
    names = list({s.rsplit("_", 1)[0] for s in npzfile.keys()})
    matrices = {}
    for name in names:
        N = npzfile[f"{name}_N"]
        indices = np.triu_indices(N)
        _C = np.full((N, N), np.inf)
        _C[indices] = npzfile[f"{name}_values"]
        matrices[name] = _C
    return matrices


def generate_static_queries(query_centers: List[Tuple[float, float]], std: float = 0.1, num_batches: int = 100, points_per_batch: int = 100, seed: int = 0):
    rng = np.random.default_rng(seed=seed)
    X = []
    points_per_query_centre = int(points_per_batch/len(query_centers))
    for t in range(num_batches):
        _X = []
        for loc in query_centers:
            _X.append(rng.normal(loc=loc, scale=std,
                      size=(points_per_query_centre, 2)))
        _X = np.concatenate(_X)
        X.append(_X)

    return X
