import numpy as np
import pandas as pd

import argparse

from experiment_utils import graph_together, graph_separate

def main(
    data_dirs, 
    data_names, 
    output_dir,
    y_label,
):
    params = []
    scores = []
    for data_dir in data_dirs:
        param = pd.read_csv(f'{data_dir}/params.csv')
        score = pd.read_csv(f'{data_dir}/scores.csv')

        params.append(param)
        scores.append(score)

    graph_together(
        path = f'{output_dir}/params.csv', 
        data_list = params, 
        y_name = 'Parameter', 
        y_label = y_label, 
        legend_labels = data_names,
    )

    graph_together(
        path = f'{output_dir}/raw_scores.csv', 
        data_list = scores, 
        y_name = 'Train', 
        y_label = 'Train Accuracies', 
        legend_labels = data_names,
    )

    graph_separate(
        path = f'{output_dir}/score_diffs.csv', 
        data_list = scores, 
        y_name = 'Train-Val', 
        y_label = r'(Train $-$ Val) Accuracy', 
        legend_labels = data_names,
    )

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Runs the experiment described for Figure 1')
    parser.add_argument(
        '-d', 
        '--data_dirs', 
        nargs='*',
        default=['./data/cart/compas', './data/cart/fico', './data/cart/telco', './data/cart/monks1'],
    )
    parser.add_argument(
        '-n', 
        '--data_names', 
        nargs='*',
        default=['Compas', 'FICO', 'Telco Bin', 'Monks 1'],
    )
    parser.add_argument('-o', '--output_dir', default='./data/cart')
    parser.add_argument('-l', '--y_label', default='Best Depth')

    args = parser.parse_args()

    main(
        data_dirs = args.data_dirs, 
        data_names = args.data_names, 
        output_dir = args.output_dir, 
        y_label = args.y_label,
    )