import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tqdm
import argparse
import time
import os
from datasets import Recidivism, FICO, Dataset, Schizo, Adults, Diabetes
from sa import BaseAlgorithm, AlgorithmParams
from rules import TwoWeightKnapsackRule, IntegerKnapsackRule, ORRule, Rule
from neighbors import swap_high_rule, swap_low_rule, move_low_to_high, move_high_to_low
from complement import Complement, ComplementParams
from consistency import Coverage, Consistency, ConsistencySoft, CoverageConsistencyParams
from benchmark import BenchmarkRuleMiner
from multiprocessing import Pool, cpu_count

os.chdir('/home/evanyao/paper')

def run_trial(inpt: (Dataset, list[Rule], CoverageConsistencyParams)) -> pd.DataFrame:
    dataset: Dataset = inpt[0]
    benchmark_rule_list: list[Rule] = inpt[1]
    param: CoverageConsistencyParams = inpt[2]
    
    results = []
    
    for evaluation_method in [Coverage, Consistency, ConsistencySoft]:
        alg = evaluation_method(dataset, param, False)
        rule = alg.run()
        
        benchmark_scores = score_benchmark(alg, benchmark_rule_list)
        
        results.append({
            'dataset': dataset.__class__.__name__,
            'criteria': evaluation_method.__name__,
            'c': param.c,
            'N': param.N,
            'train_score': alg.evaluate_rule_train(rule),
            'test_score': alg.score_rule(rule),
            'train_start': alg.evaluate_rule_train(alg.starting_rule),
            'test_start': alg.score_rule(alg.starting_rule),
            'bench_train': benchmark_scores[0],
            'bench_test': benchmark_scores[1],
        })
    
    return pd.DataFrame(results)

def score_benchmark(alg: BaseAlgorithm, rule_list: list[Rule]):
    results = []
    for r in rule_list:
        results.append({
            'rule': r, 
            'train': alg.evaluate_rule_train(r), 
            'test': alg.score_rule(r)})

    results = pd.DataFrame(results).sort_values(by='train', ascending=False)
    
    return results[:3]['train'].mean(), results[:3]['test'].mean()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()          
    parser.add_argument('--n_cores', type=int, default=cpu_count(), help='Number of cores to use')
    parser.add_argument('--n_runs', type=int, default=1, help='Number of splits')
    parser.add_argument('--n_sa_runs', type=int, default=1, help='Number of SA runs')
    parser.add_argument('--save_file_name', type=str, default='results.csv', help='Name of the file to save (need to include .csv)')
    parser.add_argument('--n_benchmark_rules', type=int, default=500, help='Number of benchmark rules to mine')
    parser.add_argument('--rule_size', type=int, default=6, help='Number of rules in the checklist')
    parser.add_argument('--Q', type=int, default=5, help='Slices for each numerical column')
    
    print('Number of Cores: %d' % cpu_count())
    
    args = parser.parse_args()

    for seed in tqdm.trange(args.n_runs):
        datasets = [
            Recidivism(random_seed=seed, num_features_universe=40, Q=args.Q),
            FICO(random_seed=seed, num_features_universe=40, Q=args.Q),
            Adults(random_seed=seed, num_features_universe=40, Q=args.Q),
            Diabetes(random_seed=seed, num_features_universe=40, Q=args.Q),
            Schizo(random_seed=seed, num_features_universe=40, Q=args.Q),
        ]

        parameters = [] 

        for d in datasets:
            brm = BenchmarkRuleMiner(d)
            brm.get_pareto_rules()
            benchmark_rules = brm.get_or_rules(num=args.n_benchmark_rules).rule

            for c in np.arange(0.05, 0.45, 0.05):
                param = CoverageConsistencyParams(
                    num_iter=1000,
                    N=args.rule_size,
                    c=c,
                    allow_high_low_switch=False,
                    should_validate=True,
                )
                    
                for _ in range(args.n_sa_runs):
                    parameters.append((d, benchmark_rules, param))
        
    with Pool(args.n_cores) as p:
        start_time = time.perf_counter()
        
        results = list(tqdm.tqdm(p.imap(run_trial, parameters), total=len(parameters)))
        results = pd.concat(results, axis=0).reset_index(drop=True)
        results.to_csv('results/goal1/%s' % args.save_file_name, index=False)
        
        finish_time = time.perf_counter()
        print(f"Program finished in {(finish_time-start_time) / 60} minutes")