import numpy as np
import os
import jax
import jax.numpy as jnp
import pandas as pd
import numpyro
import numpyro.distributions as dist
import pickle
import time
import argparse
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from jax.scipy.special import logsumexp
from nmodels import *
from semivalues import *
from copy import deepcopy
from tqdm import trange

# Define data loading and preprocessing
def load_data():
    url = "https://raw.githubusercontent.com/SantiagoMorenoV/Combined-Cycle-Power-Plant_Regs/refs/heads/main/Data.csv"
    data = pd.read_csv(url)
    X = data.iloc[:, :-1].values.astype("float32")
    y = data.iloc[:, -1].values.astype("float32")
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
    scaler_X = StandardScaler()
    scaler_y = StandardScaler()
    X_train = scaler_X.fit_transform(X_train)
    X_test = scaler_X.transform(X_test)
    y_train = scaler_y.fit_transform(y_train.reshape(-1, 1)).squeeze()
    y_test = scaler_y.transform(y_test.reshape(-1, 1)).squeeze()
    test_data = {"x": X_test, "y": y_test}
    return X_train, y_train, test_data

def split_dataset(data, labels, proportions, random_state=0):
    np.random.seed(random_state)
    indices = np.random.permutation(len(data))
    data = data[indices]
    labels = labels[indices]
    indices = np.cumsum([int(p * len(data)) for p in proportions[:-1]])
    splits_data = np.split(data, indices)
    splits_labels = np.split(labels, indices)

    return [(d, l) for d, l in zip(splits_data, splits_labels)]

# Define value function generator
def generate_value_function(
    construct_model, test_data, num_chains=4, burn_in=1000, num_samples=2000,
    sample_prob=0.8, num_tests=20):

    model = construct_model()
    prior = fit_model(
        model, test_data["x"][:0], test_data["y"][:0], num_chains=num_chains, burn_in=burn_in, num_samples=num_samples
    )
    test_data["y"] = test_data["y"].reshape(-1, 1)
    jax.clear_caches()
    prior_lp = get_log_probs(prior, test_data, intermediate=True)

    def value_function(datasets, trained=None, prior=prior.get_samples(), prior_lp=prior_lp, test_data=test_data):
        if not datasets:
            return np.zeros((num_tests,)), prior

        train_x = np.vstack([d[0] for d in datasets])
        train_y = np.hstack([d[1] for d in datasets])

        model = construct_model()
        if trained is None: 
            mcmc = fit_model(
                model, train_x, train_y,
                num_chains=num_chains, burn_in=burn_in, num_samples=num_samples
            )
            mcmc = mcmc.get_samples()
        else:
            mcmc = trained
        jax.clear_caches()
        curr_lp = get_log_probs(mcmc, test_data, model=model.model, intermediate=True)

        results = np.zeros((num_tests,))
        for s in range(num_tests):
            np.random.seed(s)
            mask = np.random.rand(len(test_data["x"])) <= sample_prob
            res = logsumexp(curr_lp[:,mask].sum(axis=1), axis=0) - logsumexp(prior_lp[:,mask].sum(axis=1), axis=0)
            results[s] = res.item() / mask.sum()
        return results, mcmc

    return value_function

def generate_value_functions(
    construct_model, validation_sets, num_chains=4, burn_in=1000, num_samples=2000):

    model = construct_model()
    prior = fit_model(
        model, validation_sets[0][0][:0],validation_sets[0][1][:0], num_chains=num_chains, burn_in=burn_in, num_samples=num_samples
    )
    jax.clear_caches()

    def value_function(datasets, trained=None, prior=prior.get_samples(), validation_sets=validation_sets):
        if not datasets:
            return np.zeros((len(validation_sets),)) , prior

        train_x = np.vstack([d[0] for d in datasets])
        train_y = np.hstack([d[1] for d in datasets])

        model = construct_model()
        if trained is None: 
            mcmc = fit_model(
                model, train_x, train_y,
                num_chains=num_chains, burn_in=burn_in, num_samples=num_samples
            )
            mcmc = mcmc.get_samples()
        else:
            mcmc = trained
        jax.clear_caches()

        results = np.zeros((len(validation_sets),))
        for s, (X_test_subset, y_test_subset) in enumerate(validation_sets):
            test_data_subset = {"x": X_test_subset, "y": y_test_subset.reshape(-1,1)}
            diff_log_probs = get_log_probs(mcmc, test_data_subset, model=model.model) - get_log_probs(prior, test_data_subset, model=model.model)
            results[s] = diff_log_probs.item() / len(y_test_subset)
        return results, mcmc

    return value_function

