import json
import os
import re
from typing import List
import math
from math_verify import parse, verify
import ast

split_chars = [',', 'and', '&']
def levenshtein_distance(s1, s2):
    if len(s1) > len(s2):
        s1, s2 = s2, s1
    distances = range(len(s1) + 1)
    for i2, c2 in enumerate(s2):
        distances_ = [i2 + 1]
        for i1, c1 in enumerate(s1):
            if c1 == c2:
                distances_.append(distances[i1])
            else:
                distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
        distances = distances_
    return distances[-1]

def anls(
    references,
    predictions,
    thresh_hold=0.5,
):
    values = []
    pred = str(predictions)
    
    for answer in references:
        answer = str(answer)
        gt_answer = " ".join(answer.strip().lower().split())
        det_answer = " ".join(pred.strip().lower().split())
        
        dist = levenshtein_distance(gt_answer, det_answer)
        length = max(len(answer), len(pred))
        
        values.append(0.0 if length == 0 else float(dist) / float(length))
        
    question_result = 1 - min(values)
    
    if question_result < thresh_hold:
        question_result = 0
        
    return question_result

def calculate_list_iou_as_sets(list1, list2):
    set1 = set(list1)
    set2 = set(list2)

    intersection = set1.intersection(set2) 
    union = set1.union(set2)              

    if not union:  
        return 0.0

    iou = len(intersection) / len(union)
    return iou

def is_numeric_string(s):
    try:
        float(s)
        return True
    except ValueError:
        return False
    except TypeError:
        return False

def split_with_multiple_delimiters(text: str, delimiters: list) -> list:
    """
    Splits a string by multiple delimiters to achieve the finest granularity.

    Args:
        text (str): The string to be split.
        delimiters (list): A list of characters to use as delimiters.

    Returns:
        list: A list of substrings after splitting.
    """
    # Create a regex pattern by joining the delimiters with the OR operator '|'
    # The re.escape() ensures that special regex characters are treated as literals
    pattern = '|'.join(re.escape(delim) for delim in delimiters)

    # Use re.split() to split the string based on the pattern
    # The list comprehension at the end removes any empty strings that might result from splitting
    return [part for part in re.split(pattern, text) if part]

def mmlongbench_doc_compute_score(model_output: str, ground_truth: str):
    # extracted_model_answer = re.search(r'\\boxed{(.*?)}', model_output, re.DOTALL)
    
    # if extracted_model_answer:
    #     extracted_model_answer = extracted_model_answer.group(1)
    # else:
    #     extracted_model_answer = re.search(r'<answer>(.*?)</answer>', model_output, re.DOTALL)
    #     if extracted_model_answer:
    #         extracted_model_answer = extracted_model_answer.group(1).strip()
    #     else: 
    #         print(f"mmlong, {ground_truth} vs {extracted_model_answer}, no answer block")
    extracted_model_answer = model_output 
    if extracted_model_answer is None:
        return {
            'score': 0.0, 
            'gt': ground_truth, 
            'extracted': None, 
        }
    
    try:
        parsed_gt = parse(ground_truth)
        parsed_model_answer = parse(extracted_model_answer)
        is_correct = verify(parsed_gt, parsed_model_answer)
    except:
        pass

    if is_correct:
        return {
                'score': 1.0, 
                'gt': ground_truth, 
                'extracted': parsed_model_answer, 
            }
    elif not is_numeric_string(ground_truth):
        if type(ground_truth) == str and type(extracted_model_answer) == str:
            ground_truth = ground_truth.lower().strip()
            extracted_model_answer = extracted_model_answer.lower().strip()
            if extracted_model_answer == ground_truth:
                return {
                'score': 1.0, 
                'gt': ground_truth, 
                'extracted': extracted_model_answer, 
            }
            if ground_truth in extracted_model_answer and len(extracted_model_answer) - len(ground_truth) < 10:
                return {
                'score': 1.0, 
                'gt': ground_truth, 
                'extracted': extracted_model_answer, 
            }
        if str(ground_truth) == str(extracted_model_answer):
            return {
                'score': 1.0, 
                'gt': ground_truth, 
                'extracted': extracted_model_answer, 
            }
        try:
            gt_lst = ast.literal_eval(ground_truth)
            if type(gt_lst) == list:
                model_answer_lst = [item.strip() for item in extracted_model_answer.split(',')]
                iou = calculate_list_iou_as_sets(gt_lst, model_answer_lst)
                gt_str = ', '.join(gt_lst)
                anls_score2 = anls([gt_str], extracted_model_answer)
                anls_score = max(iou, anls_score2)
            else:
                anls_score = anls([ground_truth], extracted_model_answer)
        except:
            anls_score = anls([ground_truth], extracted_model_answer)
    else: # gt is a number, math-verify failed
        try:
            anls_score = abs(float(extracted_model_answer) - float(ground_truth)) < 1e-3
        except:
            print(f"Warning: extracted_model_answer {extracted_model_answer}, ground_truth {ground_truth}")
            return {
                'score': 0.0, 
                'gt': ground_truth, 
                'extracted': extracted_model_answer, 
            }
    return {
                'score': anls_score, 
                'gt': ground_truth, 
                'extracted': extracted_model_answer, 
            }

