from model_abstractions import BCvMAlgorithm, DataSharingExperiment
from data_generators import NormalNormalDataGenerator, BetaBernDataGenerator
from scipy.stats import norm
from submission_functions import identity_func
from functools import partial

def beta_bern_gfunction(data_points, T, alpha, beta):
    n = len(data_points)
    return T + (1-T) * (beta + (n+1) - sum(data_points)) / (alpha + beta + (n+1))

def normal_normal_gfunction(data_points, T, prior_mean, prior_var, var):
    """
    data_points: list of data points
    T: the test point
    prior_mean: prior mean μ₀
    prior_var: prior variance σ₀²
    var: observation variance σ²
    """
    # posterior mean and variance
    mu_tilde = (
        (prior_mean/prior_var)
        + (sum(data_points) + T)/var
    ) / (1/prior_var + (len(data_points)+1)/var)

    var_tilde = 1 / (1/prior_var + (len(data_points)+1)/var)
    inner_term = (T - mu_tilde) / (var + var_tilde)**0.5
    return norm.cdf(inner_term)

def normal_normal_exp(prior_mean, prior_var, var, submission_funcs, runs, num_data_per_agent, alg=None):

    normalnormal_data_gen = NormalNormalDataGenerator(prior_mean=prior_mean, prior_var=prior_var, var=var)
    gfunc = partial(normal_normal_gfunction,
                        prior_mean=prior_mean,
                        prior_var=prior_var,
                        var=var)

    if alg is None:
        alg = BCvMAlgorithm(feature_maps=[identity_func],
                            feature_coeffs=[1],
                            g_functions=[gfunc])
    else:
        alg = alg

    norm_norm_exper = DataSharingExperiment(algorithm=alg, 
                                            submission_functions=submission_funcs, 
                                            data_generator=normalnormal_data_gen)

    results = norm_norm_exper.run_experiment(runs=runs, num_data_per_agent=num_data_per_agent)
    return results


def beta_bern_exp(alpha, beta, submission_funcs, runs, num_data_per_agent, alg=None):
    # Create a pickleable g-function
    gfunc = partial(beta_bern_gfunction, alpha=alpha, beta=beta)

    beta_bern_data_gen = BetaBernDataGenerator(alpha=alpha, beta=beta)

    if alg is None:
        alg = BCvMAlgorithm(
            feature_maps=[identity_func],
            feature_coeffs=[1],
            g_functions=[gfunc],     
        )

    exper = DataSharingExperiment(
        algorithm=alg,
        submission_functions=submission_funcs,
        data_generator=beta_bern_data_gen
    )
    return exper.run_experiment(runs=runs, num_data_per_agent=num_data_per_agent)