import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tqdm
import argparse
import time
import os
import copy

from datasets import Recidivism, FICO, Dataset, Schizo, Adults, Diabetes, Readmission
from multiprocessing import Pool, cpu_count
from dataclasses import dataclass, replace
from rulelist import RuleListMinerParams, RuleListMiner
from checklist import ChecklistMiner, ChecklistMinerParams

from scipy.stats import norm

def run_rule_list_trial(inpt):   
    dataset: Dataset = inpt[0]
    rulelist_params: RuleListMinerParams = inpt[1]
    target_support = inpt[2]
    
    miner = RuleListMiner(dataset, rulelist_params, False)
    rule = miner.run(target_support=target_support)

    return {
        'dataset': dataset.__class__.__name__,
        'N': rulelist_params.N,
        'target_support': target_support,
        'max_depth': rulelist_params.max_depth,
        'abbr': miner.evaluate_rule_test(rule),
        'abbr_start': miner.evaluate_rule_test(miner.starting_rule),
        'use_quantile': rulelist_params.use_quantile,
    }

def run_checklist_trial(inpt):
    dataset: Dataset = inpt[0]
    checklist_params: ChecklistMinerParams = inpt[1]
    target_support = inpt[2]
    
    miner = ChecklistMiner(dataset, checklist_params, False)
    rule = miner.run(target_support=target_support)
    
    return {
        'dataset': dataset.__class__.__name__,
        'N': checklist_params.Q,
        'target_support': target_support,
        'abbr': miner.evaluate_rule_test(rule),
        'abbr_start': miner.evaluate_rule_test(miner.starting_rule),
    }

if __name__ == '__main__':
    parser = argparse.ArgumentParser()          
    parser.add_argument('--n_cores', type=int, default=cpu_count(), help='Number of cores to use')
    parser.add_argument('--n_runs', type=int, default=1, help='Number of splits')
    parser.add_argument('--save_file_name', type=str, help='Name of the file to save (need to include .csv)', required=True)
    
    args = parser.parse_args()
    print('Number of Cores: %d' % cpu_count())

    parameters = []

    for seed in tqdm.trange(args.n_runs):
        for ds in [Recidivism, FICO, Adults, Diabetes, Schizo, Readmission]:
            dataset = ds(random_seed=seed)
            for N in [3, 5]:
                for max_depth in [2, 3, 4]:
                    rulelist_params = RuleListMinerParams(
                        N=N,
                        max_depth=max_depth,
                        num_iter=500,
                        tolerance=0.1,
                        use_quantile=True,
                    )
                    
                    for target_support in [0.1, 0.2]:
                        parameters.append((dataset, rulelist_params, target_support))
    
    with Pool(args.n_cores) as p:
        start_time = time.perf_counter()
        
        results = list(tqdm.tqdm(p.imap(run_rule_list_trial, parameters), total=len(parameters)))
        results = pd.DataFrame(results)
        results.to_csv('rule_lists_%s' % args.save_file_name, index=False)
        
        finish_time = time.perf_counter()
        print(f"Rule Lists finished in {(finish_time-start_time) / 60} minutes")

    # for seed in tqdm.trange(args.n_runs):
    #     for ds in [Recidivism, FICO, Adults, Diabetes, Schizo, Readmission]: 
    #         dataset = ds(random_seed=seed)
    #         for N in [5, 7]:
    #             checklist_params = ChecklistMinerParams(
    #                 N=N,
    #                 num_iter=500,
    #             )
                
    #             for target_support in [0.1, 0.2]:
    #                 parameters.append((dataset, checklist_params, target_support))

    # with Pool(args.n_cores) as p:
    #     start_time = time.perf_counter()
        
    #     results = list(tqdm.tqdm(p.imap(run_checklist_trial, parameters), total=len(parameters)))
    #     results = pd.DataFrame(results)
    #     results.to_csv('checklists_%s' % args.save_file_name, index=False)
        
    #     finish_time = time.perf_counter()
    #     print(f"Rule Lists finished in {(finish_time-start_time) / 60} minutes")