import sys
import os
import argparse
import pandas as pd
from tqdm import tqdm
import json
import re
import concurrent.futures
import traceback

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from utils.azure_openai_api import get_openai_llm_response, get_openai_lrm_response
from utils.formula_utils import FormulaParser
from utils.exact_match_utils import judge_exact_match



def get_llm_response(model_name, prompt):
    if model_name == "gpt-4o":
        return get_openai_llm_response(prompt, model='gpt-4o')
    elif model_name == "gpt-4o-mini":
        return get_openai_llm_response(prompt, model='gpt-4o-mini')
    elif model_name == "o1":
        return get_openai_lrm_response(prompt, model='o1')
    else:
        raise ValueError(f"Unknown LLM: {model_name}")



def extract_solution(solution_str):
    # Regular expression to find the last occurrence of <answer>...</answer>
    answer_pattern = r'<answer>(.*?)</answer>'
    matches = re.findall(answer_pattern, solution_str, re.DOTALL)
    
    if matches:
        return matches[-1].strip()
    else:
        return ''



def calculate_answer_score(pred, gold_answer, table_data, task):
    try:
        if task == "formula":
            pred_formula = pred
            pred_formula_parser = FormulaParser(pred_formula)
            pred_results = pred_formula_parser.execute(table_data)
        elif task == "text":
            pred_results = pred
        else:
            raise ValueError(f"Unknown task: {task}")

        # format into list for judge_exact_match
        pred_results = [str(result).strip() for result in str(pred_results).split(', ')]
        gold_results = gold_answer.tolist()

        answer_score = 1 if judge_exact_match(pred_results, gold_results) else 0

    except Exception as e:
        print(f"[Error] Evaluation error: {e}")
        pred_results = []
        answer_score = 0

    return answer_score, pred_results


# main
def evaluate(model_name, task, dataset, with_reasoning, MAX_WORKERS=10):

    data_path = f"data/processed_data/{task}/qwen/{dataset}/test.parquet"
    df = pd.read_parquet(data_path)

    test_data = [df.iloc[i] for i in range(len(df))]
    
    if with_reasoning:
        output_dir = f"evaluation_outputs/zero_shot/{task}/{dataset}/{model_name}/wr"
    else:
        output_dir = f"evaluation_outputs/zero_shot/{task}/{dataset}/{model_name}/wor"
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, 'eval.json')

    # for re-run
    if os.path.exists(output_path):
        with open(output_path, 'r') as f:
            pre_eval = json.load(f)
            sample_data = pre_eval['sample_data']
    else:
        sample_data = [{}] * len(test_data)

    # for i in range(len(sample_data)):
    #     if i < 100:
    #         sample_data[i] = {}

    prompts = [item['prompt'] for item in test_data]
    ground_truths = [item['reward_model']['ground_truth']['answer'] for item in test_data]
    table_data = [item['extra_info']['table'] for item in test_data]

    def worker(thread_id, D):
        if sample_data[thread_id] != {}:
            return thread_id, sample_data[thread_id]
        # return thread_id, {}

        data_id = thread_id
        prompt = prompts[thread_id]
        ground_truth = ground_truths[thread_id]
        tmp_table_data = table_data[thread_id]

        try:
            prompt = '# Task\n' + prompt.split('# Task')[1].strip()
            prompt = prompt.split('<|im_end|>')[0].strip()

            if with_reasoning:
                if task == 'formula':
                    prompt = prompt + '\nNow, write the spreadsheet formula with reasoning.\n<think>\n'
                elif task == 'text':
                    prompt = prompt + '\nNow, give the answer with reasoning.\n<think>\n'
            else:
                prompt = prompt.replace("You first think about the reasoning process in the mind and then provides the user with the answer.", "You need to provide the user with the answer.")
                prompt = prompt.replace("Show your reasoning within <think> </think> tags.", "")
                prompt = prompt.replace("<think>\n[thinking process]\n</think>", "")
                if task == 'formula':
                    prompt = prompt + '\nNow, write the spreadsheet formula directly.\n<answer>\n'
                elif task == 'text':
                    prompt = prompt + '\nNow, give the answer directly.\n<answer>\n'
            
            generated_text = get_llm_response(model_name, prompt)

            answer_text = extract_solution(generated_text)
            if answer_text:
                try:
                    if task == 'formula':
                        pred = json.loads(answer_text)['formula']
                    else:
                        pred = json.loads(answer_text)['answer']
                    score, pred_results = calculate_answer_score(pred, ground_truth, tmp_table_data, task)
                    # if not score:
                    #     print('pred:', pred, 'ground_truth:', ground_truth, 'score:', score)
                except (json.JSONDecodeError, KeyError) as e:
                    print(f"[Error] JSON parsing error: {e}")
                    pred_results = []
                    score = 0
            else:
                pred_results = []
                score = 0
            
        except Exception as e:
            print(f"[Error] Evaluation error: {e}")
            traceback.print_exc()
            return thread_id, {}

        D = {
            'id': data_id,
            'prompt': prompt,
            'ground_truth': list(ground_truth),
            'response': generated_text,
            'pred_results': pred_results,
            'final_answer': answer_text,
            'score': score,
        }
        # print(generated_text)
        return thread_id, D


    error_count = 0
    current_accuracy = 0
    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        futures = [executor.submit(worker, thread_id, D) for thread_id, D in enumerate(test_data)]
        pbar = tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc=f"[Evaluating]")
        for future in pbar:
            try:
                thread_id, D = future.result()
                sample_data[thread_id] = D
                if D == {}:
                    error_count += 1
                # Print intermediate results
                scores = [D['score'] for D in sample_data if D != {}]
                current_accuracy = sum(scores) / len(scores) if len(scores) > 0 else 0
                pbar.set_description(f"[Evaluating] Current Accuracy: {current_accuracy:.4f}")
            
            except Exception as e:
                traceback.print_exc()
    
    # Calculate and print final metrics
    scores = [D['score'] for D in sample_data if D != {}]
    final_accuracy = sum(scores) / len(scores) if len(scores) > 0 else 0
    print("-"*60)
    print(f"\nFinal Results:")
    print(f"Accuracy: {final_accuracy:.4f}")
    print(f"Error Count: {error_count}")
    print("-"*60)
    
    # Save results
    evaluation_results = {
        "model_name": model_name,
        "accuracy": final_accuracy,
        "error_count": error_count,
        "total_samples": len(sample_data),
        "sample_data": sample_data
    }
    # print(evaluation_results)
    with open(os.path.join(output_dir, f"eval.json"), "w") as f:
        json.dump(evaluation_results, f)




def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="gpt-4o")
    parser.add_argument("--task", type=str, default="formula")
    parser.add_argument("--dataset", type=str, default="hitab")
    parser.add_argument("--with_reasoning", type=str, default='True')
    args = parser.parse_args()
    
    args.with_reasoning = True if args.with_reasoning.lower() == "true" else False

    evaluate(args.model_name, args.task, args.dataset, args.with_reasoning)


if __name__ == "__main__":
    main()