import os
import json
import sys
import re
from argparse import ArgumentParser
from tqdm import tqdm
import concurrent.futures

sys.path.append("./")
# from utils.open_source import generate
from utils.gpt import generate

VLM_JUDGE_PROMPT = """You are an expert web developer. Your task is to determine if a given web development task is feasible based on the provided HTML code.

**HTML Code:**
```html
{html_code}
```

**Task Description:**
{task_instruction}

**Expected Result:**
{expected_result}

Based on the code, can the task be completed to achieve the expected result? The result is only achievable if all elements required for the interaction are present in the HTML.

Please begin your response with the analysis. Then provide your answer in the following JSON format. 

```json
{{
  "feasible": <true or false>
}}
```
"""

total_tokens = {
    "prompt_token_count": 0,
    "candidates_token_count": 0,
    "thoughts_token_count": 0
}

def parse_args():
    parser = ArgumentParser()
    parser.add_argument("--data_path", type=str, default="WebDevJudge_Unit/data/3_generate_interact_description_gemini_prepare.jsonl", help="Path to a file containing a list of directories to process.")
    parser.add_argument("--output_dir", type=str, default="WebDevJudge_Unit/outputs")
    parser.add_argument("--model", type=str, default="DeepSeek-V3-0324")
    parser.add_argument("--label_path", type=str, default="WebDevJudge_Unit/data/label.json", help="Path to a file containing the labels.")
    return parser.parse_args()

def construct_prompt(html_code, task_instruction, expected_result):
    prompt_text = VLM_JUDGE_PROMPT.format(html_code=html_code, task_instruction=task_instruction, expected_result=expected_result)
    return [{"role": "user", "content": prompt_text}]

def extract_html(response):
    entry = response[-1]
    if entry["role"] == "assistant":
        content = entry.get("content")
        if type(content) == list:
            content = content[0]['text']
        if isinstance(content, str):
            match = re.search(r"```html(.*?)```", content, re.DOTALL)
            if match:
                return match.group(1).strip()
            else:
                print(f"No match found for {content}")
                return content
    return None

def extract_and_parse_json(response_str):
    match = re.search(r'```json(.*?)```', response_str, re.DOTALL)
    if match:
        json_str = match.group(1).strip()
        reasoning = response_str[:match.start()].strip()
        parsed_json = json.loads(json_str)
        parsed_json['reasoning'] = reasoning
        return parsed_json
    raise ValueError("No JSON found in response")

def process_item(item, model):
    html_code = extract_html(item["res_inference"])
    if not html_code:
        print(f"Skipping item {item.get('question_id')} due to missing HTML code.")
        return None

    task_instruction = item['task']['task']
    expected_result = item['task']['expected_result']
    question_id = item["question_id"]
    task_id = item['task']['id']

    prompt = construct_prompt(html_code, task_instruction, expected_result)
    
    generate_config = {"max_output_tokens": 8192, "temperature": 0.0}
    response = None
    metadata_response = None
    max_retries = 5
    for i in range(max_retries):
        generate_config["temperature"] = 0.0 + 0.1 * i
        response, metadata_response = generate(model=model, messages=prompt, generation_config=generate_config)
        try:
            parsed_response = extract_and_parse_json(response)
            return {
                "question_id": question_id,
                "task_id": task_id,
                "model_response": parsed_response,
                "raw_response": response,
                "metadata": metadata_response
            }
        except Exception as e:
            if i < max_retries - 1:
                print(f"Attempt {i+1} failed to parse JSON for item {question_id}, retrying...")
            else:
                print(f"Item {question_id} failed to parse JSON after {max_retries} attempts.")
                return {
                    "question_id": question_id,
                    "task_id": task_id,
                    "model_response": {"feasible": False, "reasoning": "Failed to parse model output."},
                    "raw_response": response,
                    "metadata": metadata_response
                }

def main(args):
    with open(args.data_path, "r") as f:
        items = [json.loads(line) for line in f]

    print(f"Processing {len(items)} items")
    os.makedirs(args.output_dir, exist_ok=True)

    results = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=128) as executor:
        future_to_item = {executor.submit(process_item, item, args.model): item for item in items}
        for future in tqdm(concurrent.futures.as_completed(future_to_item), total=len(future_to_item)):
            try:
                result = future.result()
                if result:
                    results.append(result)
                    
                    metadata = result["metadata"]
                    total_tokens["prompt_token_count"] += metadata["prompt_token_count"]
                    total_tokens["candidates_token_count"] += metadata["candidates_token_count"]
                    total_tokens["thoughts_token_count"] += metadata.get("thoughts_token_count", 0)
            except Exception as e:
                print(f"Error processing item: {e}")

    print(f"Total tokens: {total_tokens}")
    output_path = os.path.join(args.output_dir, f"{args.model}_judge.jsonl")
    with open(output_path, "w") as f:
        for result in results:
            f.write(json.dumps(result, ensure_ascii=False) + "\n")

def evaluate(args):
    with open(args.label_path, "r") as f:
        labels = json.load(f)

    with open(os.path.join(args.output_dir, f"{args.model}_judge.jsonl"), "r") as f:
        results = [json.loads(line) for line in f]

    res = {}
    for result in results:
        res[f"{result['question_id']}_{result['task_id']}"] = 1 if result["model_response"]["feasible"] else 0
    
    TP = 0
    FP = 0
    TN = 0
    FN = 0
    for key, value in res.items():
        if value == 1 and labels[key] == 1:
            TP += 1
        elif value == 1 and labels[key] == 0:
            FP += 1
        elif value == 0 and labels[key] == 0:
            TN += 1
        elif value == 0 and labels[key] == 1:
            FN += 1
    print(f"TP: {TP}, FP: {FP}, TN: {TN}, FN: {FN}")
    print(f"Precision: {round((TP / (TP + FP)) * 100, 1)}%")
    print(f"Recall: {round((TP / (TP + FN)) * 100, 1)}%")
    print(f"F1-score: {round((2 * TP / (2 * TP + FP + FN)) * 100, 1)}%")
    print(f"Accuracy: {round(((TP + TN) / (TP + TN + FP + FN)) * 100, 1)}%")
    return res

if __name__ == "__main__":
    args = parse_args()
    print("Model: ", args.model)
    print("Output Dir: ", args.output_dir)
    main(args)
    evaluate(args)