#from workflowsonnet1 import Workflow
from utils.prompts3 import GameOf24, CheckmateInOne, WordSorting, P3_Test, Sonnet
from joblib import Parallel, delayed

import re
import json
import argparse
import os
import datetime
from tqdm import tqdm
import time

from utils.utils import execute_py_code


def is_white_move(san):
    # Find all matches of the pattern (\d+)\.\s*
    matches = list(re.finditer(r'(\d+)\.\s*', san))
    
    # If there are matches, get the last one
    if matches:
        last_match = matches[-1]
        # Get the position of the last match
        start_pos = last_match.end()
        
        # Check if there are any SAN moves after the last match
        trailing_moves = san[start_pos:].strip()  # Check the part after the last match
        if trailing_moves:
            return "Black"
        else:
            return "White"
    else:
        return "Malformed"

def eval_results(benchmark_path, eval_path):
    inputs = []
    targets = []
    predicts = []

    with open(benchmark_path) as f:
        lines = f.readlines()
        for line in lines:
            target = json.loads(line)['target']
            inp = json.loads(line)['input']
            if task == 'GameOf24':
                target = eval(target)
            targets.append(target)
            inputs.append(inp)

    with open(eval_path) as f:
        lines = f.readlines()
        for i, line in enumerate(lines):
            result = json.loads(line)['result']
            result = result.strip()
            if task == 'GameOf24': # TODO: process when doing the inference
                if ':' in result:
                    result = result.split(':', 1)[1].strip()
                try:
                    result = eval(result)
                    print(result)
                except Exception as e:
                    pass
            elif task == 'P3_Test':
                inp = inputs[i]
                full_code = f"from typing import *\n{inp}\n{result}\nanswer = sol()\nprint(sat(answer))"
                try:
                    result, code = execute_py_code(full_code)
                except Exception as e:
                    import pdb;pdb.set_trace()
                if "True" in result:
                    result = ""
            predicts.append(result)
    
    cnt = 0
    correct = 0
    for tgt, pred in zip(targets, predicts):
        if tgt == pred:
            correct += 1
        cnt += 1
    print(f"Accuracy: {correct/cnt}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--task_name', type=str, default='GameOf24', choices=[
        'GameOf24', 'CheckmateInOne', 'WordSorting', 'P3_Test', 'Sonnet'
    ])
    parser.add_argument('--model_id', type=str, default='gpt-4o', help='Input model id, or local model path')
    parser.add_argument('--api_key', default=None, type=str, help='input your api key here')
    parser.add_argument('--api_base', default=None, type=str, help='input your api base here')
    parser.add_argument('--eval_path', default=None, type=str, help='whether only evaluate the outputs on the given path')
    parser.add_argument('--naive', action='store_true', help='whether to add test to naive model')
    parser.add_argument('--has_eval', action='store_true', help='whether has a separate evaluator')
    parser.add_argument('--parallel', action='store_true', help='whether to run parallely')
    args = parser.parse_args()

    task = args.task_name
    model_id = args.model_id
    api_key = args.api_key
    api_base = args.api_base
    eval_path = args.eval_path

    if task == 'GameOf24':
        if 'Qwen' in model_id:
            from workflows.workflowQwenG24 import Workflow
        else:
            from workflows.workflowg2 import Workflow
    elif task == 'CheckmateInOne':
        from workflows.workflowcm1 import Workflow
    elif task == 'WordSorting':
        from workflows.workflowsort2 import Workflow
    elif task == 'P3_Test':
        from workflows.workflow3 import Workflow
    elif task == 'Sonnet':
        from workflows.workflowsonnet1 import Workflow
    else:
        exit(1)
    print(task)

    output_dir = 'outputs'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    prompts = {
        'GameOf24': GameOf24,
        'CheckmateInOne': CheckmateInOne,
        'WordSorting': WordSorting,
        'P3_Test': P3_Test,
        'Sonnet': Sonnet,
    }
    
    benchmark_dir = './benchmarks'
    
    user_prompt = prompts[task]
    benchmark_path = os.path.join(benchmark_dir, f"{task}.jsonl")

    if eval_path:
        task = os.path.basename(eval_path).split('_')[1]
        benchmark_path = os.path.join(benchmark_dir, f"{task}.jsonl")
        eval_results(benchmark_path, eval_path)
        exit()


    workflow = Workflow(model_id=model_id, api_key= api_key, api_base=api_base)

    now = datetime.datetime.now()
    timestamp_str = now.strftime("%Y-%m-%d-%H:%M:%S")
    output_path = f'{output_dir}/ROT_{task}_{timestamp_str}.jsonl'

    cnt = 0
    correct = 0
    metric_dict = {
        "success": 0,
        "non-sat": 0,
        "error": 0,
        "all_counter": 0,
        "counter": 0,
        "any_success": 0,

        "corrected": 0,
        "wronged": 0,
        "modify_counter": 0,

        "tp": 0,
        "tn": 0,
        "fp": 0,
        "fn": 0,
        "unsure": 0,

        "naive": 0,
        "selected": 0,
    }
    with open(benchmark_path) as fin:
        lines = fin.readlines()
    start = time.time()
    with open(output_path, 'w', encoding='utf-8') as fout:

        batch_size = 6
        all_outputs = []

        for i in tqdm(list(range(0, len(lines), batch_size))):
            targets = [json.loads(line)['target'] for line in lines[i:i+batch_size]]
            queries = []
            questions = []
            for q in [json.loads(line)['input'] for line in lines[i:i+batch_size]]:
                if task == 'WordSorting':
                    q = q.replace('Sort the following words alphabetically: List:', '')
                question = f"{user_prompt}{q}"
                if task == "CheckmateInOne":
                    side = is_white_move(q)
                    if side != "Malformed":
                        question = f"{question}\nNow is {side}'s turn."
                queries.append(q)
                questions.append(question)

            #workflow.run(question=questions[1], query=queries[1])

            cnt += len(queries)
                
            # Run the model in parallel but have 0.1 second delay between each job
            if not args.parallel:
                outputs = [workflow.run(question=question, query=query, target=target) for question, query, target in zip(questions, queries, targets)]
            else:
                outputs = Parallel(n_jobs=6, verbose=100, prefer="threads")(
                    delayed(workflow.run)(
                        question=question,
                        query=query,
                        target=target,
                    )
                    for question, query, target in zip(questions, queries, targets)
                )

            # Append the outputs to OUTPUTS
            all_outputs.extend(outputs)

            #for i, line in enumerate(tqdm(lines)):
            #    if i == 903:
            #        continue
            #    cnt += 1
            #    query = json.loads(line)['input']
            #    question = f"{user_prompt}{query}"
            #    result, d, num_attemps = workflow.run(question, query)

            for d in outputs:
                metric_dict['success'] += d['success']
                metric_dict['non-sat'] += d['non-sat']
                metric_dict['error'] += d['error']
                metric_dict['all_counter'] += d['num_attempts']

                metric_dict['corrected'] += d['corrected']
                metric_dict['wronged'] += d['wronged']
                metric_dict['modify_counter'] += d['corrected'] + d['wronged'] + d['others']
                if metric_dict['modify_counter'] == 0:
                    metric_dict['modify_counter'] = 1

                if args.naive:
                    metric_dict['naive'] += (1 if d['naive'] else 0)
                if not args.has_eval:
                    metric_dict['selected'] += (1 if d['selected'] else 0)
                    metric_dict['tp'] += d['tp']
                    metric_dict['tn'] += d['tn']
                    metric_dict['fp'] += d['fp']
                    metric_dict['fn'] += d['fn']
                    metric_dict['unsure'] += d['unsure']

                if d['success'] != 0:
                    metric_dict["any_success"] += 1
                metric_dict['counter'] += 1
                #result_dict = {'input':query, 'result':result}
                d['elapsed'] = time.time() - start
                json_str = json.dumps(d)
                fout.write(json_str + '\n')
                fout.flush()

            print(output_path)
            print(f"Metrics:")

            print(f"Accuracy-any: {metric_dict['any_success']/metric_dict['counter']}")
            print(f"All accuracy: {metric_dict['success']/metric_dict['all_counter']}")
            print(f"All non-sat: {metric_dict['non-sat']/metric_dict['all_counter']}")
            print(f"All error: {metric_dict['error']/metric_dict['all_counter']}")

            print(f"Corrected: {metric_dict['corrected']/metric_dict['modify_counter']}")
            print(f"Wronged: {metric_dict['wronged']/metric_dict['modify_counter']}")

            if args.naive:
                print(f"Accuracy0: {metric_dict['naive']/metric_dict['counter']}")

            if not args.has_eval:
                print(f"Accuracy: {metric_dict['selected']/metric_dict['counter']}")
                print(f"Eval Acc: {(metric_dict['tp']+metric_dict['tn'])/(metric_dict['tn']+metric_dict['tp']+metric_dict['fn']+metric_dict['fp'])}")
                print(f"Eval Prec: {metric_dict['tp']/(metric_dict['tp'] + metric_dict['fp'])}")
                print(f"Eval Recall: {metric_dict['tp']/(metric_dict['tp'] + metric_dict['fn'])}")
                print(f"Eval Neg Recall: {metric_dict['tn']/(metric_dict['fp'] + metric_dict['tn'])}")
                #print(f"Unsure predict: {metric_dict['unsure']/metric_dict['all_counter']}")



    #eval_results(benchmark_path, output_path)
