import numpy as np
import numpy as np


import numpy as np
from pathlib import Path

from attacks.robust_mia import perform_rmia, read_data_from_pickles
from attacks.plotter import get_optimal_alpha_gamma_auc, calculate_confusion_matrix

import numpy as np




if __name__ == '__main__':
    from pathlib import Path
    import os
    import argparse
   
    parser = argparse.ArgumentParser(description="Get matrix.")
    parser.add_argument('--task', type=str, choices=["books_iid", 'books_overlap', "mem", "mem_gpt2", 'samsum', 'mem_pickles'],
                    help='Specify the task to be executed. Choices are: pile, german_wiki (default: pile)')

    args = parser.parse_args()
    auc_dir = Path(f'heatmaps/{args.task}/')
    pickles_dir = Path(f"pickles/{args.task}/")
    
    deltas = {
        "german_wiki": 1/17500,
        "books_iid": 0.000118,
        'books_overlap': 0.0001, 
        "mem": 0.002, 
        "mem_gpt2": 0.01,
        'samsum': 0.000068,
        'mem_pickles': 0.002
        
    }
    prefix = 'train' if 'mem' not in args.task else 'mem'
    for file_name in [k for k in os.listdir(pickles_dir)]:
        if os.path.exists(auc_dir / (file_name+'_auc_matrix.csv')):
            
        
            auc_matrix = np.genfromtxt(auc_dir / (file_name+'_auc_matrix.csv'), delimiter=',')
            alpha_values = np.linspace(0, 1.5, 30)
            gamma_values = np.linspace(0, 3, 60)
            to_log = True
                
            
            
            ref_train_x, ref_test_x, ref_z, train_losses, test_losses, z = read_data_from_pickles(pickles_dir, file_name, to_log, limit=None, prefix=prefix)
            
            alpha, gamma, auc =  get_optimal_alpha_gamma_auc(
                auc_matrix,
                alpha_values=alpha_values, 
                gamma_values=gamma_values            
            )
            
            
            train_above, test_above = perform_rmia(
                    ref_train_x, ref_test_x, ref_z, train_losses, test_losses, z, alpha, gamma
            )
            
            tn, fp, fn, tp = calculate_confusion_matrix(train_above, test_above)
            
            print(10*'-')
            print(f"Model {file_name}: {tp}, {fn}, {fp}, {tn}, {auc}")
            print(10*'-')
            
            with open('pickles/auc_matrix_vc.csv', 'a+') as f:
                f.write(f"D(name='{file_name}', dataset = '{args.task}', epsilon=8,  delta={deltas[args.task]}, fn={fn}, fp={fp}, tn={tn}, tp={tp}),\n")
