import numpy as np
from collections import defaultdict
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import mutual_info_score


def mean_std(cnt_1, cnt_0):
    total = cnt_1 + cnt_0
    mean = cnt_1 / total
    var = (cnt_1 * cnt_0) / (total ** 2)
    # return mean, np.sqrt(var)
    return mean, var


def compute_metrics(data_dict):
    x, y = [], []
    for val_str, count_dict in data_dict.items():
        for label_str, count in count_dict.items():
            label = int(label_str)
            x.extend([int(val_str)] * count)
            y.extend([label] * count)
    pearson, pearson_p = pearsonr(x, y)
    spearman, spearman_p = spearmanr(x, y)
    mi = mutual_info_score(x, y)
    return (pearson, pearson_p), (spearman, spearman_p), mi


def get_bins(stat_question_dict, num_bins=10, pre_clip=0.0, post_clip=0.0):
    samples = []
    for val_str, count_dict in stat_question_dict.items():
        for label_str, count in count_dict.items():
            samples.extend([(int(val_str), int(label_str))] * count)
    samples = samples[int(len(samples)*pre_clip): int(len(samples)*(1-post_clip))]
    new_dict = defaultdict(lambda: {'1': 0, '0': 0})
    for key, value in samples:
        new_dict[str(key)][str(value)] += 1

    val_items = {}
    for val, label in samples:
        val_items.setdefault(val, []).append((val, label))
    val_items = sorted(val_items.items())
    target_bin_size = len(samples) / num_bins
    bins, current_bin, current_count = [], [], 0
    for _, sample_list in val_items:
        if current_count + len(sample_list) > target_bin_size and current_bin:
            bins.append(current_bin)
            current_bin = sample_list.copy()
            current_count = len(sample_list)
        else:
            current_bin.extend(sample_list)
            current_count += len(sample_list)
    if current_bin:
        bins.append(current_bin)
    return bins, new_dict

# def get_bins(stat_question_dict, num_bins=10):
#         x, y = [], []
#         for val_str, count_dict in stat_question_dict.items():
#             for label_str, count in count_dict.items():
#                 x.extend([int(val_str)] * count)
#                 y.extend([int(label_str)] * count)
#         samples = list(zip(x, y))
#         bin_size = len(samples) // num_bins
#         bins = [samples[i:i + bin_size] for i in range(0, len(samples), bin_size)]
#         return bins


