from jinja2 import FileSystemLoader, Environment
import os
import re
import glob
import shutil
import platform
import subprocess
import json
from typing import Dict, List, Tuple
from datetime import datetime


def get_code(answer, seperator):
    # get code from format `separator`
    start = answer.find(seperator[0])
    end = answer.find(seperator[1], start) # - len(seperator[1])
    if (start == -1) or (end == -1): return ''
    return answer[start+len(seperator[0]): end]


def get_batch_id(count, batch_size):
    # from global_id -> batch_id
    return (count-1) % batch_size

def revise_file(file_name, save_dir, *args, **kwargs):
    # current_working_directory = os.getcwd()
    # print(f"current_working_directory: {current_working_directory}")
    env = Environment(loader=FileSystemLoader('.'))
    template = env.get_template(file_name)
    output = template.render(*args, **kwargs)

    with open(save_dir, 'w') as f:
        f.write(output)


def clean_files(folder_path, mode, *args, **kwargs):
    if mode == "all":
        for filename in os.listdir(folder_path):
            file_path = os.path.join(folder_path, filename)
            if os.path.isfile(file_path):
                os.remove(file_path)
    elif mode == "exe":
        for file_path in glob.glob(os.path.join(folder_path, "*.exe")):
            try:
                os.remove(file_path)
            except Exception as e:
                print(f"Error deleting {file_path}: {e}")
    elif mode == "folder":
        pass
    else:
        raise NotImplemented


def process_raw_results(folder_path, timeout, answers=None):
    result = {
        "time": {},
        "prompt": {},
        "PAR-2": {},
        "satisfiable": {},
        "unsatisfiable": {},
        "timeout": {},
    }
    record_all_data = [] if answers is None else None
    for filename in os.listdir(folder_path):
        match = re.match(r'(\d+)_(\d+).txt', filename)
        if match:
            id, num = match.groups() # Notice, id(key of result / results): str
            file_path = os.path.join(folder_path, filename)
            if os.path.isfile(file_path):
                tmp_total_time = 0
                tmp_situation = {"satisfiable": 0,
                                 "unsatisfiable": 0,
                                 "timeout": 0}
                tmp_par2 = 0
                with open(file_path, 'r') as file:
                    for line in file.readlines():
                        line = line.strip().strip('\n').strip()
                        if line.startswith('File name'):
                            continue
                        parts = line.split('\t')
                        duration = int(parts[1])
                        situation_single = parts[2].lower()
                        tmp_situation[situation_single] += 1
                        tmp_total_time += duration
                        tmp_par2 += duration if duration < timeout else 2*timeout
                        if record_all_data is not None:
                            cnf_file_name = parts[0]
                            record_all_data.append((cnf_file_name, duration, situation_single))
                # finish reading the file, load temp results
                if id in result["time"]:
                    result["time"][id] += tmp_total_time
                    result["PAR-2"][id] += tmp_par2
                    for situation_key in tmp_situation:
                        result[situation_key][id] += tmp_situation[situation_key]
                else:
                    result["time"][id] = tmp_total_time
                    result["PAR-2"][id] = tmp_par2
                    if answers:
                        # Try both string and int keys
                        prompt_value = answers.get(id, answers.get(int(id), 'Evaluation Stage.'))
                    else:
                        prompt_value = 'Evaluation Stage.'
                    result["prompt"][id] = prompt_value
                    for situation_key in tmp_situation:
                        result[situation_key][id] = tmp_situation[situation_key]
    if answers is not None: # train
        # Calculate PAR-2 as average (same as evaluation) for consistency
        for id in result["PAR-2"]:
            total_questions = result.get("satisfiable", {}).get(id, 0) + result.get("unsatisfiable", {}).get(id, 0) + result.get("timeout", {}).get(id, 0)
            if total_questions > 0:
                result["PAR-2"][id] = round(result["PAR-2"][id] / total_questions, 2)
            else:
                # If no questions processed, keep original value
                pass
        return result
    else: # eval
        result['total time'] = result.pop('time')
        result.pop('prompt')
        result_dict = {k: v['1'] for k, v in result.items()} # during evaluation, global_id \equiv '1'
        result_dict['#question'] = result_dict['satisfiable'] + result_dict['unsatisfiable'] + result_dict['timeout']
        result_dict['PAR-2'] = round(result_dict['PAR-2'] / result_dict['#question'] , 2)
        return result_dict, record_all_data


