# Identified issues with the code output:
# 1. Programs have hardcoded answers
# 2. The program does not perform the correct objective. (e.g. returning a number when it should return a letter)
# 3. Syntax errors
# 4. Wrongly reading from the symbols
# 5. No output code
# 6. No output symbols
# 7. Using placeholders in the code
# 8. Return value which is input independent and is fully copied from a comment.

import ast
import re
import numpy as np
from typing import Dict, List, Tuple, Any
from ast import literal_eval
import argparse
from src.symbol_mapping import LLMNet, PromptedLLM
from src.utils import RawInput
from run_experiments import APIModel, get_dataset
from src.function_evaluation import python_eval
import json

def has_hardcoded_answers(model: APIModel, question: str, code: str) -> bool:
    """Check if the code contains hardcoded answers."""
    # Look for hardcoded/trivial code
    llmnet = PromptedLLM(
        model,
        """Examine the provided code which solves the question and determine if it is trivial or if the return value is hardcoded.
For instance, a typical hardcoded function consists of comments with a final return statement, so the code itself is not actually performing any computation.
The code should also be considered hardcoded if it introduces hardcoded values in the code which were not specified in the question or are not easy to determine.
The code can also be trivial if the performed computation is not actually used to produce the final output.
Code which is highly specific to a particular input (or whose logic is hardcoded) is not considered hardcoded as long as the final output is computed by the code.
This means that the logic of the code can itself be hardcoded (as in the input must be a specific value for the code to work) but the return value is not hardcoded, so the code is not trivial.
After examining the code carefully, please answer with 1 if the code is hardcoded/trivial and 0 otherwise."""
    )
    question_code = f"Question: {question}\nCode: {code}"
    return llmnet.forward(RawInput(text_input=question_code, image_input=None)) == "1"

def has_input_processing_errors1(model: APIModel, question: str, code: str) -> bool:
    """Check if the code has input processing errors."""
    llmnet = PromptedLLM(
        model,
        """Examine the provided code which solves the question and determine if it has any issues due to an **Erroneous Hardcoded Value**.
This type of issue occurs when a constant, string literal, number, or data structure directly embedded (hardcoded) in the code, which is intended to represent information from the question is factually incorrect, mistranscribed, or mismatched with the source data. The program's surrounding logic might be sound if this specific hardcoded value were correct.

**Instructions:**
1.  Carefully compare the values hardcoded in the code against the information presented in the question.
2.  Determine if there is a clear instance where a hardcoded value in the code is a direct but incorrect representation of a piece of information from the problem. If you cannot be absolutely certain, please err on the side of caution and answer "No".
3.  Answer "Yes" if the code contains this issue, otherwise answer "No".
"""
    )
    question_code = f"Question: {question}\nCode: {code}"
    return llmnet.forward(RawInput(text_input=question_code, image_input=None)) == "Yes"

