import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.cluster import KMeans

# we use K-means on the calibration dataset to identify what the buckets should be
# on our machine, K-means can't handle inputs that are too large
INTERP_SIZE = 10000
def get_interpolated(
    input,
    interp_size=INTERP_SIZE,
):
    N = len(input)
    x_original = np.linspace(0, N-1, N)
    x_interpolated = np.linspace(0, N-1, interp_size)
    input_interpolated = np.interp(x_interpolated, x_original, input)
    return input_interpolated

def find_centroids(
    values,
    num_clusters=5,
):
    X = np.array(values).reshape(-1, 1)
    if len(X) > INTERP_SIZE:
        X = get_interpolated(np.sort(X.flatten()))
        X = X[:, np.newaxis]

    kmeans = KMeans(n_clusters=num_clusters, random_state=0)
    kmeans.fit(X)
    centroids = kmeans.cluster_centers_
    return np.sort(centroids.flatten())

def get_bin_targets(
    values,
    num_clusters,
):
    centroids = find_centroids(values, num_clusters=num_clusters)
    bin_targets = {k: v for k, v in enumerate(centroids)}
    return bin_targets

### ASCE and Brier score calculation

# get calibration error for each example
def get_ce_df(
    df,
    score_col='score_bin',
):
    errors = df.groupby(score_col)[[score_col, 'correct']].mean()
    errors['calibration_error'] = (errors[score_col] - errors['correct'])
    errors['squared_calibration_error'] = errors['calibration_error'] ** 2
    del errors[score_col]

    sizes = df.groupby(score_col).size().to_frame()
    sizes.columns = ['size']

    errors = pd.merge(sizes, errors, left_index=True, right_index=True)
    return errors

# get ASCE
def get_asce(
    df,
    score_col='score_bin',
):
    errors = get_ce_df(df, score_col=score_col)
    sc_errors = errors['squared_calibration_error']
    sizes = errors['size']
    asce = sc_errors * sizes / sizes.sum()
    asce = asce.sum()
    return asce

# get Brier score (MSE)
def get_brier(
    df,
    score_col='score_bin',
):
    brier_score = (df['correct'] - df[score_col]) ** 2
    brier_score = brier_score.mean()
    return brier_score

# get summary of scores
def get_summary(
    score_col_name,
    df_calibration,
    df_test,
    asce_fn,
    brier_fn,
):
    # asce
    calibration_asce = asce_fn(df_calibration)
    test_asce = asce_fn(df_test)

    # brier
    calibration_brier = brier_fn(df_calibration)
    test_brier = brier_fn(df_test)

    df_summary = pd.DataFrame([
        ['calibration', 'asce', calibration_asce],
        ['test', 'asce', test_asce],
        ['calibration', 'brier', calibration_brier],
        ['test', 'brier', test_brier]],
        columns=['split', 'metric', score_col_name]
    )
    return df_summary


### Functions for getting results by group

def concat_group_dict(
    group_dict,
    score_col_name
):
    all_dfs = []
    for group_name, df in group_dict.items():
        group_values = df.index.values
        if len(df.index.names) > 1: # convert multi-index
            group_values = ['; '.join(np.array(x, dtype=str)) for x in group_values]

        df = df.reset_index(drop=True)
        df.columns = ['size', score_col_name]
        df.loc[:, 'group_cols'] = group_name
        df.loc[:, 'group_vals'] = group_values

        df = df[['group_cols', 'group_vals', 'size', score_col_name]]

        all_dfs.append(df)

    df = pd.concat(all_dfs).reset_index(drop=True)
    return df

def get_group_df(
    df, 
    score_fn, 
    group,
):
    x = df.groupby(group).size()
    y = df.groupby(group).apply(score_fn, include_groups=False)
    df = pd.merge(x.to_frame(), y.to_frame(), left_index=True, right_index=True)
    df.columns = ['size', 'score']
    return df

def get_group_df_all(
    df, 
    groups,
    score_fn,
    score_name,
    use_tqdm=True,
):
    out = {}
    for group in tqdm(groups) if use_tqdm else groups:
        group_name = '; '.join(group)
        out[group_name] = get_group_df(df, score_fn, group)
    df_out = concat_group_dict(out, score_name)
    return df_out

# get uncalibrated scores by group
def get_group_summary(
    score_col_name,
    groups,
    df_calibration,
    df_test,
    asce_fn,
    brier_fn,
):
    # asce
    df_calibration_group_asce = get_group_df_all(df_calibration, groups, asce_fn, score_col_name)
    df_calibration_group_asce.insert(0, 'metric', 'asce')
    df_calibration_group_asce.insert(0, 'split', 'calibration')
    df_test_group_asce = get_group_df_all(df_test, groups, asce_fn, score_col_name)
    df_test_group_asce.insert(0, 'metric', 'asce')
    df_test_group_asce.insert(0, 'split', 'test')

    # brier
    df_calibration_group_brier = get_group_df_all(df_calibration, groups, brier_fn, score_col_name)
    df_calibration_group_brier.insert(0, 'metric', 'brier')
    df_calibration_group_brier.insert(0, 'split', 'calibration')
    df_test_group_brier = get_group_df_all(df_test, groups, brier_fn, score_col_name)
    df_test_group_brier.insert(0, 'metric', 'brier')
    df_test_group_brier.insert(0, 'split', 'test')

    df_group_summary = pd.concat((
        df_calibration_group_asce,
        df_test_group_asce,
        df_calibration_group_brier,
        df_test_group_brier,
    )).reset_index(drop=True)
    return df_group_summary