import numpy as np
import os
import jax
import jax.numpy as jnp
import pandas as pd
import numpyro
import numpyro.distributions as dist
import pickle
import argparse
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from jax.scipy.special import logsumexp
from nmodels import *
from semivalues import *
from copy import deepcopy
from scipy import stats
from tqdm import trange

# def get_cross_validated_likelihoods(X, y):
#     from sklearn.model_selection import KFold
#     kf = KFold(n_splits=5, shuffle=True, random_state=5)
#     likelihoods = []
#     indices = []
#     for traini, testi in kf.split(y):
#         X_train, y_train = X[traini], y[traini]
#         X_test, y_test = X[testi], y[testi]
#         test_data = {"x": X_test, "y": y_test}

#         lr = BayesianLogisticRegression(input_dim=X_train.shape[1], scale=1)
#         mcmc_lr = fit_model(lr, X_train, y_train, num_chains=4, burn_in=1000, num_samples=2000)

#         likelihoods.append(get_indiv_log_probs(mcmc_lr, test_data))
#         indices.append(testi)
#     return np.concatenate(likelihoods), np.concatenate(indices)

# Define data loading and preprocessing
def load_heart_disease(out=False):
    # url = "https://raw.githubusercontent.com/jmrieck17/CSC-5800-Final-Project-Heart-Disease-Prediction/main/heart.csv"
    file = 'heart.csv'
    df = pd.read_csv(file)
    df = pd.get_dummies(
        df,
        columns=["sex", "fbs", "cp", "exang", "restecg", "slope", "thal"],
        drop_first=True,
        dtype=int,
    )
    if not out:
        loaded_data = np.load("heart_likelihood.npz")
        likelihoods = loaded_data["likelihoods"]
        indices = loaded_data["indices"]
        outliers = indices[likelihoods < -2]
        df = df.drop(outliers, axis=0)
    y = df.pop("target").to_numpy()
    X = df.to_numpy()
    # likelihoods, indices = get_cross_validated_likelihoods(X, y)
    # np.savez("heart_likelihood.npz", likelihoods=likelihoods, indices=indices)
    return X, y

def split_dataset(data, labels, proportions, random_state=0):
    np.random.seed(random_state)
    indices = np.random.permutation(len(data))
    data = data[indices]
    labels = labels[indices]
    total = len(data)
    indices = np.cumsum([int(p * total) for p in proportions[:-1]])
    splits_data = np.split(data, indices)
    splits_labels = np.split(labels, indices)

    all_data_except_last = np.concatenate(splits_data[:-1])
    scaler = StandardScaler().fit(all_data_except_last)

    return [
        (scaler.transform(d), l) for d, l in zip(splits_data, splits_labels)
    ]

# Define value function generator
def generate_value_function(
    construct_model, test_data, num_chains=4, burn_in=1000, num_samples=2000,
    sample_prob=0.8, num_tests=20):

    model = construct_model()
    prior = fit_model(
        model, test_data["x"][:0], test_data["y"][:0], num_chains=num_chains, burn_in=burn_in, num_samples=num_samples
    )
    prior_lp = get_log_probs(prior, test_data, intermediate=True)

    def value_function(datasets, trained=None, prior=prior, prior_lp=prior_lp, test_data=test_data):
        if not datasets:
            return np.zeros((num_tests,)), prior.get_samples()

        train_x = np.vstack([d[0] for d in datasets])
        train_y = np.hstack([d[1] for d in datasets])

        model = construct_model()
        if trained is None: 
            mcmc = fit_model(
                model, train_x, train_y,
                num_chains=num_chains, burn_in=burn_in, num_samples=num_samples
            )
            mcmc = mcmc.get_samples()
        else:
            mcmc = trained
        jax.clear_caches()
        curr_lp = get_log_probs(mcmc, test_data, model=model.model, intermediate=True)

        results = np.zeros((num_tests,))
        for s in range(num_tests):
            np.random.seed(s)
            mask = np.random.rand(len(test_data["x"])) <= sample_prob
            res = logsumexp(curr_lp[:,mask].sum(axis=1), axis=0) - logsumexp(prior_lp[:,mask].sum(axis=1), axis=0)
            results[s] = res.item() / mask.sum()
        return results, mcmc

    return value_function

def generate_value_functions(
    construct_model, validation_sets, num_chains=4, burn_in=1000, num_samples=2000):

    model = construct_model()
    prior = fit_model(
        model, validation_sets[0][0][:0],validation_sets[0][1][:0], num_chains=num_chains, burn_in=burn_in, num_samples=num_samples
    )
    jax.clear_caches()

    def value_function(datasets, trained=None, prior=prior.get_samples(), validation_sets=validation_sets):
        if not datasets:
            return np.zeros((len(validation_sets),)) , prior

        train_x = np.vstack([d[0] for d in datasets])
        train_y = np.hstack([d[1] for d in datasets])

        model = construct_model()
        if trained is None: 
            mcmc = fit_model(
                model, train_x, train_y,
                num_chains=num_chains, burn_in=burn_in, num_samples=num_samples
            )
            mcmc = mcmc.get_samples()
        else:
            mcmc = trained
        jax.clear_caches()

        results = np.zeros((len(validation_sets),))
        for s, (X_test_subset, y_test_subset) in enumerate(validation_sets):
            test_data_subset = {"x": X_test_subset, "y": y_test_subset}
            diff_log_probs = get_log_probs(mcmc, test_data_subset, model=model.model) - get_log_probs(prior, test_data_subset, model=model.model)
            results[s] = diff_log_probs.item() / len(y_test_subset)
        return results, mcmc

    return value_function

