import sys
import os
import shutil
import pandas as pd
from tqdm import tqdm
import json
import re
import concurrent.futures
import traceback

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from utils.azure_openai_api import get_openai_llm_response, get_openai_lrm_response
from utils.formula_utils import FormulaParser
from utils.exact_match_utils import judge_exact_match



def get_llm_response(model_name, prompt):
    if model_name == "gpt-4o":
        return get_openai_llm_response(prompt, model='gpt-4o')
    elif model_name == "gpt-4o-mini":
        return get_openai_llm_response(prompt, model='gpt-4o-mini')
    elif model_name == "o1":
        return get_openai_lrm_response(prompt, model='o1')
    else:
        raise ValueError(f"Unknown LLM: {model_name}")


def extract_solution(solution_str):
    # Regular expression to find the last occurrence of <answer>...</answer>
    answer_pattern = r'<answer>(.*?)</answer>'
    matches = re.findall(answer_pattern, solution_str, re.DOTALL)
    
    if matches:
        return matches[-1].strip()
    else:
        return ''


def extract_reasoning_and_answer(solution_str):
    # Regular expression to find the last occurrence of <answer>...</answer>
    reasoning_pattern = r'<think>(.*?)</think>'
    answer_pattern = r'<answer>(.*?)</answer>'

    reasoning_matches = re.findall(reasoning_pattern, solution_str, re.DOTALL)
    answer_matches = re.findall(answer_pattern, solution_str, re.DOTALL)
    
    if reasoning_matches and answer_matches:
        return reasoning_matches[-1].strip(), answer_matches[-1].strip()
    else:
        return '', ''


def calculate_answer_score(pred, gold_answer, table_data, task):
    try:
        if task == "formula":
            pred_formula = pred
            pred_formula_parser = FormulaParser(pred_formula)
            pred_results = pred_formula_parser.execute(table_data)
        elif task == "text":
            pred_results = pred
        else:
            raise ValueError(f"Unknown task: {task}")

        # format into list for judge_exact_match
        pred_results = [str(result).strip() for result in str(pred_results).split(', ')]
        gold_results = gold_answer.tolist()

        answer_score = 1 if judge_exact_match(pred_results, gold_results) else 0

    except Exception as e:
        print(f"[Error] Evaluation error: {e}")
        pred_results = []
        answer_score = 0

    return answer_score, pred_results


# main
def generate_reasoning_data(task, dataset, MAX_WORKERS=20):

    data_path = f"data/processed_data/{task}/qwen/{dataset}/train.parquet"
    df = pd.read_parquet(data_path)

    test_data = [df.iloc[i] for i in range(len(df))]
    
    output_dir = f"data/distill_data/{task}/{dataset}"
    os.makedirs(output_dir, exist_ok=True)

    output_path = os.path.join(output_dir, 'distill_data.json')

    # for re-run
    if os.path.exists(output_path):
        with open(output_path, 'r') as f:
            distill_data = json.load(f)
    else:
        distill_data = [''] * len(test_data)

    prompts = [item['prompt'] for item in test_data]
    ground_truths = [item['reward_model']['ground_truth']['answer'] for item in test_data]
    table_data = [item['extra_info']['table'] for item in test_data]

    def worker(thread_id, D):
        if distill_data[thread_id] != '':
            return thread_id, distill_data[thread_id]

        # return thread_id, ''

        data_id = thread_id
        prompt = prompts[thread_id]
        ground_truth = ground_truths[thread_id]
        tmp_table_data = table_data[thread_id]

        try:
            prompt = '# Task\n' + prompt.split('# Task')[1].strip()
            prompt = prompt.split('<|im_end|>')[0].strip()

            if task == 'formula':
                prompt = prompt + f'\nHint: the answer is {ground_truth}. The execution result of the generated formula should be the same as the hint.'
                prompt = prompt + '\nAvoid relying on the hint in your reasoning.\n'
                prompt = prompt + '\nNow, write the spreadsheet formula with reasoning.\n<think>\n'
            elif task == 'text':
                prompt = prompt + f'\nHint: the answer is {ground_truth}. Your final answer should be the same as the hint.'
                prompt = prompt + '\nAvoid relying on the hint in your reasoning.\n'
                prompt = prompt + '\nNow, give the answer with reasoning.\n<think>\n'
            
            generated_text = get_llm_response('gpt-4o', prompt)

            answer_text = extract_solution(generated_text)
            if answer_text:
                try:
                    if task == 'formula':
                        pred = json.loads(answer_text)['formula']
                    else:
                        pred = json.loads(answer_text)['answer']
                    score, pred_results = calculate_answer_score(pred, ground_truth, tmp_table_data, task)

                except (json.JSONDecodeError, KeyError) as e:
                    print(f"[Error] JSON parsing error: {e}")
                    pred_results = []
                    score = 0
            else:
                pred_results = []
                score = 0
            
        except Exception as e:
            print(f"[Error] Evaluation error: {e}")
            traceback.print_exc()
            return thread_id, ''

        # print(generated_text)
        # correct

        reasoning, answer = extract_reasoning_and_answer(generated_text)
        solution_with_reasoning = f'<think>{reasoning}</think>\n<answer>{answer}</answer>'
        # print(ground_truth)
        # print(solution_with_reasoning)
        
        if score == 1:
            return thread_id, solution_with_reasoning
        else:
            return thread_id, ''


    error_count = 0
    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        futures = [executor.submit(worker, thread_id, D) for thread_id, D in enumerate(test_data)]
        pbar = tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc=f"[Distilling]")
        for future in pbar:
            try:
                thread_id, generated_text = future.result()
                distill_data[thread_id] = generated_text
                if generated_text == '':
                    error_count += 1
                # Print intermediate results
                pbar.set_description(f"[Distilling] Current Error Count: {error_count}")
            
            except Exception as e:
                traceback.print_exc()
    
    # Calculate and print final metrics
    print("-"*60)
    print(f"\nFinal Results:")
    print(f"Error Count: {error_count}")
    print("-"*60)
    
    # Save results
    with open(os.path.join(output_dir, f"distill_data.json"), "w") as f:
        json.dump(distill_data, f)




def main():

    datasets = ['wikitq', 'tabfact', 'finqa', 'hitab', 'multihiertt']

    tasks = ['formula', 'text']
    
    for task in tasks:
        for dataset in datasets:
            generate_reasoning_data(task, dataset)


if __name__ == "__main__":
    main()