def collect_results(answers, repetition_dict, results, args):
    repetition_result = {
        "time": {},
        "prompt": {},
        "PAR-2": {},
        "satisfiable": {},
        "unsatisfiable": {},
        "timeout": {},
    }
         # Use temp_base_dir from args if available, otherwise use default
    temp_base_dir = getattr(args, 'temp_base_dir', './temp')
    folder_path = '{}/results/'.format(temp_base_dir)
    result = process_raw_results(folder_path=folder_path, timeout=args.timeout, answers=answers)
    

    
    if args.devoid_duplication:
        for value in list(repetition_dict.values()):
            key = find_key_for_value(results["prompt"], value)
            if key == None:
                break
            repetition_result["time"][key] = results["time"][key]
            repetition_result["prompt"][key] = results["prompt"][key]

        result["time"].update(repetition_result["time"])
        result["prompt"].update(repetition_result["prompt"])

         # Ensure PAR-2 dictionary is not empty
    if not result["PAR-2"]:
        print("Warning: All code failed, using default values")
        result["PAR-2"]["0"] = float('inf')
        result["time"]["0"] = float('inf')
        result["prompt"]["0"] = "default_failed_code"
    
         # Ensure baseline (ID "0") always exists in results
    if "0" not in result["PAR-2"] and "0" in results["PAR-2"]:
                 # If current result doesn't have baseline but global results do, keep baseline
        result["PAR-2"]["0"] = results["PAR-2"]["0"]
        result["time"]["0"] = results["time"]["0"]
        result["prompt"]["0"] = results["prompt"]["0"]
                 # Ensure baseline has extra_params entry (even if empty)
        if "0" not in result.get("extra_params", {}):
            if "extra_params" not in result:
                result["extra_params"] = {}
            result["extra_params"]["0"] = {}
    
    best_key = min(result["PAR-2"], key=result["PAR-2"].get) # global_id
    return result, {best_key: [result["time"][best_key], result["prompt"][best_key], result["PAR-2"][best_key]]}


def collect_results_eval(raw_path, final_path, args):
    folder_path = raw_path
    result_dict, record_all_data = process_raw_results(folder_path=folder_path, timeout=args.eval_timeout, answers=None)  # eval mode
    with open(final_path, 'a+', encoding='utf-8') as f:
        f.write("cnf File \t Duration \t Situation \n")
        for cnf_name, duration, situation in record_all_data:
            f.write(f"{cnf_name}\t{duration}\t{situation}\n")
        f.write(str(result_dict) + '\n')
    return result_dict


def fill_core_codes(origin_file, target_file, answer_code,**kwargs):
    revise_file(file_name=origin_file,
                save_dir=target_file,
                timeout='{{ timeout }}',
                data_dir='{{ data_dir }}',
                replace_code=answer_code,
                **kwargs)
    return


def delete_InfiniteLoopInst(candidates, result_dict, results_folder=None):
    if results_folder is None:
        results_folder = './temp/results/'  # fallback for backward compatibility
    failed_id_list = []
    for file_name in candidates:
        if not os.path.isfile(os.path.join(results_folder, file_name)):  # failed
            id_str = file_name.replace('finished', '').split('_')[0]
            for key in result_dict:
                if id_str in result_dict[key]:
                                         # Set failed code to infinity instead of deleting
                    result_dict[key][id_str] = float('inf')
                    failed_id_list.append(id_str)
                    print(f"Code {id_str} failed (compilation error or runtime crash), setting PAR-2 to infinity")
    # kill the procession. Maybe dangerous.
    if platform.system() == 'Windows':
        try:
            result = subprocess.run(['taskkill', '/F', '/IM', 'EasySAT'], check=True, text=True)  # TODO check
        except Exception as e:
            print(f"wrong when killing procession: {e}")
        pass
    elif platform.system() == 'Linux':
        try:
            # More specific kill command to avoid killing the main process
            result = subprocess.run(['pkill', '-f', 'EasySAT_'], check=True, stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE, text=True)
        except Exception as e:
            print(f"wrong when killing procession: {e}")
    else:
        raise NotImplementedError('sorry, we only support Wins Or Linux.')
    return failed_id_list


