from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os, utils, random
from pickle import FALSE

import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset, Subset

from baselines.fairPATE.analysis import analyze_multiclass_confident_fair_gnmax

from baselines.fairPATE.models.ensemble_model import FairEnsembleModel
from baselines.fairPATE.models.private_model import get_private_model_by_id
from baselines.fairPATE.models.utils_models import get_model_name_by_id

from baselines.fairPATE.utils import eval_model, metric, train_model, get_unlabeled_set, \
                                    load_evaluation_dataloader, get_unlabeled_set

from baselines.DPSGDGlobalAdapt.main import main
from sklearn.linear_model import LogisticRegression

import sys
from pathlib import Path
current_script_path = Path(__file__).parent  # Path to A
module_b_path = current_script_path / 'fairpate'  # Path to B

# Add B to sys.path
sys.path.append(str(module_b_path))
#print('sys.path', sys.path)
#print('module_b_path', str(module_b_path))

from fairpate.fairpate_tabular.fairpate import run_experiment, run_optuna, process_data, DataSet, FairPATERunParams
from fairpate.fairpate_tabular.generate_tabular_pareto import is_pareto_efficient
from fairpate.fairpate_tabular.args import process_arguments

MAX_ITER = 1000

def fairpate_train(args, param, verbose=True):
    # Add necessary args
    # TODO: Reexamine whether this is needed.
    # Is there a better way to get idiosyncratic args from fairpate into this?
    use_optuna = False

    args.results_dir = '.'
    args.pate_based_model = 'fairpate'
    args.fairness_metric = 'DemParity'
    args.use_optuna = use_optuna
    args.backend = 'sklearn'
    args.num_teachers = 4
    args.log_path = './fairpate/logs/'
    args.verbose = verbose
    args.undersampling_ratio = None
    args.data_path = './fairpate/fairpate_tabular/data/'
    args.use_stratification = False
    args.teacher_query_set_split = 0.7
    args.skip = None
    args.split = 0.75
    args.parallel = None
    args.use_inference_time_postprocessing =  False

    args.budget = param[0]
    args.fairness_threshold = param[1]

    args = process_arguments(args)

    log = print if args.verbose else lambda *x, **y: None

    def generate_model_and_fit(train_features, train_labels, model_id=None):
        model = LogisticRegression(max_iter=MAX_ITER, random_state=args.seed).fit(X=train_features, y=train_labels)
        return model
    train_set, test_set = None, None
    
    # PATE analysis needs this
    np.random.seed(args.seed)
    # Otherwise we use an rng
    np_rng = np.random.default_rng(args.seed)

    # Generate fairpate_args for run_experiment
    fairpate_args = args

    # Generate FairPATERunParams
    train_features, train_labels, train_sensitives, test_features, test_labels, test_sensitives = \
                process_data(np_rng, args, log, train_set=train_set, test_set=test_set)
    train_data = DataSet(features=train_features, labels=train_labels, sensitives=train_sensitives, set=train_set)
    test_data = DataSet(features=test_features, labels=test_labels, sensitives=test_sensitives, set=test_set)

    packed_data = FairPATERunParams(train=train_data, test=test_data, np_rng=np_rng, generate_model_and_fit=generate_model_and_fit, MAX_ITER=MAX_ITER)
    
    # Split based on on Optuna
    if args.use_optuna:
        best_trials = run_optuna(args, packed_data, run_experiment)

        # Filter to pareto frontier of newly added points
        costs = np.array([
            [trial.user_attrs['achieved_eps'],
                trial.user_attrs['test_dem_parity'],
                -trial.user_attrs['student_model_test_accuracy']]
            for trial in best_trials
        ])

        #print(costs)
        
        pareto_indices = is_pareto_efficient(costs)

        # Many optuna points
        for idx in pareto_indices:
            trial_attrs = best_trials[idx].user_attrs

            mydict = {
                'epsilon': param[0],
                'fairness_gaps': param[1], 
                'achieved_epsilon': trial_attrs['achieved_eps'], 
                'achieved_fairness_gaps': trial_attrs['test_dem_parity'], 
                'query_fairness_gaps': None, 
                'number_answered': trial_attrs['num_queries_answered'], 
                'accuracy': trial_attrs['student_model_test_accuracy'],
                'auc': trial_attrs['test_auc'],
                'coverage': trial_attrs['test_coverage']
            }
    else:
        student_model_validation_accuracy, validation_disparity, validation_coverage, \
                achieved_eps, max_num_query, validation_auc, num_queries_answered, student_model_test_accuracy, test_disparity, test_coverage, test_auc = run_experiment(fairpate_args, log, packed_data)

        print(f'achieved_eps: {achieved_eps}, test_disparity: {test_disparity}, param (priv, fair): {param}')

        # Single point
        mydict = {
            'epsilon': param[0],
            'fairness_gaps': param[1], 
            'achieved_epsilon': achieved_eps, 
            'achieved_fairness_gaps': test_disparity, 
            'query_fairness_gaps': None, 
            'number_answered': num_queries_answered, 
            'accuracy': student_model_test_accuracy,
            'auc': test_auc,
            'coverage': test_coverage
        }
    
    return mydict


