import numpy as np
import pandas as pd
from nmodels import *
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle


def subsets_gen(dataset0=None, seed=0):
    if dataset0 is None:
        return "Fraction of validation set", np.linspace(0.2, 1., 5)
    else:
        len_y = len(dataset0[1])
        np.random.seed(seed)
        random_indices = np.random.permutation(len(dataset0[1]))
        temp_X, temp_y = dataset0[0][random_indices], dataset0[1][random_indices]
        def generator():
            subsets = np.linspace(0, len_y, 6, dtype=int)[1:]
            for upto in tqdm(subsets):
                yield temp_X[:upto], temp_y[:upto]
        return generator()
    
def sorted_subsets_gen(dataset0=None, seed=0):
    if dataset0 is None:
        return "Fraction of validation set", np.linspace(0.2, 1., 5)
    else:
        len_y = len(dataset0[1])
        np.random.seed(seed)
        sorted_features = np.random.permutation(dataset0[0].shape[1])
        order_indices = np.lexsort([dataset0[0][:,i] for i in sorted_features])
        temp_X, temp_y = dataset0[0][order_indices], dataset0[1][order_indices]
        def generator():
            subsets = np.linspace(0, len_y, 6, dtype=int)[1:]
            for upto in tqdm(subsets):
                yield temp_X[:upto], temp_y[:upto]
        return generator()

def flips_gen(dataset0=None, seed=0):
    if dataset0 is None:
        return "Fraction of labels flipped", np.linspace(0, .05, 6)
    else:
        np.random.seed(seed)
        mask = np.random.rand(len(dataset0[1]))
        train_classes, train_mapping = np.unique(dataset0[1], return_inverse=True)
        def generator():
            for noise_rate in tqdm(np.linspace(0, .05, 6)):
                replace_mask = mask <= noise_rate
                train_shift = np.random.choice(len(train_classes) - 1, sum(replace_mask)) + 1
                train_noise = (train_mapping[replace_mask] + train_shift) % len(train_classes)
                new_y = dataset0[1].copy()
                new_y[replace_mask] = train_classes[train_noise]
                yield dataset0[0], new_y
        return generator()
   
def noise_gen(dataset0=None, seed=0):
    if dataset0 is None:
        return "Additional noise", np.linspace(0, .25, 6)
    else:
        np.random.seed(seed)
        def generator():
            for scale in tqdm(np.linspace(0, .25, 6)):
                yield dataset0[0], dataset0[1] + np.random.normal(0, scale, size=dataset0[1].shape)
        return generator()


def check_diff_tests(generator, test_data, model, first_samples, prior_samples, num_seeds=10):
    results = []  # Store results across seeds
    for seed in range(num_seeds):
        # test_count = 0
        for dtest in generator((test_data["x"], test_data["y"]), seed=seed):
            # test_count += 1
            result = np.zeros((len(first_samples)))
            prior_log_prob = get_log_probs(prior_samples, {"x": dtest[0], "y": dtest[1]}, model=model.model)
            for j, samples in enumerate(first_samples):
                result[j] = get_log_probs(samples, {"x": dtest[0], "y": dtest[1]}, model=model.model) - prior_log_prob
                result[j] /= len(dtest[1])
            results.append(result)
    return results, num_seeds, len(first_samples)

def plot_diff_tests(results, num_seeds, num_strategies, x_label, x_ax, prepend, xticks=None, append="", ylabel=r'$v(D_0)$', save=False, ncol=2, legend=True):
    repeated_subsets = np.repeat(x_ax, num_strategies)
    df = pd.DataFrame({
        x_label: np.tile(repeated_subsets, num_seeds),
        'DVF': np.concatenate(results),
        'Strategy': np.tile(np.arange(num_strategies), len(x_ax) * num_seeds)
    })

    strategy_map = {0: "T", 1: "S", 2: "N", 3: "D", 4: "I"}
    df['Strategy'] = df['Strategy'].replace(strategy_map)

    markers = ['o', 's', 'd', 'P', '>']
    sns.set(font_scale=2.6) #2.4
    sns.set_style("ticks")
    plt.rcParams.update({
        'pdf.fonttype': 42
    })
    plt.figure(figsize=(8,8))

    ax = sns.lineplot(data=df, x=x_label, y="DVF", hue="Strategy", errorbar=("ci", 95), marker="o", style="Strategy",
                    palette=sns.color_palette("husl", n_colors=num_strategies), markeredgewidth=1, markersize=20, err_kws = {"alpha": 0.1},
                    dashes= [""] + [(6,8)] *  (num_strategies-1), markers=markers, markeredgecolor='gray', lw=3, legend=legend)

    for j, line in enumerate(ax.lines):
        plt.setp(line, zorder=10-j)
    plt.ylabel(ylabel, fontweight="normal")
    if legend:
        plt.legend(ncol=ncol, title="Strategy", columnspacing=1.0)
    plt.xticks(xticks)
    plt.tight_layout()
    if not append:
        append = x_label
    if save:
        plt.savefig('_graphs/{}-{}.pdf'.format(prepend, append), bbox_inches='tight', dpi=300)
    else:
        plt.show()

