import numpy as np
import pandas as pd

from bound import brute_force_optimal
from mechanisms import *


TRIALS            = 25
N_RANGE           = range(10, 60, 10)

UNIFORM           = np.random.rand
BETA_SYM          = lambda n: np.random.beta(5, 5, n)
TRI_SYM           = lambda n: np.random.triangular(0, 0.5, 1, n)
SYM_DISTS         = [UNIFORM, BETA_SYM, TRI_SYM]
SYM_DIST_LABELS   = ["uniform", "beta_sym", "triangle_sym"]

BETA_ASYM         = lambda n: np.random.beta(1, 9, n)
TRI_ASYM          = lambda n: np.random.triangular(0, 0, 1, n)
ASYM_DISTS        = [BETA_ASYM, TRI_ASYM]
ASYM_DIST_LABELS  = ["beta_asym", "triangle_asym"]

COLUMN_NAMES      = ["k_val", "dist", "n", "mech", "sw_mean", "sw_std"]

ALPHA_VALS        = [0.1, 0.2, 0.3, 0.4]
MECHS_EQUAL       = [brute_force_optimal, esp_worst, esp_equal_k_best]
MECH_EQUAL_LABELS = ["opt", "esp_sym_worst", "esp_sym_best"]

ALPHA_PAIRS       = [(0.4, 0.3), (0.6, 0.2), (0.7, 0.1)]
MECHS_DIFF        = [brute_force_optimal, esp_worst, esp_diff_k_best]
MECH_DIFF_LABELS  = ["opt", "esp_asym_worst", "esp_asym_best"]

MIXED_COLUMNS     = ["ratios", "alpha", "n", "mech", "sw_mean", "sw_std"]
ALPHAS_MIXED      = [0.1, 0.2, 0.3]
MIXED_LABELS      = ["opt", "esp_mixed_best"]
MECHS_MIXED       = [brute_force_optimal, esp_mixed_equal_k_best]

DIST_RATIOS_SYM   = [(0.1, 0.3, 0.6), (0.2, 0.5, 0.3), (0.3, 0.4, 0.3)]
DIST_RATIOS_ASYM  = [(0.1, 0.3, 0.6), (0.2, 0.5, 0.3), (0.3, 0.4, 0.3)]


def run_all_tests():
    # This is horrible code but it lets you easily choose which tests to run
    print("k1 = k2  " + "=" * 50)
    if False:
        print("Running symmetric tests...")
        data = equal_tests(TRIALS, ALPHA_VALS, SYM_DISTS, SYM_DIST_LABELS, MECHS_EQUAL, MECH_EQUAL_LABELS)
        save_data(data, COLUMN_NAMES, "results/equal_symmetric.csv")

    if False:
        print("Running asymmetric tests...")
        data = equal_tests(TRIALS, ALPHA_VALS, ASYM_DISTS, ASYM_DIST_LABELS, MECHS_EQUAL, MECH_EQUAL_LABELS)
        save_data(data, COLUMN_NAMES, "results/equal_asymmetric.csv")

    print("k1 > k2  " + "=" * 50)
    if False:
        print("Running symmetric tests...")
        data = diff_tests(TRIALS, ALPHA_PAIRS, SYM_DISTS, SYM_DIST_LABELS, MECHS_DIFF, MECH_DIFF_LABELS)
        save_data(data, COLUMN_NAMES, "results/diff_symmetric.csv")
    if False:
        print("Running asymmetric tests...")
        data = diff_tests(TRIALS, ALPHA_PAIRS, ASYM_DISTS, ASYM_DIST_LABELS, MECHS_DIFF, MECH_DIFF_LABELS)
        save_data(data, COLUMN_NAMES, "results/diff_asymmetric.csv")

    print("Mixed k1 = k2  " + "=" * 50)
    if False:
        print("Running mixed sym tests...")
        data = mixed_tests(TRIALS, ALPHAS_MIXED, mixed_dist_sym, DIST_RATIOS_SYM,
                           MECHS_MIXED, MIXED_LABELS)
        save_data(data, MIXED_COLUMNS, "results/equal_mixed_sym.csv")
    if True:
        print("Running mixed asym tests...")
        data = mixed_tests(TRIALS, ALPHAS_MIXED, mixed_dist_asym, DIST_RATIOS_SYM,
                           MECHS_MIXED, MIXED_LABELS)
        save_data(data, MIXED_COLUMNS, "results/equal_mixed_asym.csv")


def equal_tests(trials, alpha_vals, dists, dist_labels, mechs, mech_labels):
    data = []
    for alpha in alpha_vals:
        for dist, dist_label in zip(dists, dist_labels):
            for n in N_RANGE:
                print(f"\t{alpha}  {dist_label}  {n}")
                k = alpha_to_k(alpha, n)
                scores = test(k, k, n, dist, mechs, trials)
                for mech_label, score in zip(mech_labels, scores):
                    data.append([
                        alpha, dist_label, n, mech_label,
                        score.mean(), score.std()])
    return data


def diff_tests(trials, alpha_pairs, dists, dist_labels, mechs, mech_labels):
    data = []
    for a1, a2 in alpha_pairs:
        for dist, dist_label in zip(dists, dist_labels):
            for n in N_RANGE:
                print(f"\t{a1},{a2}  {dist_label}  {n}")
                k1 = alpha_to_k(a1, n)
                k2 = alpha_to_k(a2, n)
                scores = test(k1, k2, n, dist, mechs, trials)
                for mech_label, score in zip(mech_labels, scores):
                    data.append([
                        (a1, a2), dist_label, n, mech_label,
                        score.mean(), score.std()])
    return data


def mixed_tests(trials, alphas, dist, ratios, mechs, mech_labels):
    data = []
    for ratio in ratios:
        for alpha in alphas:
            for n in N_RANGE:
                print(f"\t{alpha}  {ratio}  {n}")
                k = alpha_to_k(alpha, n)
                get_x = lambda n: dist(n, ratio)
                scores = test(k, k, n, get_x, mechs, trials)
                for mech_label, score in zip(mech_labels, scores):
                    data.append([
                        ratio, alpha, n, mech_label,
                        score.mean(), score.std()])
                
    return data


def alpha_to_k(alpha, n):
    return max(round(alpha * n), 1)


def test(k1, k2, n, get_x, mechanisms, trials):
    scores = [np.ndarray(trials) for _ in mechanisms]
    for i in range(trials):
        x = get_x(n)
        for j, m in enumerate(mechanisms):
            scores[j][i] = m(x, k1, k2)
    return scores


def save_data(data, cols, filename):
    print("Writing data...")
    df = pd.DataFrame(data, columns=cols)
    df.to_csv(filename, index=False)


def mixed_dist_sym(n, ratios):
    return np.concatenate([
        UNIFORM(round(n * ratios[0])),
        BETA_SYM(round(n * ratios[1])),
        TRI_SYM(round(n * ratios[2]))
    ])

def mixed_dist_asym(n, ratios):
    return np.concatenate([
        UNIFORM(round(n * ratios[0])),
        BETA_ASYM(round(n * ratios[1])),
        TRI_ASYM(round(n * ratios[2]))
    ])

if __name__ == "__main__":
    run_all_tests()
