import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tqdm
import argparse
import time
import os
import copy

from datasets import Recidivism, FICO, Dataset, Schizo, Adults, Diabetes, Readmission
from multiprocessing import Pool, cpu_count
from dataclasses import dataclass, replace
from rulelist import RuleListMinerParams, RuleListMiner 

from scipy.stats import norm

os.chdir('/home/evanyao/final_paper')

def run_trial(inpt):   
    dataset: Dataset = inpt[0]
    rulelist_params: RuleListMinerParams = inpt[1]
    target_support = inpt[2]
    
    miner = RuleListMiner(dataset, rulelist_params, False)
    rule = miner.run(target_support=target_support)

    return {
        'dataset': dataset.__class__.__name__,
        'N': rulelist_params.N,
        'target_support': target_support,
        'max_depth': rulelist_params.max_depth,
        'abbr': miner.evaluate_rule_test(rule),
        'abbr_start': miner.evaluate_rule_test(miner.starting_rule),
        'use_quantile': rulelist_params.use_quantile,
    }

if __name__ == '__main__':
    parser = argparse.ArgumentParser()          
    parser.add_argument('--n_cores', type=int, default=cpu_count(), help='Number of cores to use')
    parser.add_argument('--n_runs', type=int, default=1, help='Number of splits')
    parser.add_argument('--save_file_name', type=str, help='Name of the file to save (need to include .csv)', required=True)
    
    args = parser.parse_args()
    print('Number of Cores: %d' % cpu_count())

    parameters = [] 

    for seed in tqdm.trange(args.n_runs):
        for ds in [Recidivism, FICO, Adults, Diabetes, Schizo, Readmission]: 
            dataset = ds(random_seed=seed)
            for N in [3, 5]:
                for max_depth in [2, 3, 4]:
                    for use_quantile in [True]:     
                        rulelist_params = RuleListMinerParams(
                            N=N,
                            max_depth=max_depth,
                            num_iter=500,
                            tolerance=0.1,
                            use_quantile=use_quantile,
                        )
                        
                        for target_support in [0.1, 0.2]:
                            parameters.append((dataset, rulelist_params, target_support))
        
    with Pool(args.n_cores) as p:
        start_time = time.perf_counter()
        
        results = list(tqdm.tqdm(p.imap(run_trial, parameters), total=len(parameters)))
        results = pd.DataFrame(results)
        results.to_csv('results/%s' % args.save_file_name, index=False)
        
        finish_time = time.perf_counter()
        print(f"Program finished in {(finish_time-start_time) / 60} minutes")