def plot_semi(semivalues, prepend, append="", shadow=False, ylabel=None, save=True):
    """
    Plots the first value from pickle.load in all run_experiments(...)
    """
    if type(semivalues) == list:
        semivalues = np.stack(semivalues)
    dim0, dim1, dim2 = semivalues.shape
    palette = sns.color_palette("rocket", dim1)
    palette = [(r, g, b, 1.0) if i == 0 else (r, g, b, 0.5) for i, (r, g, b) in enumerate(palette)]
    # if shadow:
    #     palette = [(*desaturate((r, g, b), .2), i-.1) for r, g, b, i in palette]

    data = {
        'Strategy': np.repeat(np.arange(dim0), dim1 * dim2),
        'Source': np.tile(np.repeat(np.arange(dim1), dim2), dim0),
        'Shapley Value': semivalues.ravel()
    }

    df = pd.DataFrame(data)
    strategy_map = {0: "T", 1: "S", 2: "N", 3: "D", 4: "I"}
    df['Strategy'] = df['Strategy'].replace(strategy_map)

    sns.set(font_scale=2.6)
    sns.set_style("ticks")
    plt.rcParams.update({
        'pdf.fonttype': 42
    })
    plt.figure(figsize=(8,8))

    sns.stripplot(data=df,
        x="Strategy", y="Shapley Value",hue="Source", jitter=False,   
        dodge=True, legend=False, palette=palette, size=12, alpha=.3, marker="X"       # Separates points for hue categories
    )
    sns.pointplot(
        data=df, x="Strategy", y="Shapley Value",hue="Source",   
        dodge=.56, linestyle="none", errorbar=None,
        marker="_", markersize=30, markeredgewidth=6, palette=palette, legend=not shadow, alpha=1,
    )

    plt.axhline(y=semivalues[0][0].mean(), color='grey', linestyle=(0,(5,5)), linewidth=2, alpha=0.8)
    if not shadow:
        plt.legend(title="Source", handletextpad=.3, labelspacing=.2, borderpad=.2) #loc="lower left"

    if ylabel:
        plt.ylabel(ylabel, fontweight="normal")

    plt.tight_layout()
    if save:
        plt.savefig('_graphs/{}-shapley{}.pdf'.format(prepend, append), bbox_inches='tight', dpi=300)
    else:
        plt.show()

from semivalues import BetaShapleyValuation,  IndivValuation

def compute_beta_semi(memo, datasets, a=16, b=1):
    """
    memo can be obtained as the second value from pickle.load in all run_experiments(...)
    if there are n parties, can just pass np.arannge(n) as the dataset 
    """
    ans = []
    for m in memo:
        beta_valuation = BetaShapleyValuation(lambda *x: 0, datasets, memo=m, a=a, b=b)
        ans.append(beta_valuation.semivalues)
    return ans

if __name__ == "__main__":
    with open("heart_results/heart-props=[0.3, 0.2, 0.2, 0.3]_seed=0_outliers=False.pkl", "rb") as f:
        semivalues = pickle.load(f)
        memo = pickle.load(f)
        first_samples = pickle.load(f)

    #plot_semi(semivalues)
    beta_semi = compute_beta_semi(memo, np.arange(3))

    from heart import load_heart_disease, split_dataset
    from nmodels import *
    X, y = load_heart_disease(out=False)
    splits = split_dataset(X, y, [0.3, 0.2, 0.2, 0.3], random_state=0)
    test_split = splits[-1]
    test_data = {"x": test_split[0], "y": test_split[1]}

    def construct_model():
        return BayesianLogisticRegression(input_dim=X.shape[1], scale=1)

    model = construct_model()
    prior = fit_model(
        model, test_data["x"][:0], test_data["y"][:0], num_chains=4, burn_in=1000, num_samples=2000
    )

    sresults, num_seeds, num_strategies = check_diff_tests(sorted_subsets_gen, test_data, model, 
                            first_samples, prior.get_samples())
    # plot_diff_tests(sresults, num_seeds, num_strategies,
                    # *sorted_subsets_gen(), prepend="heart", append="vary-sorted", save=False, xticks=np.linspace(.2, 1., 5))

