# import sys
# sys.path.append("../")


from dataclasses import dataclass
from pkgutil import get_data
from analysis import analyze_multiclass_confident_gnmax, analyze_multiclass_confident_fair_gnmax
from analysis import analyze_multiclass_confident_gnmax_jax, analyze_multiclass_confident_fair_gnmax_jax
from analysis import analyze_multiclass_confident_fair_gnmax_new_analysis

import numpy  as np
import pytest

@pytest.fixture(params=[10, 50, 100, 150])
def number_of_models(request):
    return request.param

@pytest.fixture(params=["cmnist", "fface", "celeba", "gaussian"])
def name(request):
    return request.param

@pytest.fixture()
def load_data(name, number_of_models):
    if name == "cmnist":
            dataset = "colormnist"
    elif name == "fface":
            dataset = "fairface"
    elif name == "celeba":
            dataset = "celebasensitive"
    elif name == "gaussian":
            dataset = "gaussian"

    path = f"./data/{name}_{number_of_models}-models/"
    if name == "celeba":
            raw_votes = np.load(path + f"model(1)-raw-votes-(mode-random)-dataset-celebasensitive.npy").astype(float)
    else:
            raw_votes = np.load(path + "model(1)-raw-votes-mode-random-vote-type-discrete.npy").astype(float)
                    
    targets = np.load(path + "model(1)-targets-(mode-random)-dataset-"+dataset+".npy").astype(float)
    sensitives = np.load(path + "model(1)-sensitives-(mode-random)-dataset-"+dataset+".npy").astype(float)

    return [name, number_of_models, raw_votes, targets, sensitives]

def test_jax_nonjax(load_data):
    name, number_of_models, raw_votes, targets, sensitives = load_data

    args = dict(
        votes=raw_votes, 
        threshold=2, 
        sigma_threshold=5, 
        sigma_gnmax=5, 
        budget = 20,
        delta = 1e-5)
    
    num_classes = raw_votes.shape[1]
    num_sensitive_attributes = int(sensitives.max()) + 1
    ratio_slack = 0.1
    max_fair_threshold =  ratio_slack * (1/num_classes)
    args["fair_threshold"]= max_fair_threshold
    args["sigma_fair_threshold"]=0.1
    args["sensitives"] = sensitives
    args["minimum_group_count"]=50


    (jax_max_num_query, jax_dp_eps, jax_partition, jax_answered, jax_order_opt, 
            jax_sensitive_group_count, jax_pos_prediction_one_hot, jax_answered_curr, jax_gaps, jax_pr_answered_per_query) = \
    analyze_multiclass_confident_fair_gnmax_jax(**args, num_sensitive_attributes=num_sensitive_attributes, num_classes=num_classes, log=lambda *x:x)



    (nonjax_max_num_query, nonjax_dp_eps, nonjax_partition, nonjax_answered, nonjax_order_opt, 
            nonjax_sensitive_group_count, nonjax_pos_prediction_one_hot, nonjax_answered_curr, nonjax_gaps, nonjax_pr_answered_per_query) = \
    analyze_multiclass_confident_fair_gnmax(**args, num_sensitive_attributes=num_sensitive_attributes, num_classes=num_classes, file=".", log=lambda *x:x)

    assert jax_max_num_query.item() == nonjax_max_num_query

def test_new_analysis_vs_old(load_data):
    name, number_of_models, raw_votes, targets, sensitives = load_data

    args = dict(
            votes=raw_votes, 
            threshold=2, 
            sigma_threshold=5, 
            sigma_gnmax=5, 
            budget = 20,
            delta = 1e-5)

    num_classes = raw_votes.shape[1]
    num_sensitive_attributes = int(sensitives.max()) + 1
    ratio_slack = 0.1
    max_fair_threshold =  ratio_slack * (1/num_classes)
    args["fair_threshold"]= max_fair_threshold
    args["sigma_fair_threshold"]=0.1
    args["sensitives"] = sensitives
    args["minimum_group_count"]=10


    (new_max_num_query, new_dp_eps, new_partition, new_answered, new_order_opt, 
            new_sensitive_group_count, new_pos_prediction_one_hot, new_answered_curr, new_gaps, new_pr_answered_per_query) = \
    analyze_multiclass_confident_fair_gnmax_new_analysis(**args, num_sensitive_attributes=num_sensitive_attributes, num_classes=num_classes, 
                                                         file=".", log=lambda *x:x)

    args["minimum_group_count"]=100

    (nonjax_max_num_query, nonjax_dp_eps, nonjax_partition, nonjax_answered, nonjax_order_opt, 
            nonjax_sensitive_group_count, nonjax_pos_prediction_one_hot, nonjax_answered_curr, nonjax_gaps, nonjax_pr_answered_per_query) = \
    analyze_multiclass_confident_fair_gnmax_new_analysis(**args, num_sensitive_attributes=num_sensitive_attributes, num_classes=num_classes, 
                                                         file=".", log=lambda *x:x)

    '''
    with open("./results/test.csv", "a") as f:
        f.writelines(f"{name}, {number_of_models}, {new_max_num_query}, {nonjax_max_num_query}\n")
    '''
    print(new_max_num_query, nonjax_max_num_query)

    assert new_max_num_query <= nonjax_max_num_query
    