import glob
import re
from src.xlogominiprog.translator import python2codejs, extract_python_code_from_text
from src.xlogomini.emulator.executor import execute
from src.xlogomini.components.code.xlogo_code import Code
from src.xlogomini.components.task import Task
import json
from tqdm import tqdm
from multiprocessing import Pool, cpu_count, Manager


def process_chat(chat):
    # Extract relevant information from the chat
    task_json = chat['task_json']
    if 'constraints' not in task_json:
        if 'constraints' not in chat:
            cons_json = chat['cons_json']
        else:
            cons_json = chat['constraints']
    else:
        cons_json = task_json['constraints']
    if 'goal' not in task_json.keys() or task_json['goal'] is None:
        task_json['goal'] = []
    task_json['constraints'] = cons_json
    code_json = chat['code_json']

    model_output = chat['output'] if '[/INST]' not in chat['output'] else chat['output'][
                                                                          chat['output'].index('[/INST]'):]
    task_ascii = str(Task.init_from_json(task_json))

    # Step 1: Check code format
    try:
        python_code = extract_python_code_from_text(model_output)
        model_codejs = python2codejs(python_code)
        Code(model_codejs)  # Initialize to check for errors
    except Exception as e:
        print("Code format error:", e)
        # print("Model output:", model_output)
        result = {
            # original data
            'task_json'            : task_json,
            'task_ascii'           : task_ascii,
            'code_json'            : code_json,
            'constraints'          : cons_json,
            # feedback from emulator
            'code_format_error'    : True,
            'code_format_error_msg': str(e),
            'execution_error'      : True,
            'execution_error_msg'  : str(e),
            'success'              : False,
            # model input and output
            'prompt'               : chat,
            'model_output'         : model_output,
            'model_codejs'         : None,
        }
        return result

    # Step 2: Execute the task
    try:
        exec_result = execute(task_json=task_json, code_json=model_codejs)
    except Exception as e:
        raise ValueError(f"Got Execution Error: {e}")

    # Step 3: No format error, no execution error, then check the execution result
    success = all([exec_result['goal_ok'], exec_result['cons_ok'], not exec_result['crashed']])
    result = {
        # original data
        'task_json'            : task_json,
        'task_ascii'           : task_ascii,
        'code_json'            : code_json,
        'constraints'          : cons_json,
        # feedback from emulator
        'code_format_error'    : False,
        'code_format_error_msg': None,
        'execution_error'      : False,
        'execution_error_msg'  : None,
        'exec_res'             : exec_result,
        'success'              : success,
        # model input and output
        'prompt'               : chat,
        'model_output'         : model_output,
        'model_codejs'         : model_codejs,
    }

    return result


def eval_model_parallel(path: str, n_samples=None):
    outputs = json.load(open(path, 'r'))
    if n_samples is not None:
        outputs = outputs[:n_samples]

    manager = Manager()
    pool = Pool(processes=cpu_count())
    print("Using", cpu_count(), "cores")

    # Creating the progress bar
    pbar = tqdm(total=len(outputs))

    # List to store async results
    async_results = []

    # Enqueue jobs
    for output in outputs:
        async_result = pool.apply_async(process_chat, args=(output,), callback=lambda result: pbar.update())
        async_results.append(async_result)

    # Collect results
    results = [async_result.get() for async_result in async_results]

    pool.close()
    pool.join()
    pbar.close()

    # Summarize the results
    summary = calculate_summary(results)
    print(summary)
    return results, summary