def slidevqa_compute_score(model_output: str, ground_truth: str):
    # extracted_model_answer = re.search(r'\\boxed{(.*?)}', model_output, re.DOTALL)
    extracted_model_answer = model_output
    if extracted_model_answer is None:
        return {
                'score': 0.0, 
                'gt': ground_truth, 
                'extracted': extracted_model_answer, 
        }

    extracted_model_answer = extracted_model_answer.lower().strip()
    ground_truth = ground_truth.lower().strip()
    if extracted_model_answer == ground_truth:
        return {
                'score': 1.0, 
                'gt': ground_truth, 
                'extracted': extracted_model_answer, 
            }
    if ground_truth in extracted_model_answer and len(extracted_model_answer) - len(ground_truth) < 10:
        return {
                'score': 1.0, 
                'gt': ground_truth, 
                'extracted': extracted_model_answer, 
            }
    is_correct = False
    try:
        parsed_gt = parse(ground_truth)
        parsed_model_answer = parse(extracted_model_answer)
        is_correct = verify(parsed_gt, parsed_model_answer)
    except:
        pass
    if is_correct:
        return {
                'score': 1.0, 
                'gt': ground_truth, 
                'extracted': extracted_model_answer, 
            }
    elif not is_numeric_string(ground_truth):
        if type(ground_truth) == str and type(extracted_model_answer) == str:
            ground_truth = ground_truth.lower().strip()
            extracted_model_answer = extracted_model_answer.lower().strip()
            if extracted_model_answer == ground_truth:
                return {
                'score': 1.0, 
                'gt': ground_truth, 
                'extracted': extracted_model_answer, 
            }
            if ground_truth in extracted_model_answer and len(extracted_model_answer) - len(ground_truth) < 10:
                return {
                'score': 1.0, 
                'gt': ground_truth, 
                'extracted': extracted_model_answer, 
            }
        if str(ground_truth) == str(extracted_model_answer):
            return {
                'score': 1.0, 
                'gt': ground_truth, 
                'extracted': extracted_model_answer, 
            }
        try:
            ground_truth_lst = split_with_multiple_delimiters(ground_truth, split_chars)
            model_answer_lst = split_with_multiple_delimiters(extracted_model_answer, split_chars)
            iou = calculate_list_iou_as_sets(ground_truth_lst, model_answer_lst)
            return {
                'score': iou, 
                'gt': ground_truth, 
                'extracted': extracted_model_answer, 
            }
        except:
            print(f"slide, {ground_truth} vs {extracted_model_answer}, iou incorrect")
            return {
                'score': 0.0, 
                'gt': ground_truth, 
                'extracted': extracted_model_answer, 
            }
    else: # gt is a number, math-verify failed
        try:
            anls_score = float(abs(float(extracted_model_answer) - float(ground_truth)) < 1e-3)
        except:
            print(f"Warning: extracted_model_answer {extracted_model_answer}, ground_truth {ground_truth}")
            return {
                'score': 0.0, 
                'gt': ground_truth, 
                'extracted': extracted_model_answer, 
            }
    print(f"slide, {ground_truth} vs {extracted_model_answer}, {anls_score}")
    # return anls_score
    return {
                'score': anls_score, 
                'gt': ground_truth, 
                'extracted': extracted_model_answer, 
            }

