import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tqdm
import argparse
import time
import os

from datasets import Recidivism, FICO, Dataset, Schizo, Adults, Diabetes
from sa import BaseAlgorithm, AlgorithmParams
from rules import TwoWeightKnapsackRule, IntegerKnapsackRule, ORRule, Rule
from neighbors import swap_high_rule, swap_low_rule, move_low_to_high, move_high_to_low
from complement import Complement, ComplementParams
from benchmark import BenchmarkRuleMiner
from multiprocessing import Pool, cpu_count
from dataclasses import dataclass, replace

os.chdir('/home/evanyao/paper')

def run_trial(inpt: (Dataset, list[Rule], ComplementParams)) -> pd.DataFrame:
    dataset: Dataset = inpt[0]
    benchmark_rule_list: list[Rule] = inpt[1]
    param: ComplementParams = inpt[2]
        
    alg = Complement(dataset, param, False)
    rule = alg.run()

    benchmark_scores = score_benchmark(alg, benchmark_rule_list)

    return pd.DataFrame([{
        'dataset': dataset.__class__.__name__,
        'criteria': 'Complement',
        'c': param.c,
        'p': param.p,
        'N': param.N,
        'train_score': alg.evaluate_rule_train(rule),
        'test_score': alg.score_rule(rule),
        'train_start': alg.evaluate_rule_train(alg.starting_rule),
        'test_start': alg.score_rule(alg.starting_rule),
        'bench_train': benchmark_scores[0],
        'bench_test': benchmark_scores[1],
    }])

def score_benchmark(alg: BaseAlgorithm, rule_list: list[Rule]):
    results = []
    for r in rule_list:
        results.append({
            'rule': r, 
            'train': alg.evaluate_rule_train(r), 
            'test': alg.score_rule(r)})

    results = pd.DataFrame(results).sort_values(by='train', ascending=False)
    return results[:3]['train'].mean(), results[:3]['test'].mean()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()          
    parser.add_argument('--n_cores', type=int, default=cpu_count(), help='Number of cores to use')
    parser.add_argument('--n_runs', type=int, default=1, help='Number of splits')
    parser.add_argument('--n_sa_runs', type=int, default=1, help='Number of SA runs')    
    parser.add_argument('--save_file_name', type=str, default='results.csv', help='Name of the file to save (need to include .csv)')
    parser.add_argument('--n_benchmark_rules', type=int, default=500, help='Number of benchmark rules to mine')
    parser.add_argument('--rule_size', type=int, default=6, help='Number of rules in the checklist')
    parser.add_argument('--Q', type=int, default=5, help='Slices for each numerical column')

    args = parser.parse_args()

    for seed in tqdm.trange(args.n_runs):
        datasets = [
            Recidivism(random_seed=seed, num_features_universe=40, Q=args.Q),
            FICO(random_seed=seed, num_features_universe=40, Q=args.Q),
            Adults(random_seed=seed, num_features_universe=40, Q=args.Q),
            Diabetes(random_seed=seed, num_features_universe=40, Q=args.Q),
            Schizo(random_seed=seed, num_features_universe=40, Q=args.Q),
        ]

        parameters = [] 

        for d in datasets:
            brm = BenchmarkRuleMiner(d)
            brm.get_pareto_rules()
            benchmark_rules = brm.get_or_rules(num=args.n_benchmark_rules).rule

            for _ in range(args.n_sa_runs):
                for p in np.arange(0.05, 0.45, 0.05):
                    for c in np.arange(0.05, p + 0.05, 0.05):
                        param = ComplementParams(
                            num_iter=1000,
                            N=args.rule_size,
                            c=c,
                            p=p,
                            allow_high_low_switch=False,
                            should_validate=True,
                        )

                        parameters.append((d, benchmark_rules, param))
        
    with Pool(args.n_cores) as p:
        start_time = time.perf_counter()
        
        results = list(tqdm.tqdm(p.imap(run_trial, parameters), total=len(parameters)))
        results = pd.concat(results, axis=0).reset_index(drop=True)
        results.to_csv('results/goal2/%s' % args.save_file_name, index=False)
        
        finish_time = time.perf_counter()
        print(f"Program finished in {(finish_time-start_time) / 60} minutes")