import numpy as np
import os
import jax
import jax.numpy as jnp
import pandas as pd
import numpyro
import numpyro.distributions as dist
import pickle
import time
import argparse
from sklearn.model_selection import train_test_split
from jax.scipy.special import logsumexp
from nmodels import *
from semivalues import *
from copy import deepcopy
from tqdm import trange

# def generate_pkl():
#     import torch
#     import torchvision
#     import torchvision.transforms as transforms
#     from medmnist import BloodMNIST
#     from torch.utils.data import DataLoader
#     from sklearn.decomposition import PCA
#     from sklearn.model_selection import KFold

#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#     transform = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize for color images
#     ])

#     train_dataset = BloodMNIST(split='train', transform=transform, download=True)
#     test_dataset = BloodMNIST(split='test', transform=transform, download=True)
#     train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False)
#     test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

#     resnet18 = torchvision.models.resnet18(pretrained=True).to(device)
#     resnet18.fc = torch.nn.Identity()

#     def extract_embeddings(data_loader, model, device):
#         model.eval()  
#         embeddings = []
#         labels = []
#         with torch.no_grad():
#             for images, targets in data_loader:
#                 images = images.to(device)
#                 features = model(images)
#                 embeddings.append(features.cpu().numpy())
#                 labels.append(targets.numpy())
#         embeddings = np.vstack(embeddings)
#         labels = np.vstack(labels).squeeze()
#         return embeddings, labels
    
#     X_train, y_train = extract_embeddings(train_loader, resnet18, device)
#     X_test, y_test = extract_embeddings(test_loader, resnet18, device)

#     pca = PCA(n_components=50, whiten=True)
#     X_train = pca.fit_transform(X_train)
#     X_test = pca.transform(X_test)

#     def get_cross_validated_likelihoods(X, y):
#         kf = KFold(n_splits=4, shuffle=True, random_state=5)
#         likelihoods = []
#         indices = []
#         for traini, testi in kf.split(y):
#             X_train, y_train = X[traini], y[traini]
#             X_test, y_test = X[testi], y[testi]
#             test_data = {"x": X_test, "y": y_test}

#             lr =  BayesianLogisticRegressionMulti(input_dim=X_train.shape[1], output_dim=8, scale=1)
#             mcmc_lr = fit_model(lr, X_train, y_train, num_chains=4, burn_in=1000, num_samples=2000)

#             likelihoods.append(get_indiv_log_probs(mcmc_lr, test_data))
#             indices.append(testi)
#         return np.concatenate(likelihoods), np.concatenate(indices)
    
#     likelihoods, indices = get_cross_validated_likelihoods(X_train, y_train)
#     inliers = indices[likelihoods > -1.6]
#     X_train, y_train = X_train[inliers], y_train[inliers]

#     lr = BayesianLogisticRegressionMulti(input_dim=X_train.shape[1], output_dim=8, scale=1)
#     mcmc_lr = fit_model(lr, X_train, y_train, num_chains=4, burn_in=1000, num_samples=2000)
#     test_probs = get_indiv_log_probs(mcmc_lr, {"x": X_test, "y": y_test})
#     test_mask = test_probs > -1.6
#     X_test, y_test = X_test[test_mask], y_test[test_mask]

#     with open("bloodmnist.pkl", "wb") as f:
#         pickle.dump((X_train, y_train), f)
#         pickle.dump((X_test, y_test), f)

# Define data loading and preprocessing
def load_data(filename):
    with open(filename, "rb") as f:
        X_train, y_train = pickle.load(f)
        X_test, y_test = pickle.load(f)
    test_data = {"x": X_test, "y": y_test}
    return X_train, y_train, test_data

def split_dataset(data, labels, proportions, random_state=0):
    np.random.seed(random_state)

    p = np.array([.2,1,1.2,1,1,.2,1,1]) #uneven split. higher prob, chosen earlier
    probs = p[labels]
    probs = probs/ probs.sum()

    indices = np.random.choice(len(labels), len(labels), replace=False, p=probs)
    data = data[indices]
    labels = labels[indices]
    indices = np.cumsum([int(p * len(data)) for p in proportions[:-1]])
    splits_data = np.split(data, indices)
    splits_labels = np.split(labels, indices)

    return [(d, l) for d, l in zip(splits_data, splits_labels)]