def levenshtein_distance_dude(s1: str, s2: str) -> int:
    """
    Calculates the Levenshtein distance between two strings.
    This is the minimum number of single-character edits (insertions,
    deletions or substitutions) required to change one string into the other.
    """
    if len(s1) > len(s2):
        s1, s2 = s2, s1

    distances = range(len(s1) + 1)
    for i2, c2 in enumerate(s2):
        distances_ = [i2 + 1]
        for i1, c1 in enumerate(s1):
            if c1 == c2:
                distances_.append(distances[i1])
            else:
                distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
        distances = distances_
    return distances[-1]

def dude_compute_score(model_output: str, ground_truth: str, threshold: float = 0.5) -> float:
    """
    Calculates the Average Normalized Levenshtein Similarity (ANLS) between two strings.

    Args:
        model_output (str): The string generated by the model.
        ground_truth (str): The correct reference string.
        threshold (float): The similarity threshold. If the score is below this,
                           it is clipped to 0.0. Defaults to 0.5.

    Returns:
        float: The ANLS score, a value between 0.0 and 1.0.
    """
    if model_output is None:
        return {
                'score': 0.0, 
                'gt': ground_truth, 
                'extracted': model_output, 
            }
    if ground_truth == "":
        if "not answerable" in model_output.lower() or "no answer" in model_output.lower():
            return {
                'score': 1.0, 
                'gt': ground_truth, 
                'extracted': model_output, 
            }
    # Preprocess both the ground truth and the model output
    # Converts to lowercase, removes leading/trailing whitespace, and normalizes internal spaces.
    processed_gt = " ".join(ground_truth.strip().lower().split())
    processed_mo = " ".join(model_output.strip().lower().split())

    # Calculate Levenshtein distance
    dist = levenshtein_distance_dude(processed_gt, processed_mo)
    
    # Normalize by the length of the longer string
    length = max(len(processed_gt), len(processed_mo))
    
    # Handle case of both strings being empty
    if length == 0:
        return {
                'score': 1.0, 
                'gt': processed_gt, 
                'extracted': processed_mo, 
            }

    # Calculate normalized similarity
    score = 1.0 - (float(dist) / float(length))
    
    # Apply the threshold
    if score < threshold:
        score = 0.0
        
    print(f"dude, {processed_gt} vs {processed_mo}", score)
    return {
                'score': score, 
                'gt': processed_gt, 
                'extracted': processed_mo, 
            }
    return score
        
def _default_compute_score(model_output: str, ground_truth: str, source: str):
    # extracted_model_answer = re.search(r'\\boxed{(.*)}', model_output, re.DOTALL)
    matches = list(re.finditer(r'\\boxed{(.*)}', model_output, re.DOTALL))
    last_match = matches[-1] if matches else None  # Get last match, or None if no matches
    if last_match is not None:
        extracted_model_answer = last_match.group(1)
    else:
        extracted_model_answer = re.search(r'<answer>(.*)</answer>', model_output, re.DOTALL)
        if extracted_model_answer:
            extracted_model_answer = extracted_model_answer.group(1).strip()
        else: extracted_model_answer = model_output.split("</think>")[-1].strip()

    # if 'dude' in source.lower():
    #     return dude_compute_score(extracted_model_answer, ground_truth)
    if 'slide' in source.lower():
        ret =  slidevqa_compute_score(extracted_model_answer, ground_truth)
    else:
        ret = mmlongbench_doc_compute_score(extracted_model_answer, ground_truth)
    ret['score'] = float(ret['score'])
    ret['acc'] = ret['score']
    return ret