def has_input_processing_errors2(model: APIModel, question: str, code: str, symbols=None) -> bool:
    """Check if the code has input processing errors."""
    llmnet = PromptedLLM(
        model,
        """You are an expert at code analysis and finding problems in code.

**Your Task:** Identify instances where code performs complex parsing of raw, unstructured data when it would have been easier to use manual human input or LLM-based extraction.

**Core Problem:** This issue arises when software programmatically parses unstructured raw input (e.g., free-form text, images, raw log files) where the specific information being extracted could have been trivially obtained and provided as direct input, either by a human or through a Large Language Model (LLM). This is particularly relevant in code generation scenarios where a model might generate parsing logic for data that could have been easily extracted beforehand without any custom code.

**Key Distinction:** This is not about condemning all in-code parsing. Parsing is essential for well-defined, structured formats (e.g., JSON, XML, CSV). The focus here is on situations where the *effort and complexity* of the in-code parsing significantly outweigh the effort of manually or LLM-extracting the target information from the raw source.

**Example Scenarios:**

* **Image-to-Text:** Code utilizes a library like OpenCV to perform Optical Character Recognition (OCR) on an image to extract a short, easily readable string (e.g., a name, a date, a simple label). A human could have instantly transcribed this text, or an LLM could have extracted it with a simple prompt.
* **Log File Snippet:** Code implements custom string matching and manipulation to find a specific error code or timestamp within a small, human-readable log snippet. A human could quickly identify this information, or an LLM could extract it.

**Non-Example Scenario:**

* **Large CSV Processing:** Code parses string representing a CSV file with thousands of rows and numerous columns to perform data aggregation. While a human *could* theoretically do this, it's not trivial and is highly error-prone and inefficient compared to programmatic parsing. This is appropriate use of in-code parsing.
* **Pre-extracted and Hardcoded Data:** A long paragraph of text describes a user's inventory: "I have 3 apples that I bought from the store. My friend gave me 2 more apples yesterday, which was very kind. I also found 5 bananas in the kitchen, and my brother has 7 oranges."
  Code for evaluation:
  apples_from_store = 3
  apples_from_friend = 2
  total_apples_i_have = apples_from_store + apples_from_friend
Condition 1 is not met. The provided Python code does not perform any parsing of the raw paragraph of text.
---

**Issue Definition**

The code exhibits this issue if **both** of the following conditions are met:

1.  **Programmatic Parsing of Raw, Unstructured Data:**
    * The code **explicitly** processes raw input that lacks a predefined schema using functions such as string manipulation, regex, image manipulation, etc. Examples include parsing free-form text, interpreting the content of an image file directly, or sifting through unstructured log data.
    * This explicitly **excludes** parsing of data already in a structured format (e.g., JSON, formatted lists, XML, CSV, YAML, protocol buffers). If the data is a string, but it is in a structured format, it is not considered raw input.
    * If all information is already hardcoded into the code or program inputs, **DO NOT** consider this as raw input parsing.

2.  **Feasibility of Simpler Preprocessing:**
    * The specific information targeted by the parsing logic could be **trivially extracted by a human** through manual inspection and input, OR could be **trivially extracted by a Large Language Model (LLM)** with a straightforward prompt.
    * "Trivial" implies that the extraction process for a human would be quick, obvious, and require minimal cognitive load. For an LLM, it means a simple query without complex prompt engineering would suffice.
    * Concurrently, the programmatic parsing implemented in the code is noticeably more complex, potentially more error-prone, or less robust than the manual/LLM alternative due to the unstructured nature of the input.

---

**Instructions for Evaluation:**

1.  Carefully examine the provided code and the problem it aims to solve.
2.  Assess if **Condition 1** is met: Is the code performing parsing on raw, unstructured input (as defined above) or is the parsing done beforehand/the input is provided already in structured form? If the input is already structured, then the code is not performing parsing on raw input.
3.  If Condition 1 is met, then assess if **Condition 2** is also met: Would it be trivial for a human to manually extract the target information, OR for an LLM to extract it? And is this manual/LLM approach significantly simpler or more reliable than the implemented code's parsing approach?
4.  If **both Condition 1 and Condition 2 are true**, your answer should be "Yes".
5.  Otherwise (if either condition is false), your answer should be "No".
"""
    )
    if symbols is not None:
        question_code = f"Question: {question}\nSymbols:\nsymbols = {symbols}\nCode:\n{code}"
    else:
        question_code = f"Question: {question}\nCode:\n{code}"
    print(question_code)
    return llmnet.forward(RawInput(text_input=question_code, image_input=None)) == "Yes"

def has_exception(model: APIModel, question: str, code: str, symbols: str) -> bool:
    """Check if the code raises an exception when executed."""
    try:
        _, _, err = python_eval(code)
        if err is not None:
            return True
        else:
            return False
    except Exception as e:
        return True

def has_wrong_return_type(model: APIModel, question: str, code: str) -> bool:
    """Check if the code returns wrong type (e.g. number instead of letter)."""
    eval_input = f"Question: {question}\nCode: {code}"
    llmnet = LLMNet(
        model,
        "a Python function",
        "1 if the return value of the function is the wrong type based on the question (e.g. number instead of letter) and 0 otherwise",
        few_shot=False)
    return llmnet.forward(RawInput(text_input=eval_input, image_input=None)) == "1"

def has_syntax_errors(model: APIModel, code: str) -> bool:
    """Check if the code has syntax errors."""
    try:
        ast.parse(code)
        return False
    except SyntaxError:
        return True