# Define value function generator
def generate_value_function(
    construct_model, test_data, num_chains=4, burn_in=1000, num_samples=2000,
    sample_prob=0.8, num_tests=20):

    model = construct_model()
    prior = fit_model(
        model, test_data["x"][:0], test_data["y"][:0], num_chains=num_chains, burn_in=burn_in, num_samples=num_samples
    )
    prior_lp = get_log_probs(prior, test_data, intermediate=True)

    def value_function(datasets, trained=None, prior=prior.get_samples(), prior_lp=prior_lp, test_data=test_data):
        if not datasets:
            return np.zeros((num_tests,)), prior

        train_x = np.vstack([d[0] for d in datasets])
        train_y = np.hstack([d[1] for d in datasets])

        model = construct_model()
        if trained is None: 
            mcmc = fit_model(
                model, train_x, train_y,
                num_chains=num_chains, burn_in=burn_in, num_samples=num_samples
            )
            mcmc = mcmc.get_samples()
        else:
            mcmc = trained
        jax.clear_caches()
        curr_lp = get_log_probs(mcmc, test_data, model=model.model, intermediate=True)

        results = np.zeros((num_tests,))
        for s in range(num_tests):
            np.random.seed(s)
            mask = np.random.rand(len(test_data["x"])) <= sample_prob
            res = logsumexp(curr_lp[:,mask].sum(axis=1), axis=0) - logsumexp(prior_lp[:,mask].sum(axis=1), axis=0)
            results[s] = res.item() / mask.sum()
        return results, mcmc

    return value_function

def generate_value_functions(
    construct_model, validation_sets, num_chains=4, burn_in=1000, num_samples=2000):

    model = construct_model()
    prior = fit_model(
        model, validation_sets[0][0][:0],validation_sets[0][1][:0], num_chains=num_chains, burn_in=burn_in, num_samples=num_samples
    )
    jax.clear_caches()

    def value_function(datasets, trained=None, prior=prior.get_samples(), validation_sets=validation_sets):
        if not datasets:
            return np.zeros((len(validation_sets),)) , prior

        train_x = np.vstack([d[0] for d in datasets])
        train_y = np.hstack([d[1] for d in datasets])

        model = construct_model()
        if trained is None: 
            mcmc = fit_model(
                model, train_x, train_y,
                num_chains=num_chains, burn_in=burn_in, num_samples=num_samples
            )
            mcmc = mcmc.get_samples()
        else:
            mcmc = trained
        jax.clear_caches()

        results = np.zeros((len(validation_sets),))
        for s, (X_test_subset, y_test_subset) in enumerate(validation_sets):
            test_data_subset = {"x": X_test_subset, "y": y_test_subset}
            diff_log_probs = get_log_probs(mcmc, test_data_subset, model=model.model) - get_log_probs(prior, test_data_subset, model=model.model)
            results[s] = diff_log_probs.item() / len(y_test_subset)
        return results, mcmc
    return value_function

def fake_datasets(dataset0, original=False):
    if original:
        yield dataset0
    yield dataset0[0][::2], dataset0[1][::2]

    def noisify(noise_rate):
        np.random.seed(3)
        replace_mask = np.random.rand(len(dataset0[1])) <= noise_rate
        train_classes, train_mapping = np.unique(dataset0[1], return_inverse=True)
        train_shift = np.random.choice(len(train_classes) - 1, sum(replace_mask)) + 1
        train_noise = (train_mapping[replace_mask] + train_shift) % len(train_classes)
        new_y = dataset0[1].copy()
        new_y[replace_mask] = train_classes[train_noise]
        return new_y
    yield dataset0[0], noisify(.05)

    repeat = 3
    yield np.tile(dataset0[0], (repeat, 1)), np.tile(dataset0[1], (repeat,))
    
    perturbed = dataset0[0][::10].copy()
    perturbed[:,:2] = dataset0[0][:,:2].min(axis=0) - 0.1
    yield np.vstack((dataset0[0], perturbed)), np.hstack((dataset0[1], np.zeros(len(perturbed), dtype=np.int32)))

    np.random.seed(3)
    yield dataset0[0] + .2 * np.random.rand(*dataset0[0].shape), dataset0[1]