def eval_model(path: str, n_samples=None):
    outputs = json.load(open(path, 'r'))

    if n_samples is not None:
        outputs = outputs[:n_samples]

    results = []

    for chat in tqdm(outputs):
        task_json = chat['task_json']
        if 'constraints' not in task_json:
            if 'constraints' not in chat:
                cons_json = chat['cons_json']
            else:
                cons_json = chat['constraints']
        else:
            cons_json = task_json['constraints']
        if 'goal' not in task_json.keys() or task_json['goal'] is None:
            task_json['goal'] = []
        task_json['constraints'] = cons_json
        code_json = chat['code_json']

        model_output = chat['output'] if '[/INST]' not in chat['output'] else chat['output'][
                                                                              chat['output'].index('[/INST]'):]

        task_ascii = str(Task.init_from_json(task_json))

        # Step 1: Check code format
        try:
            python_code = extract_python_code_from_text(model_output)
            model_codejs = python2codejs(python_code)
            Code(model_codejs)  # Initialize to check for errors
        except Exception as e:
            print("Code format error:", e)
            print("Model output:", model_output)
            results.append({
                # original data
                'task_json'            : task_json,
                'task_ascii'           : task_ascii,
                'code_json'            : code_json,
                'constraints'          : cons_json,
                # feedback from emulator
                'code_format_error'    : True,
                'code_format_error_msg': str(e),
                'execution_error'      : True,
                'execution_error_msg'  : str(e),
                'success'              : False,
                # model input and output
                'prompt'               : chat,
                'model_output'         : model_output,
                'model_codejs'         : None,
            })
            continue

        # Step 2: Execute the task
        try:
            exec_result = execute(task_json=task_json, code_json=model_codejs)
        except Exception as e:
            raise ValueError(f"Got Execution Error: {e}")
            results.append({
                # original data
                'task_json'            : task_json,
                'task_ascii'           : task_ascii,
                'code_json'            : code_json,
                'constraints'          : cons_json,
                # feedback from emulator
                'code_format_error'    : False,
                'code_format_error_msg': None,
                'execution_error'      : True,
                'execution_error_msg'  : str(e),
                'success'              : False,
                # model input and output
                'prompt'               : chat,
                'model_output'         : model_output,
                'model_codejs'         : model_codejs,
            })
            print("Execution error", e)
            continue

        # Step 3: No format error, no execution error, then check the execution result
        success = all([exec_result['goal_ok'], exec_result['cons_ok'], not exec_result['crashed']])
        results.append({
            # original data
            'task_json'            : task_json,
            'task_ascii'           : task_ascii,
            'code_json'            : code_json,
            'constraints'          : cons_json,
            # feedback from emulator
            'code_format_error'    : False,
            'code_format_error_msg': None,
            'execution_error'      : False,
            'execution_error_msg'  : None,
            'exec_res'             : exec_result,
            'success'              : success,
            # model input and output
            'prompt'               : chat,
            'model_output'         : model_output,
            'model_codejs'         : model_codejs,
        })

    # get summary
    summary = calculate_summary(results)
    print(summary)

    return results, summary


def calculate_summary(results):
    summary = {
        'n_total'         : len(results),
        'n_format_correct': 0,
        'n_no_crash'      : 0,
        'n_goal_ok'       : 0,
        'n_cons_ok'       : 0,
        'n_correct'       : 0,
    }

    for res in results:
        if not res['code_format_error']:
            summary['n_format_correct'] += 1

            if not res['exec_res'].get('crashed', True):
                summary['n_no_crash'] += 1
                if res['exec_res'].get('goal_ok', False):
                    summary['n_goal_ok'] += 1
                if res['exec_res'].get('cons_ok', False):
                    summary['n_cons_ok'] += 1

        if res['success']:
            summary['n_correct'] += 1

    # Calculate rates
    if summary['n_total'] > 0:
        summary['code_format_correct_rate'] = summary['n_format_correct'] / summary['n_total']
        summary['success_rate'] = summary['n_correct'] / summary['n_total']
    else:
        summary['code_format_correct_rate'] = 0
        summary['success_rate'] = 0

    if summary['n_format_correct'] > 0:
        summary['no_crash_rate'] = summary['n_no_crash'] / summary['n_total']
    else:
        summary['no_crash_rate'] = 0

    if summary['n_no_crash'] > 0:
        summary['goal_ok_rate'] = summary['n_goal_ok'] / summary['n_no_crash']
        summary['cons_ok_rate'] = summary['n_cons_ok'] / summary['n_no_crash']
    else:
        summary['goal_ok_rate'] = 0
        summary['cons_ok_rate'] = 0

    return summary


def parse_file_name(path):
    """
    Assume the path is in the format:
    path = './results/ft/nl/v1-train-92k/models--meta-llama--Llama-2-7b-hf-ds=v1-train-92k-ep=8-rank=32-alpha=128-0328/epoch_6_testds=v1-test-1k_topp=1_temperature=0.json'
    """
    patterns = {
        'prompt_template': r'./results/ft/(?P<prompt_template>[^/]+)/',
        'train_set'      : r'/(?P<train_set>[^/]+)/Meta',
        'model_name'     : r'/Meta-(?P<model_name>[^_]+)-',
        'epoch'          : r'epoch_(?P<epoch>\d+)_',
        'test_set'       : r'testds=(?P<test_set>[^_]+)_',
        'sample'         : r'sample=(?P<sample>\d+).json',
        'top_p'          : r'topp=(?P<top_p>\d+)_',
        'temperature'    : r'temperature=(?P<temperature>\d+)_',
        'lora_rank'      : r'rank=(?P<lora_rank>\d+)-',
        'lora_alpha'     : r'alpha=(?P<lora_alpha>\d+)-',
    }
    extracted_info = {}

    for key, pattern in patterns.items():
        match = re.search(pattern, path)
        if match:
            extracted_info[key] = match.group(key)
        else:
            raise ValueError(f"Pattern {pattern} not found in path {path}")

    return extracted_info


def evaluate_folder(files):
    for file in glob.glob(files):
        print(f"Evaluating File:\n\t{file}")
        results, summary = eval_model_parallel(file)
        print(f"Summary:\n\t{json.dumps(summary, indent=4)}")
        print(f"\n========================================\n")