def fake_datasets(dataset0, original=False):
    if original:
        yield dataset0
    yield dataset0[0][::2], dataset0[1][::2]

    np.random.seed(3)
    yield dataset0[0], dataset0[1] + np.random.normal(0, .2, size=dataset0[1].shape)

    repeat = 3
    yield np.tile(dataset0[0], (repeat, 1)), np.tile(dataset0[1], (repeat,))
    
    perturbed = dataset0[0][::10].copy()
    perturbed[:,:2] = dataset0[0][:,:2].min(axis=0) - 0.1
    yield np.vstack((dataset0[0], perturbed)), np.hstack((dataset0[1], np.zeros(len(perturbed))))

    np.random.seed(3)
    yield dataset0[0] + .1 * np.random.rand(*dataset0[0].shape), dataset0[1]


def run_experiment(hid_dim, num_hidden_layers, activation="relu", honest=True):
    X_train, y_train, test_data = load_data()
    datasets = split_dataset(X_train, y_train, [.4, .3, .3])
    if not honest:
        np.random.seed(42)
        perturbed = datasets[1][1]
        mask = np.random.rand(len(perturbed)) <= 0.1
        datasets[1][1][mask] = 0.

    def construct_model():
        return BayesianNeuralNetwork(
            hid_dim=hid_dim, 
            activation=jax.nn.relu if activation == "relu" else jax.nn.leaky_relu, 
            num_hidden_layers=num_hidden_layers, 
            noise_dist=dist.InverseGamma(3, 1)
        )

    semivalues = []
    memo = []
    first_samples = []

    vf = generate_value_function(construct_model, test_data)
    print("Generated value function")
    valuation = ShapleyValuationM(vf, datasets)
    print("Complete Valuation")
    semivalues.append(valuation.semivalues)
    memo.append(deepcopy(valuation.memo))
    first_samples.append({key: np.asarray(value) for key, value in valuation.memomodel[(0,)].items()})
    print(valuation.semivalues.mean(1))


    for d0 in fake_datasets(datasets[0]):
        valuation.update_player_dataset(0, d0)
        semivalues.append(valuation.semivalues)
        memo.append(deepcopy(valuation.memo))
        first_samples.append({key: np.asarray(value) for key, value in valuation.memomodel[(0,)].items()})
        print(valuation.semivalues.mean(1))
    semivalues = np.stack(semivalues)

    if honest:
        filename=f"cycle-hiddim={hid_dim}_layers={num_hidden_layers}_{activation}.pkl"
    else:
        filename=f"cycle-lie_hiddim={hid_dim}_layers={num_hidden_layers}_{activation}.pkl"

    with open(filename, "wb") as f:
        pickle.dump(semivalues, f)
        pickle.dump(memo, f)
        pickle.dump(first_samples, f)

def run_experiment_no_gt(hid_dim, num_hidden_layers, activation="relu", repeats=10, test_split=.25):
    X_train, y_train, test_data = load_data()
    datasets = split_dataset(X_train, y_train, [.4, .3, .3])

    filename = f"cycle-nogt{test_split}-hiddim={hid_dim}_layers={num_hidden_layers}_{activation}.pkl"

    def construct_model():
        return BayesianNeuralNetwork(
            hid_dim=hid_dim, 
            activation=jax.nn.relu if activation == "relu" else jax.nn.leaky_relu, 
            num_hidden_layers=num_hidden_layers, 
            noise_dist=dist.InverseGamma(3, 1)
        )
    
    result_semiv_include_all, result_semiv_exclude_all = [], []
    for seed in trange(repeats):
        validation_datasets = []
        subsets = []
        for d in datasets:
            x_train, x_val, y_train, y_val = train_test_split(*d, test_size=test_split, random_state=seed)
            validation_datasets.append((x_val, y_val))
            subsets.append((x_train, y_train))

        result_semiv_include, result_semiv_exclude = [], []

        valuation = None
        for d0 in fake_datasets(datasets[0], original=True):
            x_train, x_val, y_train, y_val = train_test_split(*d0, test_size=test_split, random_state=seed)
            subsets[0] = x_train, y_train
            validation_datasets[0] = x_val, y_val
            vf = generate_value_functions(construct_model, validation_datasets)
            print("Generated value function")

            if valuation is None:
                valuation = ShapleyValuationM(vf, subsets)
                print("Complete Valuation")
            else:
                valuation.update_v(vf)
                valuation.update_player_dataset(0, subsets[0])

            result_semiv_include.append(valuation.semivalues.sum(1)) 
            result_semiv_exclude.append(valuation.semivalues.sum(1) - np.diag(valuation.semivalues))

        result_semiv_include_all.append(np.stack(result_semiv_include))
        result_semiv_exclude_all.append(np.stack(result_semiv_exclude))
    
        with open(filename, "wb") as f:
            pickle.dump((result_semiv_include_all,result_semiv_exclude_all), f)

if __name__ == "__main__":
    run_experiment(100, 1)
    run_experiment(100, 1, honest=False)
   
    for split in [.25, .5]:
        run_experiment_no_gt(100, 1, test_split=split)