def has_wrong_symbol_reading(model: APIModel, symbols: str, code: str) -> bool:
    """Check if the code reads from symbols incorrectly."""
    # TODO: execute the code with the symbols and see if it raises an error related to symbols
    _, _, err = python_eval("symbols = " + repr(symbols) + "\n" + code + "\nanswer = solve(symbols)")
    if "symbols" in str(err):
        return True
    else:
        return False

def has_no_output_code(code: str) -> bool:
    """Check if the code has no output code."""
    return code == "None"

def has_no_output_symbols(symbols: str) -> bool:
    """Check if the code has no output symbols."""
    return symbols == "None"

def has_placeholders(model: APIModel, code: str) -> bool:
    """Check if the code contains placeholders."""
    llmnet = LLMNet(
        model,
        "a Python function",
        "1 if the code contains placeholders such as TODOs or a related comment that requires the user to fill in the code, and 0 otherwise",
        few_shot=False)
    return llmnet.forward(RawInput(text_input=code, image_input=None)) == "1"

def return_none(model: APIModel, symbols: str, code: str) -> bool:
    out, _, err = python_eval("symbols = " + repr(symbols) + "\n" + code + "\nanswer = solve(symbols)")
    if out == "None" and err is None:
        return True
    else:
        return False

def evaluate_our_code(model: APIModel, question: str, code: str, symbols: str) -> Dict[str, bool]:
    """Evaluate code for all 8 issues."""
    return {
        "has_hardcoded_answers": has_hardcoded_answers(model, question, code),
        "has_wrong_return_type": has_wrong_return_type(model, question, code),
        "has_syntax_errors": has_syntax_errors(model, code),
        "has_wrong_symbol_reading": has_wrong_symbol_reading(model, symbols, code),
        "has_no_output_code": has_no_output_code(code),
        "has_no_output_symbols": has_no_output_symbols(symbols),
        "has_placeholders": has_placeholders(model, code),
        "return_none": return_none(model, symbols, code),
        "has_wrong_hardcoded_value": has_input_processing_errors1(model, question, code),
        "has_wrong_unstructured_data_processing": has_input_processing_errors2(model, question, code, symbols),
        "has_exception": has_exception(model, question, code, symbols),
    }

def evaluate_code(model: APIModel, question: str, code: str, symbols: str) -> Dict[str, bool]:
    """Evaluate code for all 8 issues."""
    return {
        "has_hardcoded_answers": has_hardcoded_answers(model, question, code),
        "has_wrong_return_type": has_wrong_return_type(model, question, code),
        "has_syntax_errors": has_syntax_errors(model, code),
        "has_placeholders": has_placeholders(model, code),
        "has_wrong_hardcoded_value": has_input_processing_errors1(model, question, code),
        "has_wrong_unstructured_data_processing": has_input_processing_errors2(model, question, code),
        "has_exception": has_exception(model, question, code, symbols),
    }

def main_iter(args):
    # Read code from the log file
    code_outputs = []
    questions = []
    symbols = []
    with open(f"logs/{args.model}/{args.dataset}/{args.method}/outputs_gen_1_temp_0.0.txt") as f:
        for i, line in enumerate(f):
            out = literal_eval(line)
            questions.append(test_data[i][0][1])
            if 'all_outputs' in out[2]:
                codes = []
                symbs = []
                for gen in out[2]['all_outputs']:
                    if "```python" in gen:
                        try:
                            codes.append(re.findall(r"```python(.*?)```", gen, re.DOTALL)[-1])
                        except Exception as e:
                            codes.append("None")

                    if "```json" in gen:
                        try:
                            symbs.append(re.findall(r"```json(.*?)```", gen, re.DOTALL)[-1])
                        except Exception as e:
                            symbs.append("None" if len(symbs) == 0 else symbs[-1])
                    else:
                        symbs.append("None" if len(symbs) == 0 else symbs[-1])
                code_outputs.append(codes)
                symbols.append(symbs)
            else:
                if 'program' in out[2]:
                    code_outputs.append([out[2]['program']])
                elif 'output' in out[2]:
                    try:
                        if "```python" in out[2]['output']:
                            code_outputs.append([re.findall(r"```python(.*?)```", out[2]['output'], re.DOTALL)[-1]])
                        elif "```" in out[2]['output']:
                            code_outputs.append([re.findall(r"```(.*?)```", out[2]['output'], re.DOTALL)[-1]])
                        else:
                            code_outputs.append(["None"])
                    except Exception as e:
                        code_outputs.append(["None"])
                else:
                    code_outputs.append(["None"])

                if 'symbols' in out[2]:
                    symbols.append([out[2]['symbols']])
                else:
                    symbols.append(["None"])

    # Evaluate each code output
    results = []
    for i, (question, output, symbol) in enumerate(zip(questions, code_outputs, symbols)):
        all_issues = []
        for prog, sym in zip(output, symbol):
            issues = evaluate_our_code(model, question, prog, sym)
            all_issues.append(issues)

        results.append({
            'index': i,
            'code': output,
            'symbols': symbol,
            'issues': all_issues
        })
        
    # save results
    with open(f"logs/{args.model}/{args.dataset}/{args.method}/code_eval_iter_results.json", "w") as f:
        json.dump(results, f)


