# -*- coding: utf-8 -*-
import os
import shutil
import argparse
import subprocess
from datetime import datetime
import time
import re
import yaml
import math
import warnings

from autosat.utils import revise_file, clean_files, collect_results_eval, copy_folder
from autosat.execution.execution_worker import ExecutionWorker


def evaluate(args, SAT_solver_file_path, method_name=None):
    # Use dataset name and model name to create organized directory structure
    dataset_name = os.path.basename(os.path.normpath(args.eval_data_dir))
    # Get method_name from args (check both method_name and model_name for backward compatibility)
    method_name = getattr(args, 'method_name', None) or getattr(args, 'model_name', None) or method_name or os.path.basename(SAT_solver_file_path).replace('.cpp', '')
    
    # Create organized directory: results_save_path/dataset_name/method_name/
    eval_session_dir = os.path.join(args.results_save_path, dataset_name, method_name)
    os.makedirs(eval_session_dir, exist_ok=True)

    # Use eval session directory for temp files to avoid conflicts
    temp_base_dir = eval_session_dir

    # # STEP 1. run SAT solver , get raw results in parallel.
    results_save_path_intermediate = os.path.join(eval_session_dir, 'tmp')
    os.makedirs(results_save_path_intermediate, exist_ok=True)

    # Instead of copying the entire directory, just copy the specific solver file
    # This prevents copying all solver files when evaluating individual combinations
    tmp_cpp_source_path = os.path.join(results_save_path_intermediate, 'SAT_Solver_tmp.cpp')
    tmp_executable_file_path = os.path.join(results_save_path_intermediate, 'SAT_Solver_tmp')
    
    # Copy the solver file directly to the intermediate directory
    import shutil
    shutil.copy2(SAT_solver_file_path, tmp_cpp_source_path)
    
    # Also copy necessary header files from the solver's directory
    solver_dir = os.path.dirname(SAT_solver_file_path)
    for header_file in ['EasySAT.hpp', 'heap.hpp']:
        header_path = os.path.join(solver_dir, header_file)
        if os.path.exists(header_path):
            dst_header_path = os.path.join(results_save_path_intermediate, header_file)
            shutil.copy2(header_path, dst_header_path)
            print(f"Copied {header_file} to evaluation directory")
    
    # change the file.
    revise_file(
        file_name=SAT_solver_file_path,
        save_dir=tmp_cpp_source_path,
        timeout=args.eval_timeout,
        data_dir="\"" + args.eval_data_dir + "\"",
        results_dir=f"{temp_base_dir}/results",  # Use unique results directory
    )
    cnf_duration_situation_fpath = f'{temp_base_dir}/results/'
    os.makedirs(cnf_duration_situation_fpath, exist_ok=True)  # Create results directory
    execution_worker = ExecutionWorker(temp_base_dir=temp_base_dir)
    success = execution_worker.execute_eval(source_cpp_path=tmp_cpp_source_path,
                                            executable_file_path=tmp_executable_file_path,
                                            data_parallel_size=args.eval_parallel_size)
    if not success:
        raise RuntimeError("cannot correctly execute... plz check again")

    eval_data_dir = args.eval_data_dir  # TODO change xxx
    filenames = [str(1) + "_" + str(num) + ".txt" for num in
                 range(args.eval_parallel_size)]  # set `id` = 1 during evaluation
    data_num = len([f for f in os.listdir(eval_data_dir) if os.path.isfile(os.path.join(eval_data_dir, f))])
    print("data_num:", data_num, "eval_parallel_sizes: ", args.eval_parallel_size)
    if args.eval_parallel_size > data_num:
        warnings.warn(f"The parallel num for training is too large: {args.eval_parallel_size} > {data_num}. "
                      f"It will be replaced with the train set total num: {data_num}",
                      category=UserWarning, stacklevel=2)
        setattr(args, 'eval_parallel_size', data_num)
    start_time = time.time()
    while True:
        end_time = time.time()
        # if end_time - start_time > args.eval_timeout * 1.5 * math.ceil(data_num / args.eval_parallel_size) + 10:
        #     raise RuntimeError("Infinite loop error!!!")
        if end_time - start_time > args.eval_timeout * (2 * data_num / args.eval_parallel_size + 30):
            raise ValueError("Infinite loop error!!!")
        all_exist = all(
            os.path.exists(os.path.join(f'{temp_base_dir}/results/', 'finished' + filename)) for filename in filenames)
        if all_exist:
            break
    if not all_exist:
        raise ValueError("sth. wrong during evaluation")

    print('SAT Solver finished...')
    # STEP 2. collect results.
    # Generate timestamp for this evaluation run
    formatted_date_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    result_dict = collect_results_eval(raw_path=cnf_duration_situation_fpath,
                                       final_path=os.path.join(eval_session_dir,
                                                               'results_{}.txt'.format(formatted_date_time)),
                                       args=args)
    print(f'results are saved in {args.results_save_path} ...')

    # STEP 3. remove temporary files if requested
    if not args.keep_intermediate_results:
        try:
            # Clean temporary compilation files
            shutil.rmtree(results_save_path_intermediate)
            # Clean temporary results files
            clean_files(folder_path=cnf_duration_situation_fpath, mode="all")
            print(f"Cleaned temporary files in {eval_session_dir}")
        except Exception as e:
            warnings.warn(f"Warning when removing temporary files in {eval_session_dir}: {e}",
                          category=UserWarning, stacklevel=2)
    return result_dict


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        type=str,
                        default='./examples/EasySAT/eval_config.yaml',
                        help='Path to the config file')
    parser.add_argument('--SAT_solver_file_path', type=str, default='./template/EasySAT_eval/EasySAT_template.cpp',
                        help='SAT solver file path (NOTICE: auxiliary functions should be in the same directory).')
    parser.add_argument('--eval_data_dir', default='./evaluation/', type=str,
                        help='the directory where cnf files are stored.')
    parser.add_argument('--results_save_path', type=str, default='./temp/eval_results/',
                        help='where the final result are saved.')
    parser.add_argument('--eval_parallel_size', type=int, default=16, help='parallel in K processions.')
    parser.add_argument('--eval_timeout', type=int, default=1500, help='time-out for SAT Solver')
    parser.add_argument('--rand_seed', type=int, default=42, help='random seed')
    parser.add_argument('--keep_intermediate_results', type=bool, default=True,
                        help='whether to keep intermediate results.')
    parser.add_argument('--method_name', type=str, default=None, help='character or name for the SAT Solver')
    parser.add_argument('--model_name', type=str, default=None, help='alias for method_name (for backward compatibility)')

    args = parser.parse_args()

    # Load config file first (as defaults)
    if os.path.exists(args.config):
        print('Loading config file:', args.config)
        with open(args.config, 'r') as file:
            config = yaml.safe_load(file)
            # Set config values as defaults (only if not provided via command line)
            for key, value in config.items():
                if not hasattr(args, key) or getattr(args, key) == parser.get_default(key):
                    # Only set if not explicitly provided via command line
                    setattr(args, key, value)
                    print(f'  Loaded from config: {key} = {value}')
                else:
                    print(f'  Command line overrides config: {key} = {getattr(args, key)} (config had: {value})')
    else:
        print('No config file found, using command line arguments only')

    dataset_name = os.path.basename(os.path.normpath(args.eval_data_dir))
    method_name = getattr(args, 'method_name', None) or getattr(args, 'model_name', None) or 'baseline'
    
    print(f'Evaluation Configuration:')
    print(f'  - Dataset: {dataset_name}')
    print(f'  - Method: {method_name}')
    print(f'  - Timeout: {args.eval_timeout}s')
    print(f'  - SAT Solver: {args.SAT_solver_file_path}')
    print(f'  - Results will be saved to: {args.results_save_path}/{dataset_name}/{method_name}/')
    
    evaluate(args, SAT_solver_file_path=args.SAT_solver_file_path)