def run_experiment(filename, scale=1, seed=0, honest=True):
    X_train, y_train, test_data = load_data(filename)
    datasets = split_dataset(X_train, y_train, [.4, .3, .3], random_state=seed)

    if not honest:
        np.random.seed(42)
        perturbed = datasets[1][1]
        mask = np.random.rand(len(perturbed)) <= 0.06
        datasets[1][1][mask] = 0  

    def construct_model():
        return BayesianLogisticRegressionMulti(input_dim=X_train.shape[1],
                                            output_dim=8, scale=scale)
       
    semivalues = []
    memo = []
    first_samples = []

    vf = generate_value_function(construct_model, test_data)
    valuation = ShapleyValuationM(vf, datasets)
    print("Complete Valuation")
    semivalues.append(valuation.semivalues)
    memo.append(deepcopy(valuation.memo))
    first_samples.append({key: np.asarray(value) for key, value in valuation.memomodel[(0,)].items()})
    print(valuation.semivalues.mean(1))


    for d0 in fake_datasets(datasets[0]):
        valuation.update_player_dataset(0, d0)
        semivalues.append(valuation.semivalues)
        memo.append(deepcopy(valuation.memo))
        first_samples.append({key: np.asarray(value) for key, value in valuation.memomodel[(0,)].items()})
        print(valuation.semivalues.mean(1))
    semivalues = np.stack(semivalues)

    if honest:
        filename=f'blood-scale={scale}_seed={seed}_source={filename}'
    else:
        filename=f'blood-lie-scale={scale}_seed={seed}_source={filename}'
    print(filename)

    with open(filename, "wb") as f:
        pickle.dump(semivalues, f)
        pickle.dump(memo, f)
        pickle.dump(first_samples, f)

def run_experiment_no_gt(filename, scale=1, seed=0, repeats=10, test_split=.25, use=0):
    X_train, y_train, test_data = load_data(filename)
    datasets = split_dataset(X_train, y_train, [.4, .3, .3], random_state=seed)

    def construct_model():
        return BayesianLogisticRegressionMulti(input_dim=X_train.shape[1],
                                            output_dim=8, scale=scale)
    
    filename = f'blood-nogt{test_split}_scale={scale}_seed={seed}_source={filename}'

    result_semiv_include_all, result_semiv_exclude_all = [], []
    for seed in trange(repeats):
        validation_datasets = []
        subsets = []
        for d in datasets:
            x_train, x_val, y_train, y_val = train_test_split(*d, test_size=test_split, random_state=seed)
            validation_datasets.append((x_val, y_val))
            subsets.append((x_train, y_train))

        result_semiv_include, result_semiv_exclude = [], []

        valuation = None
        for d0 in fake_datasets(datasets[0], original=True):
            x_train, x_val, y_train, y_val = train_test_split(*d0, test_size=test_split, random_state=seed)
            subsets[0] = x_train, y_train
            validation_datasets[0] = x_val, y_val
            vf = generate_value_functions(construct_model, validation_datasets)
            print("Generated value function")

            if valuation is None:
                valuation = ShapleyValuationM(vf, subsets)
                print("Complete Valuation")
            else:
                valuation.update_v(vf)
                valuation.update_player_dataset(0, subsets[0])

            result_semiv_include.append(valuation.semivalues.sum(1)) 
            result_semiv_exclude.append(valuation.semivalues.sum(1) - np.diag(valuation.semivalues))

        result_semiv_include_all.append(np.stack(result_semiv_include))
        result_semiv_exclude_all.append(np.stack(result_semiv_exclude))
    
        with open(filename, "wb") as f:
            pickle.dump((result_semiv_include_all,result_semiv_exclude_all), f)

if __name__ == "__main__":
    source = "bloodmnist.pkl"
    run_experiment(source, seed=0, scale=1.)
    run_experiment(source, seed=0, scale=1., honest=False)

    for split in [.25, .5]:
        run_experiment_no_gt(source, scale=1., test_split=split)