def copy_folder(src_folder, num, mode='train', target_folder = None):
    if mode == 'train':
        for i in range(num):
            new_folder_path = src_folder[:-1] + "_{}/".format(i)
            if os.path.exists(new_folder_path):
                shutil.rmtree(new_folder_path)
            shutil.copytree(src_folder, new_folder_path)
    elif mode == 'eval':
        if target_folder is None:
            raise ValueError('please set target folder to save source files.')
        if os.path.exists(target_folder):
            shutil.rmtree(target_folder)
        shutil.copytree(src_folder, target_folder)
    else:
        raise NotImplementedError('please choose `mode` between `train` or `eval`.')


def find_key_for_value(results, value_to_find):
    for key, value in results.items():
        if value == value_to_find:
            return key
    return None


def train_init(args, temp_base_dir="./temp"):
    """Initialize workspace with required directory structure
    
    Directory structure:
    temp_base_dir/
    ├── results/
    ├── EasySAT_0/
    ├── EasySAT_1/
    └── ...
    """
    # Create and clean results directory
    results_dir = os.path.join(temp_base_dir, "results")
    if os.path.exists(results_dir):
        clean_files(folder_path=results_dir, mode="all")
    else:
        os.makedirs(results_dir)
    
    # Create EasySAT_0 directory first
    easySAT_0_dir = os.path.join(temp_base_dir, "EasySAT_0")
    if not os.path.exists(easySAT_0_dir):
        os.makedirs(easySAT_0_dir)
    
    # Copy original EasySAT files to EasySAT_0
    copy_folder('./examples/EasySAT/original_EasySAT', 1, mode='eval', target_folder=easySAT_0_dir)
    clean_files(folder_path=easySAT_0_dir, mode="exe")
    
    # Create other EasySAT directories
    for i in range(1, args.batch_size):
        easySAT_dir = os.path.join(temp_base_dir, f"EasySAT_{i}")
        if not os.path.exists(easySAT_dir):
            os.makedirs(easySAT_dir)
        copy_folder(easySAT_0_dir, 1, mode='eval', target_folder=easySAT_dir)
    
    return

def check_reIteration(round, best_result_dict, baseline):
    # True: restart the prompt to avoid terrible functions; False: no need to restart
    if round <= 1: return False
    best_results = next(iter(best_result_dict.values()))
    if best_results[0] < baseline['time'] or best_results[2] < baseline['PAR-2']:
        return False
    return True


def extract_json(text):
    ''' extract json txt from first `{` & `}` if exist. '''
    stack = []
    json_start = None
    for i, char in enumerate(text):
        if char == '{':
            if json_start is None:
                json_start = i
            stack.append(char)
        elif char == '}':
            if stack: stack.pop()
            if not stack:
                json_text = text[json_start:i+1]
                return json_text
    return None

def parse_txt_to_dict(txt):
    '''parse `json file format` to `dict`. More flexible than json.load(.) / eval(.) '''
    txt = txt.strip('{}\n ')
    dict_result = {}
    key, value = "", ""
    is_key = True
    for char in txt:
        if char == ':' and is_key:
            is_key = False
        elif char == ',' and not is_key:
            key = key.strip().strip('"')
            value = value.strip()
            dict_result[key] = value
            key, value = "", ""
            is_key = True
        elif is_key:
            key += char
        else:
            value += char
    if key or value:
        key = key.strip().strip('"')
        value = value.strip()
        dict_result[key] = value
    return dict_result

def decodeRawJsonAnswer(raw_text):
    json_str = extract_json(text=raw_text)
    try: # Fault-tolerance for llm
        data = json.loads(json_str)
    except:
        data = parse_txt_to_dict(json_str)
    return data

def sanitize_filename(filename):
    ''' filter invalid character for filename '''
    illegal_chars = r'[\\/*?:"<>|\.\s]'
    safe_filename = re.sub(illegal_chars, '_', filename)
    return safe_filename


