import hashlib

import holisticai.bias.metrics as hai_metrics
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
from sklearn.ensemble import GradientBoostingClassifier, HistGradientBoostingClassifier
from sklearn.metrics import accuracy_score

from strategies import Collective_Strategy


def hash_config(config):
    # Convert the config to a standard dictionary
    config_dict = OmegaConf.to_container(config, resolve=True)
    # Serialize the dictionary to a JSON string
    config_str = OmegaConf.to_yaml(config_dict)
    # Create a hash of the serialized string
    config_hash = hashlib.md5(config_str.encode("utf-8")).hexdigest()
    return config_hash


def create_collective(
    in_group,
    alpha,
    majority_samples=None,
    seed=0,
):
    """Gets the indices of the collective members.
    group: binary vector indicating minority membership
    alpha: relative size of collective
    """
    # Get indices of the group members
    indices_minority = np.flatnonzero(in_group == 1)

    collective_indices = randomize_indices(
        indices=indices_minority, alpha=alpha, seed=seed
    )

    # # Size of the collective
    # num_samples = int(alpha * indices_minority.shape[0])

    # # Randomly select collective members
    # rng = np.random.default_rng(seed)
    # sampled_indices = rng.permutation(indices_minority.shape[0])[:num_samples]
    # selected_indices = indices_minority[sampled_indices]

    # initialize the collective membership vector
    in_collective = np.zeros(in_group.shape[0], dtype=bool)
    in_collective[collective_indices] = True

    if majority_samples is not None:
        indices_majority = np.flatnonzero(in_group == 0)
        collective_majority = randomize_indices(
            indices=indices_majority, num_samples=majority_samples, seed=seed
        )
        in_collective[collective_majority] = True

    return in_collective


def randomize_indices(indices, alpha=None, num_samples=None, seed=0):
    if num_samples is None:
        num_samples = int(alpha * len(indices))
    rng = np.random.default_rng(seed)
    sampled_indices = rng.permutation(len(indices))[:num_samples]
    selected_indices = indices[sampled_indices]
    return selected_indices


def collect_metrics(metrics: np.ndarray):
    keys = metrics.flatten()[0].keys()
    y_metrics = {k: np.zeros_like(metrics, dtype=float) for k in keys}
    for i in range(metrics.size):
        cur_index = np.unravel_index(i, metrics.shape)
        for k in keys:
            if metrics[cur_index] is None:
                y_metrics[k][cur_index] = np.nan
            else:
                y_metrics[k][cur_index] = metrics[cur_index][k]
    return y_metrics


def equal_odds_difference(y_true, y_pred, in_group):
    tpr = hai_metrics.equal_opportunity_diff(
        np.logical_not(in_group), in_group, y_pred, y_true
    )
    fpr = hai_metrics.false_positive_rate_diff(
        np.logical_not(in_group), in_group, y_pred, y_true
    )
    return float(0.5 * (np.abs(tpr) + np.abs(fpr)))


def get_metrics_dict(in_group, y_true, y_pred):
    bias_metrics = hai_metrics.classification_bias_metrics(
        group_a=np.logical_not(in_group),
        group_b=in_group,
        y_pred=y_pred,
        y_true=y_true,
    )

    metrics_dict = {}
    for idx, row in bias_metrics.iterrows():
        metrics_dict[idx] = row["Value"].item()
    metrics_dict["Accuracy"] = accuracy_score(y_true, y_pred)
    metrics_dict["Equal odds difference"] = equal_odds_difference(
        y_true, y_pred, in_group
    )

    metrics_dict["n_samples"] = y_true.shape[0]
    metrics_dict["n_minority"] = np.count_nonzero(in_group)
    metrics_dict["n_majority"] = np.count_nonzero(np.logical_not(in_group))
    metrics_dict["n_positive"] = np.count_nonzero(y_true)
    metrics_dict["n_negative"] = y_true.shape[0] - metrics_dict["n_positive"]

    metrics_dict["n_pos_minority"] = np.count_nonzero(np.logical_and(in_group, y_true))
    metrics_dict["n_neg_minority"] = np.count_nonzero(
        np.logical_and(in_group, np.logical_not(y_true))
    )
    metrics_dict["n_pos_majority"] = np.count_nonzero(
        np.logical_and(np.logical_not(in_group), y_true)
    )
    metrics_dict["n_neg_majority"] = np.count_nonzero(
        np.logical_and(np.logical_not(in_group), np.logical_not(y_true))
    )
    return metrics_dict


def train_tabular(
    features_train: pd.DataFrame,
    labels_train: np.ndarray,
    in_group_train: np.ndarray,
    features_test: pd.DataFrame,
    labels_test: np.ndarray,
    in_group_test: np.ndarray,
    in_collective: np.ndarray = None,
    strategy: Collective_Strategy = None,
    use_hist_gbm: bool = True,
    seed: int = 0,
):
    # initialize the model
    if use_hist_gbm:
        model = HistGradientBoostingClassifier(random_state=seed)
    else:
        model = GradientBoostingClassifier(random_state=seed)

    new_labels = labels_train.copy()
    if strategy is not None:
        if in_collective is None:
            in_collective = in_group_train.copy()
        inds_to_change = np.logical_and(in_collective, in_group_train)
        changed_labeles = strategy.act(
            features_train, labels_train, in_group_train, in_collective
        )
        new_labels[inds_to_change] = changed_labeles

    # train the model
    model.fit(features_train, new_labels)

    pred_train = model.predict(features_train)
    pred_test = model.predict(features_test)

    metrics_train = get_metrics_dict(in_group_train, labels_train, pred_train)
    metrics_test = get_metrics_dict(in_group_test, labels_test, pred_test)

    if strategy is not None:
        pos2neg = np.count_nonzero(
            np.logical_and(changed_labeles == 0, labels_train[inds_to_change] == 1)
        )
        neg2pos = np.count_nonzero(
            np.logical_and(changed_labeles == 1, labels_train[inds_to_change] == 0)
        )
        total_flipped = pos2neg + neg2pos
        metrics_train["pos2neg"] = pos2neg
        metrics_train["neg2pos"] = neg2pos
        metrics_train["total_flipped"] = total_flipped
        metrics_train["collective_size"] = np.count_nonzero(
            in_collective & in_group_train
        )
        metrics_train["collective_majority"] = np.count_nonzero(
            in_collective & ~in_group_train
        )
        metrics_test["pos2neg"] = pos2neg
        metrics_test["neg2pos"] = neg2pos
        metrics_test["total_flipped"] = total_flipped

    return metrics_train, metrics_test
