import pdb
import pickle
import pandas as pd
import argparse
from tree import *
from weka_parser import parse_tree
from solver import *


def read_data(filename):
    # Read the metadata to get attribute names
    attribute_names = []        # Later called dimensions
    with open(filename, 'r') as file:
        for line in file:
            if line.startswith('@attribute'):
                name = line.split()[1]
                attribute_names.append(name)
            elif line.startswith('@data'):
                break
    
    # Read the data into a pandas DataFrame
    df = pd.read_csv(filename, comment='@', header=None, names=attribute_names)
    
    # # Drop the 'd' prefix from column names if present
    # df.columns = [col.lstrip('d') if col.startswith('d') and col[1:].isdigit() else col for col in df.columns]
    
    return df


def main():
    # Set up argument parser
    parser = argparse.ArgumentParser(description="Perform local search to improve a decision tree.")
    parser.add_argument("--tree", help="Path to a file containing the WEKA tree string")
    parser.add_argument("--data", help="Path to a file containing the data")
    parser.add_argument("--kexch", type=int, help="Upper bound on the number of cut exchange operations.")
    parser.add_argument("--kadj", type=int, help="Upper bound on the number of threshold adjustment operations.")
    parser.add_argument("--krais", type=int, help="Upper bound on the number of subtree raising operations.")
    parser.add_argument("--krepl", type=int, help="Upper bound on the number of subtree replacement operations.")
    parser.add_argument("--min-errors-output", help="Path to save the min_errors CSV file")
    # parser.add_argument("--trees-output", help="Path to save the min_error_pruned_trees pickle file")
    # parser.add_argument('-k', '--prune-nodes', type=int, required=True, help="Exact number of nodes to be pruned.")
    args = parser.parse_args()

    df = read_data(args.data)

    # Parse the tree and print it
    print("Initial tree:")
    tree = parse_tree(args.tree)
    # Subtree sizes are used in the dynamic-program recurrence in the
    # solver.
    tree.compute_subtree_sizes()
    tree.compute_node_depths()
    tree.print_tree()

    min_errors, min_error_result_trees = local_search_tree(tree, df, args.kexch, args.kadj, args.krais, args.krepl)

    print("Minimum number of errors for each budget tuple:")
    print(min_errors)

    if args.min_errors_output:
        with open(args.min_errors_output, 'w') as f:
            f.write("k_exch,k_adj,k_rais,k_repl,min_errors\n")
            for (k_exch, k_adj, k_rais, k_repl), errors in min_errors.items():
                f.write(f"{k_exch},{k_adj},{k_rais},{k_repl},{errors}\n")
        print(f"Min errors saved to: {args.min_errors_output}")
    else:
        print("No output file specified for min_errors. Skipping CSV write.")

    # Write min_error_pruned_trees to a pickle file if filename is provided
    # if args.trees_output:
    #     with open(args.trees_output, 'wb') as f:
    #         pickle.dump(min_error_pruned_trees, f)
    #     print(f"Pruned trees saved to: {args.trees_output}")
    # else:
    #     print("No output file specified for min_error_pruned_trees. Skipping pickle dump.")

if __name__ == "__main__":
    main()