def save_top_k_results_to_folder(results: Dict, args, formatted_date_time: str, k: int = 2, results_base_dir: str = None):
    """Save top-k results to a dedicated folder"""
    # Use results_base_dir if provided (new unified structure), otherwise use legacy format
    if results_base_dir:
        top_k_folder = os.path.join(results_base_dir, "top_k_results")
    else:
        # Legacy format for backward compatibility
        legacy_dir = f"./{args.task}_search_{formatted_date_time}/"
        top_k_folder = os.path.join(legacy_dir, "top_k_results")
        os.makedirs(legacy_dir, exist_ok=True)
        os.makedirs(top_k_folder, exist_ok=True)
    
    # Sort results by PAR-2 (excluding baseline) - smaller is better
    sorted_results = sorted(
        [(global_id, data) for global_id, data in results.items() if global_id != "0"],
        key=lambda x: x[1].get('PAR-2', float('inf'))
    )
    
    # Save top-k results
    top_k_results = sorted_results[:k]
    
    # Create top-k folder
    os.makedirs(top_k_folder, exist_ok=True)
    
    # Save baseline first (rank 0)
    if "0" in results:
        baseline_file = os.path.join(top_k_folder, "baseline_id_0.json")
        with open(baseline_file, 'w') as f:
            json.dump({
                'global_id': "0",
                'rank': 0,
                'PAR-2': results["0"].get('PAR-2', float('inf')),
                'time': results["0"].get('time', 0),
                'prompt': results["0"].get('prompt', ''),
                'extra_params': results["0"].get('extra_params', {}),
                'task_description': results["0"].get('task_description', ''),
                'modification_direction': results["0"].get('modification_direction', ''),
                'is_baseline': True
            }, f, indent=4, ensure_ascii=False)
    
    # Save each top-k result
    for i, (global_id, result) in enumerate(top_k_results):
        result_file = os.path.join(top_k_folder, f"top_{i+1}_id_{global_id}.json")
        with open(result_file, 'w') as f:
            json.dump({
                'global_id': global_id,
                'rank': i + 1,
                'PAR-2': result.get('PAR-2', float('inf')),
                'time': result.get('time', 0),
                'prompt': result.get('prompt', ''),
                'extra_params': result.get('extra_params', {}),
                'task_description': result.get('task_description', ''),
                'modification_direction': result.get('modification_direction', ''),
                'is_baseline': False
            }, f, indent=4, ensure_ascii=False)
    
    # Save summary
    summary_file = os.path.join(top_k_folder, "summary.json")
    summary = {
        'search_task': args.task,
        'search_time': formatted_date_time,
        'total_results': len(results),
        'top_k': k,
        'baseline_PAR2': results.get("0", {}).get('PAR-2', float('inf')),
        'baseline_time': results.get("0", {}).get('time', 0),
        'top_k_results': [
            {
                'rank': i + 1,
                'global_id': global_id,
                'PAR-2': result.get('PAR-2', float('inf')),
                'time': result.get('time', 0),
                'prompt': result.get('prompt', '')  # Include prompt in summary
            }
            for i, (global_id, result) in enumerate(top_k_results)
        ]
    }
    with open(summary_file, 'w') as f:
        json.dump(summary, f, indent=4, ensure_ascii=False) 
        print(f"Saved baseline and top-{k} results to {top_k_folder}")
    return os.path.dirname(top_k_folder), top_k_folder

def get_search_folder_name(args, formatted_date_time: str) -> str:
    """Generate search-specific folder name"""
    return f"./{args.task}_search_{formatted_date_time}/"

def load_top_k_results(search_folder: str) -> Dict:
    """Load top-k results from search folder"""
    top_k_folder = os.path.join(search_folder, "top_k_results")
    if not os.path.exists(top_k_folder):
        raise FileNotFoundError(f"Top-k results folder not found: {top_k_folder}")
    
    results = {}
    
    # Load summary first
    summary_file = os.path.join(top_k_folder, "summary.json")
    if os.path.exists(summary_file):
        with open(summary_file, 'r') as f:
            summary = json.load(f)
            print(f"Loaded search summary: {summary['search_task']}, {summary['total_results']} total results")
    
    # Load individual result files (including baseline)
    for filename in os.listdir(top_k_folder):
        if (filename.startswith("top_") or filename.startswith("baseline_")) and filename.endswith(".json"):
            result_file = os.path.join(top_k_folder, filename)
            with open(result_file, 'r') as f:
                result_data = json.load(f)
                global_id = result_data['global_id']
                results[global_id] = {
                    'prompt': result_data['prompt'],
                    'time': result_data['time'],
                    'PAR-2': result_data['PAR-2'],
                    'extra_params': result_data.get('extra_params', {}),
                    'task_description': result_data.get('task_description', ''),
                    'modification_direction': result_data.get('modification_direction', ''),
                    'is_baseline': result_data.get('is_baseline', False),
                    'rank': result_data.get('rank', 0)
                }
    
    return results