def train_student_governance_game(args, param, verbose=True):
    """
        Train a student model using FairPATE
    """
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # Update the args according to current param
    print(f"Calibrating on {args.dataset} with (priv, fair): {param}", flush = True)

    # Use fairpate-tabular if needed
    if args.dataset not in ['mnist', 'svhn', 'celeba', 'celebasensitive', 'fairface', 'utkface']:
        return fairpate_train(args, param, verbose=verbose)

    args.budget = param[0]
    args.max_fairness_violation = param[1]
    
    # Logs
    file_name = "logs-(num-models:{})-(num-query-parties:{})-(query-mode:{})-(threshold:{:.1f})-(sigma-gnmax:{:.1f})-(sigma-threshold:{:.1f})-(budget:{:.2f}).txt".format(
        args.num_models,
        1,
        "random",
        args.threshold,
        args.sigma_gnmax,
        args.sigma_threshold,
        args.budget,
    )
    if not os.path.exists(os.path.join(args.path, "logs")):
         os.makedirs(os.path.join(args.path, "logs"))
    file = open(os.path.join(args.path, "logs", file_name), "w")

    # Get the whole unlabeled dataset
    unlabeled_dataset = utils.get_unlabeled_set(args=args)

    # Load raw votes
    filename = "model(1)-raw-votes-(mode-{})-dataset-{}.npy".format(
            "random", args.dataset
        )
    votes_path = args.prev_results_dir+"votes/"+args.dataset+"/"
    filepath = os.path.join(votes_path, filename)
    votes = np.load(filepath)

    # Load sensitive labels
    filename = "model(1)-sensitives-(mode-{})-dataset-{}.npy".format(
        "random", args.dataset
    )
    filepath = os.path.join(votes_path, filename)
    sensitive = np.load(filepath)

    # Get max num queries
    (
        max_num_query, dp_eps, _, answered, _, _, 
        _, _, _, _
        ) = analyze_multiclass_confident_fair_gnmax(votes=votes, sensitives=sensitive, \
                        threshold=args.threshold, fair_threshold=args.max_fairness_violation,\
                        sigma_threshold=args.sigma_threshold, sigma_fair_threshold=0.0, sigma_gnmax=args.sigma_gnmax,\
                        budget=args.budget, delta=args.delta, file=file,show_dp_budget='disable', \
                        args=None, num_sensitive_attributes=len(args.sensitive_group_list), num_classes=args.num_classes, 
                        minimum_group_count=args.min_group_count)
        
    ensemble_model = FairEnsembleModel(
            model_id=0, private_models=[], args=args
        )
    all_indices = list(range(0, args.num_unlabeled_samples))
    indices_queried_num = all_indices[:max_num_query]

    unlabeled_dataset = Subset(unlabeled_dataset, indices_queried_num)
    queryloader = DataLoader(
        unlabeled_dataset, batch_size=len(unlabeled_dataset), shuffle=False
    )
    # Get which queries are answered and preds
    votes = votes[:len(indices_queried_num)]
    sensitive = sensitive[:len(indices_queried_num)]
    noise_threshold = np.random.normal(0., args.sigma_threshold,
                                                       votes.shape[0])
    vote_counts = votes.max(axis=1)
    answered = (vote_counts + noise_threshold) > args.threshold              
    noise_gnmax = np.random.normal(0., args.sigma_gnmax, (
                    votes.shape[0], votes.shape[1]))
    noisy_votes = (votes + noise_gnmax)
    preds = (noisy_votes).argmax(axis=1)

    answered = ensemble_model.apply_fairness_constraint(preds, answered, sensitive, args)

    # Get the train set
    X = None
    z = None
    for data, _, sens in queryloader:
        X = data
        z = sens

    indices = np.where(answered == 1)[0]
    X = X[indices].to(torch.float32)
    y =  torch.from_numpy(preds[indices]).to(torch.float32)
    z = z[indices]

    dataset = TensorDataset(X,y,z)
    trainloader = DataLoader(
        dataset, batch_size=args.batch_size, shuffle=FALSE)
    
    # Get the test set
    evalloader = load_evaluation_dataloader(args)

    # Train
    model_name = get_model_name_by_id(id=0)
    model = get_private_model_by_id(args=args, id=0)
    model.name = model_name
    train_model(args=args, model=model, trainloader=trainloader,
                evalloader=evalloader, patience=10)
    # test
    result, fairness_gaps = eval_model(args=args, model=model, dataloader=evalloader, sensitives=True, preprocessor=True)

    # Single point
    mydict = {'epsilon': param[0],
              'fairness_gaps': param[1], 
              'achieved_epsilon':dp_eps[max_num_query - 1], 
              'achieved_fairness_gaps': np.amax(fairness_gaps), 
              'query_fairness_gaps': np.amax(ensemble_model.fairness_disparity_gaps), 
              'number_answered': sum(answered), 
              'accuracy':result[metric.acc],
              'auc': result[metric.auc],
              'coverage': result[metric.coverage]}
    
    print(mydict)
    return mydict

def train_dpsgd_g_a(args, param):
    accuracy_2_nonpriv = 0.9312015175819397
    accuracy_8_nonpriv = 0.8634496927261353
    print("Preparing to train DPSGD-G.-A. model with: "+str(param), flush = True)
    # set tau
    args.config.append('threshold='+str(param[1]))
    
    result = main(True, param[0], args)
    accuracy_2, accuracy_8 = result['achieved_fairness_gaps']
    result['achieved_fairness_gaps'] = abs((accuracy_2-accuracy_2_nonpriv) - (accuracy_8-accuracy_8_nonpriv))
    
    mydict = {'epsilon': param[0],
              'fairness_gaps': param[1], 
              'achieved_epsilon':result['achieved_epsilon'], 
              'achieved_fairness_gaps': result['achieved_fairness_gaps'],
              'accuracy':100 * result['accuracy'],
              'coverage': 1}
    
    return mydict
    