"""
Generate result from trained model.
"""
import pickle

import argparse
import os
from collections import defaultdict
import re
import json
import logging

from module.result import Result

def main(args):
    """ Main function for generating result from trained model.

    Args:
        args (argparse.Namespace): Arguments from command line.
    """
    pattern = r"([\w-]+)_([\w-]+)_fold_(\d+)\.pkl"

    file_structure = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(str))))

    model_dir = os.path.join(args.output_dir, args.model_dir)
    result_dir = os.path.join(args.output_dir, args.result_dir)
    param_dir = os.path.join(args.output_dir, args.param_dir)

    model_files = {f for f in os.listdir(model_dir) if f.endswith(".pkl")}

    for filename in model_files:
        assert re.match(pattern, filename), f"Filename {filename} does not match pattern {pattern}"
        match = re.match(pattern, filename)
        model, dataset, fold = match.groups()

        if args.dataset_list is not None:
            if dataset not in args.dataset_list.split(","):
                continue

        assert os.path.exists(os.path.join(result_dir, filename)), f"Result file {os.path.join(result_dir, filename)} does not exist"
        assert os.path.exists(os.path.join(param_dir, filename)), f"Param file {os.path.join(param_dir,filename)} does not exist"

        file_structure[dataset][model][fold]["model"] = os.path.join(model_dir, filename)
        file_structure[dataset][model][fold]["result"] = os.path.join(result_dir, filename)
        file_structure[dataset][model][fold]["param"] = os.path.join(param_dir, filename)

    logger = logging.getLogger("run_result.py main")
    logger.info("File structure: %s", json.dumps(file_structure, indent=2))

    os.makedirs(os.path.join(args.output_dir, args.result_dir), exist_ok=True)
    result = Result(file_structure, args)

    if args.gen_table:
        table, table_mean, table_std, table_fit_time = result.generate_table(skip_time=False)
        with open(f'{args.result_dir}/table.pkl', 'wb') as f:
            pickle.dump(table, f)
        with open(f'{args.result_dir}/table_mean.pkl', 'wb') as f:
            pickle.dump(table_mean, f)
        with open(f'{args.result_dir}/table_std.pkl', 'wb') as f:
            pickle.dump(table_std, f)
        with open(f'{args.result_dir}/table_fit_time.pkl', 'wb') as f:
            pickle.dump(table_fit_time, f)

    if args.gen_density:
        result.plot_density("kantch_evasion_attack_accuracy_on_test")

    # result.plot_scatter_sparsity()
    # result.plot_lineplot_patterns()
    # result.plot_scatter_patterns()
    # result.report_leaf_dist()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate result from trained model.")
    # Input/Output Section
    io_group = parser.add_argument_group('Input/Output', 'Arguments related to file and folder paths')
    io_group.add_argument('--output_dir', type=str, default="out/out_paper", help='Directory of experiment outputs')
    io_group.add_argument('--result_dir', type=str, default="result", help='Directory of results')
    io_group.add_argument('--model_dir', type=str, default="model", help='Directory of models')
    io_group.add_argument('--param_dir', type=str, default="param", help='Directory of parameters')
    io_group.add_argument('--fig_dir', type=str, default="fig", help='Directory for saving figures')
    io_group.add_argument('--table_dir', type=str, default="table", help='Directory for saving tables')

    setup_group = parser.add_argument_group('Setup', 'Arguments related to setup')
    setup_group.add_argument('--log_level', type=str, default="INFO", help='Logger level')
    setup_group.add_argument('--dataset_list', type=str, default=None, help='List of datasets to run on')
    
    result_group = parser.add_argument_group('Result', 'Arguments related to result generation')
    result_group.add_argument('--gen_table', action='store_true', help='Generate table')
    result_group.add_argument('--gen_density', action='store_true', help='Generate density plot')

    args = parser.parse_args()


    logging.basicConfig(
        level=logging._nameToLevel[args.log_level],
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",  # Log message format
        datefmt="%Y-%m-%d %H:%M:%S",              # Date and time format
        handlers=[
            logging.FileHandler("result.log"),     # Log to a file
            logging.StreamHandler()               # Log to console
        ]
    )


    main(args)