def main(args):
    # Read code from the log file
    code_outputs = []
    questions = []
    symbols = []
    with open(f"logs/{args.model}/{args.dataset}/{args.method}/outputs_gen_1_temp_0.0.txt") as f:
        for i, line in enumerate(f):
            out = literal_eval(line)
            questions.append(test_data[i][0][1])
            if 'program' in out[2]:
                code_outputs.append(out[2]['program'])
            elif 'output' in out[2]:
                try:
                    if "```python" in out[2]['output']:
                        code_outputs.append(re.findall(r"```python(.*?)```", out[2]['output'], re.DOTALL)[-1])
                    elif "```" in out[2]['output']:
                        code_outputs.append(re.findall(r"```(.*?)```", out[2]['output'], re.DOTALL)[-1])
                    else:
                        code_outputs.append("None")
                except Exception as e:
                    code_outputs.append("None")
            else:
                code_outputs.append("None")

            if 'symbols' in out[2]:
                symbols.append(out[2]['symbols'])
            else:
                symbols.append("None")

    # Evaluate each code output
    results = []
    for i, (question, output, symbol) in enumerate(zip(questions, code_outputs, symbols)):
        if args.method != "code":
            issues = evaluate_our_code(model, question, output, symbol)
        else:
            issues = evaluate_code(model, question, output, symbol)
        print(issues)
        results.append({
            'index': i,
            'code': output,
            'issues': issues
        })
        
    # save results
    with open(f"logs/{args.model}/{args.dataset}/{args.method}/code_eval_results.json", "w") as f:
        json.dump(results, f)

    # Print results
    print("\nEvaluation Results:")
    print("-" * 80)
    hardcoded_count = sum(result['issues']['has_hardcoded_answers'] for result in results)
    print(f"Percent of code which is hardcoded: {hardcoded_count / len(results)}")

    wrong_return_type_count = sum(result['issues']['has_wrong_return_type'] for result in results)
    print(f"Percent of code which has wrong return type: {wrong_return_type_count / len(results)}")

    syntax_errors_count = sum(result['issues']['has_syntax_errors'] for result in results)
    print(f"Percent of code which has syntax errors: {syntax_errors_count / len(results)}")
    

if __name__ == "__main__":
    # set up argument parser
    args = argparse.ArgumentParser()
    args.add_argument("--dataset", type=str)
    args.add_argument("--model", type=str, default="meta-llama/Llama-3.2-90B-Vision-Instruct")
    args.add_argument("--method", default="zs_cot")
    args = args.parse_args()

    model = APIModel("gemini-2.0-flash")

    np.random.seed(0)
    data = get_dataset(args)

    test_data_ids = list(range(min(200, len(data)))) #+ list(range(103, len(data)))
    shuf = np.random.permutation(test_data_ids)
    test_data = [data[int(i)] for i in shuf[:min(200, len(shuf))]]
    gt = [test_data[i][1] for i in range(len(test_data))]
    
    if False and args.method == "gen_sym_reason_prog_checks":
        print("evaluating iteratively")
        main_iter(args)
    else:
        print("evaluating normally")
        main(args)