import re
import json
import ast
import operator
import tempfile
import logging
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from src.pddl import eval_solution_files
from src.dataset import TwentyFourGameDataset, OmniMathDataset
from src.llm_models import APIModel
from vllm import SamplingParams
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


def get_equivalence(args):
    if args.dataset == "gsm8k":

        def equiv(x, y, i):
            try:
                y = re.findall(r"(?:-)?(\d+(?:\.\d+)?)", y)[-1]
                if not abs(float(x) - float(y)) < 0.01:
                    print(f"GT: {x}, Pred: {y}")
                return abs(float(x) - float(y)) < 0.01
            except:
                print(f"GT: {x}, Pred: {y}")
                return False

        return equiv
    elif args.dataset == "chartqa":

        def relaxed_equiv(gt, pred, i):
            # check if x and y are strings that can be converted to floats
            if pred is None:
                return False

            pred = str(pred)

            if (
                gt.replace(".", "", 1).replace("%", "", 1).replace("$", "", 1).isdigit()
                and pred.replace(".", "", 1)
                .replace("%", "", 1)
                .replace("$", "", 1)
                .isdigit()
            ):
                return (
                    abs(float(gt) - float(pred.replace("%", "", 1).replace("$", "", 1)))
                    / max(
                        abs(float(gt)),
                        abs(float(pred.replace("%", "", 1).replace("$", "", 1))),
                    )
                    < 0.05
                )
            elif pred == "True":
                return gt == "yes"
            elif pred == "False":
                return gt == "no"
            else:
                return gt.lower() == pred.lower()

        return relaxed_equiv
    elif args.dataset == "blocksworld":

        def equiv(gt, pred, i):
            # find pddl for problem
            # problem = f"data/mystery_blocksworld/mystery_pddl/instance-{i+2}.pddl"
            problems = "data/mystery_blocksworld/mysteryblocks_3_10.json"
            domain_file = "data/mystery_blocksworld/mystery_pddl/domain.pddl"

            problem = json.load(open(problems, "r"))["data"][i]["pddl_problem"]
            with tempfile.NamedTemporaryFile(mode='w', delete=False, encoding="ascii") as f:
                f.write(problem)
                problem_file = f.name
                f.seek(0)

                try:
                    if "pddl" in pred:
                        pred = pred.split("pddl")[-1].strip()
                    if "\\text" in pred:
                        pred = pred.split("\\text")[-1].strip()
                    pred = pred.replace(" from", "").strip().lower()
                    pred = pred.replace(" object", "").strip().lower()
                    pred = pred.replace(" then", "").strip().lower()
                    pred = pred.replace(",", "\n")
                    plan = [re.sub(r"[^a-zA-Z]", " ", s).strip() for s in pred.split("\n")]
                    print("Inside equiv:", i, plan)
                except:
                    return False
                return eval_solution_files(problem_file, domain_file, plan)

        return equiv
    elif args.dataset == "bbh_sort":

        def equiv(gt, pred, i):
            try:
                if "\\text" in pred:
                    pred = re.findall(r"\\text{(.*)}", pred)[0]

                if "1. " in pred:
                    # extract the list of words from numbered list
                    words = re.findall(r"\d\. (.*)", pred)
                    pred = " ".join(words)
                else:
                    pred = (
                        pred.replace("[", "")
                        .replace(",", " ")
                        .replace("]", "")
                        .replace("'", "")
                        .replace('"', "")
                        .strip()
                    )
                    pred = " ".join([p.strip() for p in pred.split()])
                gt = gt.replace("'", "")
                if pred != gt:
                    print(f"GT: {gt}, Pred: {pred}")
                return pred == gt
            except:
                print(f"GT: {gt}, Pred: {pred}")
                return False

        return equiv
    elif args.dataset == "bbh_shuffle7":

        def equiv(gt, pred, i):
            try:
                gt = re.findall(r"[A-Z]", gt)[0]
                if "(" in pred:
                    pred = re.findall(r"([a-zA-Z])", pred)[0]
                else:
                    pred = re.findall(r"[a-zA-Z]", pred)[-1]
                return gt == pred
            except:
                return False

        return equiv
    elif args.dataset == "clevr":

        def equiv(gt, pred, i):
            gt = gt.strip().lower()
            pred = pred.strip().lower()
            if gt.isdigit() and not pred.isdigit():
                return False
            if gt.isdigit() and pred.isdigit():
                return int(gt) == int(pred)
            return gt in pred

        return equiv
    elif args.dataset == "folio":

        def equiv(gt, pred, i):
            def conv(string):
                if type(string) is not str:
                    return None
                string = string.strip().lower()
                if any([i in string for i in ["true"]]):
                    return "True"
                elif any([i in string for i in ["false"]]):
                    return "False"
                elif any([i in string for i in ["uncertain"]]):
                    return "Uncertain"
                else:
                    return None

            gt2 = conv(gt)
            pred2 = conv(pred)

            if gt2 is None or pred2 is None:
                return False

            return gt2 == pred2

        return equiv
    elif args.dataset == "longsort":

        def equiv(gt, pred, i):
            if "\\text" in pred:
                pred = re.findall(r"\\text{(.*)}", pred)[0]

            if "1. " in pred:
                # extract the list of words from numbered list
                words = re.findall(r"\d+\. (.*)", pred)
                pred = " ".join([w.strip() for w in words])
            else:
                pred = (
                    pred.replace("[", "")
                    .replace(",", " ")
                    .replace("]", "")
                    .replace("'", "")
                    .replace('"', "")
                    .strip()
                )
                # print([p.strip() for p in pred.split() if p.strip() != ""])
                pred = " ".join([p.strip() for p in pred.split() if p.strip() != "" and p.strip() != "*"])

            gt = " ".join([f"{i}" for i in gt])
            return gt == str(pred)

        return equiv
    elif args.dataset == "sudoku":

        def equiv(gt, pred, i):
            gt = re.sub(r"[^0-9]", "", gt)
            pred = re.sub(r"[^0-9]", "", pred)
            return gt == pred
        
        return equiv
    elif args.dataset == "listsynthesis":

        def equiv(gt, pred, i):
            if type(pred) == "str" and "\n" in pred:
                pred = pred.split("\n")[0]
            print(f"GT: {gt}, Pred: {pred}")
            return str(gt) == str(pred)

        return equiv
    elif args.dataset == "aime":

        def equiv(gt, pred, i):
            if pred.lower() == "none" or pred.lower() == "" or pred.lower() == "none found":
                return False
            
            try:
                gt = int(float(gt.lower()))
                pred = int(float(pred.lower()))
                return gt == pred
            except ValueError:
                return False

        return equiv
    elif args.dataset == "clutrr":

        def equiv(gt, pred, i):
            if gt in pred.lower():
                return True
            return False

        return equiv
    elif args.dataset == "clevr":
        def equiv(gt, pred, i):
            gt = gt.strip().lower()
            pred = pred.strip().lower()
            if gt.isdigit() and not pred.isdigit():
                return False
            if gt.isdigit() and pred.isdigit():
                return int(gt) == int(pred)
            return gt in pred  
        return equiv
    elif args.dataset == "24game":
        try:
            dataset = TwentyFourGameDataset(args.twentyfourgame_n)
            logger.info("Successfully instantiated TwentyFourGameDataset.")
        except Exception as e:
            logger.error(f"Failed to instantiate TwentyFourGameDataset: {e}")
            raise ValueError(f"Failed to instantiate TwentyFourGameDataset: {e}")

        # Define allowed operators for safe evaluation
        allowed_operators = {
            ast.Add: operator.add,
            ast.Sub: operator.sub,
            ast.Mult: operator.mul,
            ast.Div: operator.truediv,
        }

        def equiv(gt, pred, i):
            """
            Equivalence function for the 24 Game.

            Args:
                gt (int): The target value (24).
                pred (str): The predicted arithmetic expression.
                i (int): The index of the current sample.

            Returns:
                int: Returns 24 if the expression is valid, otherwise -1.
            """
            # Retrieve the numbers for the i-th sample
            try:
                sample = dataset[i]
                numbers = sample[2]  # sample is a list of numbers, e.g., [8, 1, 9, 6]
                logger.debug(f"Sample {i}: Numbers = {numbers}")
            except IndexError:
                logger.error(f"IndexError: Sample index {i} is out of range.")
                return -1
            except Exception as e:
                logger.error(f"Error accessing sample index {i}: {e}")
                return -1

            # Define a safe evaluation function using AST
            def safe_eval(expr):
                try:
                    tree = ast.parse(expr, mode='eval')

                    def _eval(node):
                        if isinstance(node, ast.Expression):
                            return _eval(node.body)
                        elif isinstance(node, ast.Num):  # <number>
                            return node.n
                        elif isinstance(node, ast.BinOp):
                            if type(node.op) not in allowed_operators:
                                raise ValueError(f"Unsupported operator: {ast.dump(node.op)}")
                            return allowed_operators[type(node.op)](_eval(node.left), _eval(node.right))
                        elif isinstance(node, ast.UnaryOp):
                            if type(node.op) not in allowed_operators:
                                raise ValueError(f"Unsupported operator: {ast.dump(node.op)}")
                            return allowed_operators[type(node.op)](_eval(node.operand))
                        else:
                            raise ValueError(f"Unsupported expression: {ast.dump(node)}")

                    return _eval(tree)
                except Exception as e:
                    raise ValueError(f"Invalid expression: {expr}. Error: {e}")

            # Function to check if all numbers are used exactly once
            def all_numbers_used(expr, numbers):
                """
                Check if all numbers are used exactly once in the expression.

                Args:
                    expr (str): The arithmetic expression.
                    numbers (list): The list of numbers provided.

                Returns:
                    bool: True if all numbers are used exactly once, False otherwise.
                """
                # Extract all numbers from the expression
                expr_numbers = re.findall(r'\b\d+\b', expr)
                expr_numbers = [int(num) for num in expr_numbers]

                # Compare sorted lists
                return sorted(expr_numbers) == sorted(numbers)

            # Clean the predicted expression
            pred_clean = pred.strip()
            logger.debug(f"Sample {i}: Cleaned Prediction = '{pred_clean}'")

            # Attempt to evaluate the expression
            try:
                result = safe_eval(pred_clean)
                logger.debug(f"Sample {i}: Evaluated Result = {result}")
                if abs(result - gt) > 1e-6:
                    logger.info(f"Sample {i}: Evaluation failed: {pred_clean} = {result} != {gt}")
                    return -1
            except ValueError as ve:
                logger.info(f"Sample {i}: Error evaluating expression: {pred_clean}. Error: {ve}")
                return -1

            # Check if all numbers are used exactly once
            if not all_numbers_used(pred_clean, numbers):
                logger.info(f"Sample {i}: Numbers usage failed: {pred_clean}, Numbers: {numbers}")
                return -1

            # If all checks pass, return the target value
            return gt

        return equiv
    
    elif args.dataset == "bbeh_buggy_tables":
        print("Using BBEH tables equiv")
        def equiv(gt, pred, i):
            try:
                return float(str(gt)) == float(str(pred))
            except:
                return False
        return equiv
    
    elif args.dataset == "bbeh_multistep_arithmetic":
        print("Using BBEH arith equiv")
        def equiv(gt, pred, i):
            try:
                return float(str(gt)) == float(str(pred))
            except:
                return False
        return equiv
    
    elif args.dataset == "bbeh_boolean_expressions" or args.dataset == "bbeh_geometric_shapes" or args.dataset == "bbeh_shuffled_objects"  or args.dataset == "bbeh_disambiguation_qa" or args.dataset == "bbeh_hyperbaton" or args.dataset == "bbeh_movie_recommendation"   or args.dataset == "bbeh_disambiguation_qa" or args.dataset == "bbeh_hyperbaton" or args.dataset == "bbeh_nycc":
        print("Using BBEH bbool/geo/shuff/disam/hyper expr equiv")
        # Regex to match a single letter A–Z (case-insensitive),
        # optionally in parentheses, at the start of the string,
        # ensuring the next character is NOT another letter.
        pattern = r'^\(?([A-Za-z])\)?(?![A-Za-z])'

        def extract_answer(s):
            """
            Extracts a single letter (A-Z) — optionally in parentheses — 
            from the start of the string. Excludes cases like 'BFG'.
            """
            s = str(s).strip()
            m = re.match(pattern, s)
            return m.group(1).lower() if m else None

        def equiv(gt, pred, i):
            try:
                gt_answer = extract_answer(gt)
                pred_answer = extract_answer(pred)
                return (
                    gt_answer is not None and
                    pred_answer is not None and
                    gt_answer == pred_answer
                )
            except:
                return False

        return equiv

    elif args.dataset in {"bbeh_web_of_lies", "bbeh_sportqa", "bbeh_sarc_triples"}:
        print("Using BBEH lies equiv")


        def tokenize(s):
            # Remove brackets, parentheses, quotes, whitespace, and angle brackets, and lowercase.
            cleaned = re.sub(r"[\[\]\(\)'\"\s<>]", "", str(s).lower())
            # Split on commas and filter out any empty tokens.
            tokens = [token for token in cleaned.split(",") if token]
            return tokens

        def equiv(gt, pred, i):
            return tokenize(gt) == tokenize(pred)

        return equiv
    
    elif args.dataset == "bbeh_time_arithmetic":
        print("Using BBEH time arith")

        def is_unanswerable(s: str) -> bool:
            return s.strip().lower() == "unanswerable"

        def is_listlike(s: str) -> bool:
            """
            True if the string is enclosed in matching parentheses or brackets,
            e.g. '[item1, item2]', '(0, 16, 54, 58)', '['Friday']', etc.
            """
            s = s.strip()
            # e.g. starts with '(' and ends with ')' OR starts with '[' and ends with ']'
            return ((s.startswith("(") and s.endswith(")")) or 
                    (s.startswith("[") and s.endswith("]")))

        def parse_listlike(s: str):
            """
            Strips the outer parentheses/brackets and splits on commas.
            e.g. '[thursday]' -> ['thursday'], '(0, 16, 54)' -> ['0','16','54'].
            """
            s = s.strip()
            # remove the outer bracket/parenthesis
            s_inner = s[1:-1].strip()  # drop the first and last char
            if not s_inner:
                return []
            return [x.strip().replace("'", "") for x in s_inner.split(",")]

        def compare_listlike(gt: str, pred: str) -> bool:
            """
            Compare listlike items ignoring case, same length, same item order.
            """
            gt_items = parse_listlike(gt)
            pr_items = parse_listlike(pred)
            if len(gt_items) != len(pr_items):
                return False
            for g, p in zip(gt_items, pr_items):
                if g.lower() != p.lower():
                    return False
            return True

        def is_time_expression(s: str) -> bool:
            """
            Check for time expressions like 'same_day, 09:05', '+1, 05:36', etc.
            Must contain a comma, and the left side is 'same_day' or starts with +/-.
            """
            s = s.strip().lower()
            if "," not in s:
                return False
            left_part = s.split(",", 1)[0].strip()
            if left_part == "same_day" or left_part.startswith("+") or left_part.startswith("-"):
                return True
            return False

        def compare_time_expressions(gt: str, pred: str) -> bool:
            """
            Compare time expressions piece by piece.
            """
            def parse_time_expr(expr: str):
                expr = expr.strip()
                parts = expr.split(",", 1)
                if len(parts) < 2:
                    return ("", "")
                tag = parts[0].strip().lower()   
                time_str = parts[1].strip().lower()
                return (tag, time_str)

            gt_tag, gt_time = parse_time_expr(gt)
            pr_tag, pr_time = parse_time_expr(pred)
            return (gt_tag == pr_tag) and (gt_time == pr_time)

        def is_unadorned_date(s: str) -> bool:
            """
            Very loose check for a possible date string:
            e.g. '2000-4-28', '22 Jan, 2000', '1997-Oct-16', etc.
            """
            s = s.strip().lower()
            # Just check if there's at least one digit:
            return bool(re.search(r"\d", s))

        def compare_dates(gt: str, pred: str) -> bool:
            """
            Simple case-insensitive direct match for potential dates.
            Could be replaced with actual date parsing if needed.
            """
            return gt.strip().lower() == pred.strip().lower()

        def equiv(gt, pred, i):
            """
            Overall comparison function.
            """
            try:
                gt = str(gt).strip()
                pred = str(pred).strip()

                # 1. unanswerable => must match literally
                if is_unanswerable(gt):
                    return is_unanswerable(pred)

                # 2. listlike => parse and compare item-by-item ignoring case
                if is_listlike(gt) and is_listlike(pred):
                    return compare_listlike(gt, pred)

                # 3. time expressions
                if is_time_expression(gt) and is_time_expression(pred):
                    return compare_time_expressions(gt, pred)

                # 4. dates
                if is_unadorned_date(gt) and is_unadorned_date(pred):
                    return compare_dates(gt, pred)

                # 5. fallback => direct case-insensitive compare
                return gt.lower() == pred.lower()

            except Exception as e:
                print(f"[ERROR in equiv] {e} -- gt='{gt}', pred='{pred}'")
                return False

        return equiv



    elif args.dataset == "bbeh_causal_understanding":
        print("Using BBEH causal equiv")

        def extract_label(s):
            s = str(s).strip().lower()
            if s.startswith("yes"):
                return "yes"
            elif s.startswith("no"):
                return "no"
            elif s.startswith("ambiguous"):
                return "ambiguous"
            else:
                # fallback: just grab the first word in case formatting is weird
                first_word = s.split()[0] if s else ""
                if first_word in {"yes", "no", "ambiguous"}:
                    return first_word
                return s  # return raw for debugging

        def equiv(gt, pred, i):
            return extract_label(gt) == extract_label(pred)

        return equiv
    
    elif args.dataset == "bbeh_word_sorting":
        print("Using BBEH word‑sorting equiv")

        def normalize_list(s):
            # 1) lowercase and turn into string
            t = str(s).lower()
            # 2) drop any [,],(,),',"
            t = re.sub(r"[\[\]\(\)'\"\`]", "", t)
            # 3) split on commas if present, else on whitespace
            if "," in t:
                parts = t.split(",")
            else:
                parts = t.split()
            # 4) strip each token and drop empties
            parts = [p.strip() for p in parts if p.strip()]
            # 5) sort so order doesn’t matter
            return sorted(parts)

        def equiv(gt, pred, i):
            return normalize_list(gt) == normalize_list(pred)

        return equiv


        
    
        
    
    elif args.dataset.startswith("bbeh"):
        print("Using BBEH equiv")
        def equiv(gt, pred, i):
            return str(gt).strip().lower().replace('*','') == str(pred).strip().lower().replace('*','')
        return equiv
    else:

        def equiv(gt, pred, i):
            return gt == pred

        return equiv