def fake_datasets(dataset0, original=False):
    if original:
        yield dataset0
    yield dataset0[0][::2], dataset0[1][::2]
    def noisify(noise_rate):
        np.random.seed(3)
        replace_mask = np.random.rand(len(dataset0[1])) <= noise_rate
        new_y = dataset0[1].copy()
        new_y[replace_mask] = 1 - new_y[replace_mask]
        return new_y
    yield dataset0[0], noisify(.05)

    repeat = 3 
    yield np.tile(dataset0[0], (repeat, 1)), np.tile(dataset0[1], (repeat,))
    
    perturbed = dataset0[0][::10].copy()
    perturbed[:,:2] = dataset0[0][:,:2].min(axis=0) - 0.1
    yield np.vstack((dataset0[0], perturbed)), np.hstack((dataset0[1], np.full(len(perturbed), stats.mode(dataset0[1])[0])))

    np.random.seed(3)
    yield dataset0[0] + .2 * np.random.rand(*dataset0[0].shape), dataset0[1]

def run_experiment(proportions, out=False, random_state=0, save_dir="heart_results", honest=True):    
    X, y = load_heart_disease(out=out)
    splits = split_dataset(X, y, proportions, random_state=random_state)
    datasets = splits[:-1]
    test_split = splits[-1]
    test_data = {"x": test_split[0], "y": test_split[1]}
    if not honest:
        np.random.seed(42)
        perturbed = datasets[1][1]
        mask = np.random.rand(len(perturbed)) < 0.05  
        datasets[1][1][mask] = 1 - perturbed[mask]  # Flip labels (0 -> 1, 1 -> 0)
        

    def construct_model():
        return BayesianLogisticRegression(input_dim=X.shape[1],
                                          scale=1)

    semivalues = []
    memo = []
    first_samples = []

    vf = generate_value_function(construct_model, test_data)
    valuation = ShapleyValuationM(vf, datasets)
    print("Complete Valuation")
    semivalues.append(valuation.semivalues)
    memo.append(deepcopy(valuation.memo))
    first_samples.append({key: np.asarray(value) for key, value in valuation.memomodel[(0,)].items()})
    print(valuation.semivalues.mean(1))


    for d0 in fake_datasets(datasets[0]):
        valuation.update_player_dataset(0, d0)
        semivalues.append(valuation.semivalues)
        memo.append(deepcopy(valuation.memo))
        first_samples.append({key: np.asarray(value) for key, value in valuation.memomodel[(0,)].items()})
        print(valuation.semivalues.mean(1))
    semivalues = np.stack(semivalues)

    os.makedirs(save_dir, exist_ok=True)
    if honest:
        filename = f"heart-props={proportions}_seed={random_state}_outliers={out}.pkl"
    else:
        filename = f"heart-lie-props={proportions}_seed={random_state}_outliers={out}.pkl"
    filename = os.path.join(save_dir, filename)

    with open(filename, "wb") as f:
        pickle.dump(semivalues, f)
        pickle.dump(memo, f)
        pickle.dump(first_samples, f)
        prior = {key: np.asarray(value) for key, value in valuation.memomodel[()].items()}
        pickle.dump(prior, f)

def run_experiment_no_gt(proportions, out=False, random_state=0, save_dir="heart_results", honest=True,
                         repeats=10, test_split=.25):
    X, y = load_heart_disease(out=out)
    splits = split_dataset(X, y, proportions, random_state=random_state)
    datasets = splits[:-1]
        
    def construct_model():
        return BayesianLogisticRegression(input_dim=X.shape[1],
                                          scale=1)

    os.makedirs(save_dir, exist_ok=True)
    if honest:
        filename = f"heart-nogt{test_split}-props={proportions}_seed={random_state}_outliers={out}.pkl"
    else:
        filename = f"heart-nogt{test_split}-lie-props={proportions}_seed={random_state}_outliers={out}.pkl"
    filename = os.path.join(save_dir, filename)
    print(filename)

    result_semiv_include_all, result_semiv_exclude_all = [], []
    for seed in trange(repeats):
        validation_datasets = []
        subsets = []
        for d in datasets:
            x_train, x_val, y_train, y_val = train_test_split(*d, test_size=test_split, random_state=seed)
            validation_datasets.append((x_val, y_val))
            subsets.append((x_train, y_train))

        result_semiv_include, result_semiv_exclude = [], []

        valuation = None
        for d0 in fake_datasets(datasets[0], original=True):
            x_train, x_val, y_train, y_val = train_test_split(*d0, test_size=test_split, random_state=seed)
            subsets[0] = x_train, y_train
            validation_datasets[0] = x_val, y_val
            vf = generate_value_functions(construct_model, validation_datasets)
            print("Generated value function")

            if valuation is None:
                valuation = ShapleyValuationM(vf, subsets)
                print("Complete Valuation")
            else:
                valuation.update_v(vf)
                valuation.update_player_dataset(0, subsets[0])

            result_semiv_include.append(valuation.semivalues.sum(1)) 
            result_semiv_exclude.append(valuation.semivalues.sum(1) - np.diag(valuation.semivalues))

        result_semiv_include_all.append(np.stack(result_semiv_include))
        result_semiv_exclude_all.append(np.stack(result_semiv_exclude))
    
        with open(filename, "wb") as f:
            pickle.dump((result_semiv_include_all,result_semiv_exclude_all), f)


if __name__ == "__main__":
    props = [0.3, 0.2, 0.2, 0.3]
    out= False
    for honest in [True,False]:
        run_experiment(proportions=props, out=out, random_state=0, honest=honest)
    
    # for split in [.25, .5]:
    #     run_experiment_no_gt(proportions=props, out=out,
    #                         random_state=0, honest=True, test_split=split)


   

