import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tqdm
import argparse
import time
import os
import copy

from datasets import Recidivism, FICO, Dataset, Schizo, Adults, Diabetes, Readmission
from sa import BaseAlgorithm, AlgorithmParams
from rules import TwoWeightKnapsackRule, IntegerKnapsackRule, ORRule, Rule
from neighbors import swap_high_rule, swap_low_rule, move_low_to_high, move_high_to_low
from complement import Complement, ComplementParams
from consistency import Consistency, ConsistencySoft, CoverageConsistencyParams
from benchmark import BenchmarkRuleMiner
from multiprocessing import Pool, cpu_count
from dataclasses import dataclass, replace

os.chdir('/home/evanyao/paper')

def run_trial(inpt: (Dataset, CoverageConsistencyParams, BenchmarkRuleMiner)) -> pd.DataFrame:
    dataset: Dataset = inpt[0]
    param: CoverageConsistencyParams = inpt[1]
    brm: BenchmarkRuleMiner = inpt[2]
        
    alg_soft = ConsistencySoft(dataset, param, False)
    our_rule = alg_soft.run()
    starting_rule = alg_soft.starting_rule
    
    benchmark_rules = brm.get_top_rule(
        alg_soft.evaluate_rule_train,
        alg_soft.score_rule,
        num=5,
    )
    
    alg_hard = Consistency(dataset, param, skeleton=True)
    
    results = []

    for (name, rule) in [('our', our_rule), ('start', starting_rule)] + [('bench', b) for b in benchmark_rules]:
        results += [
            {
                'name': name, 
                'metric': 'consistency_soft',
                'p': param.c,
                'train': alg_soft.evaluate_rule_train(rule),
                'test': alg_soft.score_rule(rule),  
                'rule': str(rule),
            }, {
                'name': name, 
                'metric': 'consistency_hard',
                'p': param.c,
                'train': alg_hard.evaluate_rule_train(rule),
                'test': alg_hard.score_rule(rule),
                'rule': str(rule),
            }
        ]
    
        for p in np.arange(param.c, 0.4, 0.025):
            comp_param = ComplementParams(
                num_iter=param.num_iter,
                N=param.N,
                c=param.c,
                p=p,
                allow_high_low_switch=False,
                should_validate=True,
            )
    
            alg_comp = Complement(dataset, comp_param, skeleton=True)

            results.append({
                'name': name, 
                'metric': 'complement',
                'p': p,
                'train': alg_comp.evaluate_rule_train(rule),
                'test': alg_comp.score_rule(rule),    
                'rule': str(rule),
            })

    results = pd.DataFrame(results) 
    results['dataset'] = dataset.__class__.__name__
    results['c'] = param.c 
    results['N'] = param.N
    results['support'] = brm.support
    results['zmax'] = brm.zmax

    return results
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser()          
    parser.add_argument('--n_cores', type=int, default=cpu_count(), help='Number of cores to use')
    parser.add_argument('--n_runs', type=int, default=1, help='Number of splits')
    parser.add_argument('--n_sa_runs', type=int, default=1, help='Number of SA runs')    
    parser.add_argument('--save_file_name', type=str, help='Name of the file to save (need to include .csv)', required=True)
    parser.add_argument('--Q', type=int, default=5, help='Slices for each numerical column')

    ## Benchmark
    parser.add_argument('--n_benchmark_rules', type=int, default=2000, help='Number of benchmark rules to mine')
    # parser.add_argument('--n_benchmark_support', type=int, default=3, help='Number of rules in each rule set')
    parser.add_argument('--association_rule_size', type=int, default=2, help='Number of conditions in association rule')
    
    ## Checklist
    # parser.add_argument('--N_low', type=int, default=5, help='Number of rules in the checklist')
    # parser.add_argument('--N_high', type=int, default=5, help='Number of rules in the checklist')

    args = parser.parse_args()

    print('Number of Cores: %d' % cpu_count())

    parameters = [] 

    for seed in tqdm.trange(args.n_runs):
        datasets = [
            # Recidivism(random_seed=seed, num_features_universe=40, Q=args.Q),
            # FICO(random_seed=seed, num_features_universe=40, Q=args.Q),
            # # Adults(random_seed=seed, num_features_universe=40, Q=args.Q),
            # Diabetes(random_seed=seed, num_features_universe=40, Q=args.Q),
            # Schizo(random_seed=seed, num_features_universe=40, Q=args.Q),
            Readmission(random_seed=seed, num_features_universe=40, Q=args.Q),
        ]

        for d in datasets:
            for zmax in [2, 3]:
                brm_master = BenchmarkRuleMiner(d)
                brm_master.get_pareto_rules(zmin=2, zmax=zmax)
                for support in [5]:
                    brm = copy.deepcopy(brm_master)                
                    brm.get_or_rules(num=args.n_benchmark_rules, support=support)

                    for _ in range(args.n_sa_runs):
                        for c in np.arange(0.025, 0.4, 0.025):
                            for N in [5, 7]:
                                param = CoverageConsistencyParams(
                                    num_iter=1000, 
                                    N=N,
                                    c=c,
                                    allow_high_low_switch=False,
                                    should_validate=True,
                                )

                                parameters.append((d, param, brm))
        
    with Pool(args.n_cores) as p:
        start_time = time.perf_counter()
        
        results = list(tqdm.tqdm(p.imap(run_trial, parameters), total=len(parameters)))
        results = pd.concat(results, axis=0).reset_index(drop=True)
        results.to_csv('results/%s' % args.save_file_name, index=False)
        
        finish_time = time.perf_counter()
        print(f"Program finished in {(finish_time-start_time) / 60} minutes")