def get_omnimath_judge(args):
    model = AutoModelForCausalLM.from_pretrained("KbsdJames/Omni-Judge", torch_dtype="auto", device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained("KbsdJames/Omni-Judge", trust_remote_code=True)
    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    np.random.seed(0)
    data = OmniMathDataset(split="test", difficulty=int(args.dataset[9:]))
    test_data_ids = list(range(min(200, len(data))))
    shuf = np.random.permutation(test_data_ids)
    test_data = [data[int(i)] for i in shuf[:200]]
    gt = [test_data[i][1] for i in range(len(test_data))]

    def equiv(gt, pred, i):
        question = test_data[i][0][1]
        answer = gt
        student_answer = pred

        formatted_context = tokenizer.get_context(question, answer, student_answer)
        model_inputs = tokenizer(formatted_context, return_tensors="pt").to(model.device)
        pred = model.generate(input_ids=model_inputs["input_ids"], attention_mask=model_inputs["attention_mask"], do_sample=False, num_return_sequences=1, max_new_tokens=300)[0].cpu().tolist()
        pred = pred[len(model_inputs["input_ids"][0].cpu().tolist()):]
        for terminator in terminators:
            if terminator in pred:
                pred = pred[:pred.index(terminator)]
        response = tokenizer.decode(pred, skip_special_tokens=True)
        pred_truth = tokenizer.parse_response(response)
        print(pred_truth)
        return pred_truth["judgement"] == "TRUE"

    return equiv


def get_omnimath_judge_gpt(args):
    model = APIModel(model_name="gemini-2.0-flash")

    np.random.seed(0)
    data = OmniMathDataset(split="test", difficulty=int(args.dataset[9:]))
    test_data_ids = list(range(min(200, len(data))))
    shuf = np.random.permutation(test_data_ids)
    test_data = [data[int(i)] for i in shuf[:200]]

    def equiv(gt, pred, i):
        prompt = """# CONTEXT #
I am a teacher, and I have some high-level math problems. I am tasked with evaluating the correctness of a student's answer. 
Below, I am provided with a problem and a reference answer. Additionally, a student's answer is provided. My job is to assess whether the student's answer captures the same meaning as the reference answer, even when expressed with different wording or format.

# OBJECTIVE #
I need you to judge whether the student's answer is correct given the ground truth answer.

Your tasks include:
A. Identify Mathematical or Notational Equivalence: Pay special attention to any LaTeX expressions in both answers. Confirm that the mathematical relationships, variables, and operations conveyed are equivalent.
B. Provide a Justification: Conclude with a brief explanation as to why you believe the student's output is correct or incorrect, highlighting any key differences in meaning or content.

# STYLE #
Teaching report.

# TONE #
Professional, scientific.

# AUDIENCE #
Students. Enable them to better understand whether the answer they produce is correct.

# RESPONSE: MARKDOWN REPORT #
## Student Final Answer
[Extract the student's final answer, which may be enclosed in "\\boxed{{}}".]
## Equivalence Judgement
[Whether the student's answer share the same meaning with the reference answer. (TRUE or FALSE)]
## Justification
[Conclude with a brief explanation as to why you believe the student's answer is correct or incorrect.]


# ATTENTION #
 - The reference answer is ALWAYS correct. You should carefully judge whether the student gives the same answer as reference answer.
 - The Equivalence Judgement is only TRUE or FALSE. The answer is FALSE even if the student's final answer almost correct with a minor mistakes.
 - The answer is contained within the "boxed" section, so you can focus solely on comparing the content in the student's answer box with the reference answer, without needing to consider the intermediate steps.
 - Add "=== report over ===" at the end of the report.


<math solution>
**Question**:
{Problem}

**Reference Answer**
{ReferenceAnswer}

**Student Solution**:
{Solution}

</math solution>"""
        prompt = [{"role": "user", "content": prompt.format(
            Problem=data[int(i)][0][1],
            ReferenceAnswer=gt,
            Solution=pred
        )}]
        sampling_params = SamplingParams(temperature=0.0, max_tokens=50000, top_p=1.0)
        output = (
            model.chat(prompt, sampling_params=sampling_params, use_tqdm=False)[0]
            .outputs[0]
            .text
        )

        # extract the judgement
        judgement_match = re.search(r"## Equivalence Judgement\s*\n(.*?)\s*## Justification", output, re.DOTALL)
        if judgement_match:
            judgement = judgement_match.group(1).strip()
            if judgement.lower() == "true":
                print(gt, pred, "True")
                return True
            elif judgement.lower() == "false":
                print(gt, pred, "False")
                return False
        else:
            print("Judgement not found in output.")
            return False

    return equiv


def get_equivalence_or_judge(args):
    if args.dataset.startswith("omnimath"):
        return get_omnimath_judge_gpt(args)
    else:
        return get_equivalence(args)