import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import time
import tqdm
import argparse

from dataclasses import replace
from multiprocessing import Pool, cpu_count
from datasets import Dataset, Recidivism, FICO
from multiprocessing import Pool
from rule_mining import run_trial_rule_mining, ManyCoverageParams

if __name__ == '__main__':
    parser = argparse.ArgumentParser()          
    parser.add_argument('--n_cores', type=int, default=20, help='Number of cores to use')
    parser.add_argument('--n_splits', type=int, default=20, help='Number of test/train splits to test')
    parser.add_argument('--save_file_name', type=str, default='results.csv', help='Name of the file to save (need to include .csv)')     
    args = parser.parse_args()

    param_10 = ManyCoverageParams(
        num_iter=1000, 
        num_features_universe=40, 
        N=5, 
        N_high=0, 
        c=0.10,
        covs={0.05: 0.5, 0.10: 0.5},
        allow_high_low_switch=False,
    )
   
    param_20 = ManyCoverageParams(
        num_iter=1000, 
        num_features_universe=40, 
        N=5, 
        N_high=0, 
        c=0.20,
        covs={0.15: 0.5, 0.20: 0.5},
        allow_high_low_switch=False,
    )
   
    param_30 = ManyCoverageParams(
        num_iter=1000, 
        num_features_universe=40, 
        N=5, 
        N_high=0, 
        c=0.30,
        covs={0.25: 0.5, 0.30: 0.5},
        allow_high_low_switch=False,
    )
        
    parameters = []
    for split in range(args.n_splits):
        recid = Recidivism(random_seed=split, train_size=0.99, test_size=0.5)
        fico = FICO(random_seed=split, train_size=0.99, test_size=0.5)
        
        for dataset in [recid, fico]:
            for param in [param_10, param_20, param_30]:
                for N in [5, 8]:
                    param_copy = replace(param) 
                    param_copy.N = N 
                    
                    parameters.append((dataset, param_copy))
    
    start_time = time.perf_counter()
    
    with Pool(args.n_cores) as p:
        results = list(tqdm.tqdm(p.imap(run_trial_rule_mining, parameters), total=len(parameters)))
        results = pd.concat(results, axis=0).reset_index(drop=True)
        results.to_csv('results/%s' % args.save_file_name, index=False)
        finish_time = time.perf_counter()
        print(f"Program finished in {(finish_time-start_time) / 60} minutes")