import os
import torch
import numpy as np
from torch.nn import functional as F
from collections import Counter
import re
import re, string


def extract_answer_after_answerability(raw_answer: str) -> str:
    match_answerable = re.search(
        r'(?:answerability\s*:\s*)?answerable\s*:\s*(.+?)(?:\n|$)',
        raw_answer,
        re.IGNORECASE
    )
    if match_answerable:
        return match_answerable.group(1).strip()

    match_option = re.search(
        r'answerability\s*:\s*([A-D])(?:\s*\.?\s*.*)?(?:\n|$)',
        raw_answer,
        re.IGNORECASE
    )
    if match_option:
        return match_option.group(1).strip().upper()
    match_unanswerable = re.search(
        r'(?:answerability\s*:\s*)?unanswerable(?:\n|$)',
        raw_answer,
        re.IGNORECASE
    )
    if match_unanswerable:
        return "unanswerable"
    match_answer_prefix_answerable = re.search(
        r'Answer\s*:\s*answerable\s*:\s*(.+?)(?:\n|$)',
        raw_answer,
        re.IGNORECASE
    )
    if match_answer_prefix_answerable:
        return match_answer_prefix_answerable.group(1).strip()

    match_answer_prefix = re.search(
        r'Answer\s*:\s*(.+?)(?:\n|$)',
        raw_answer,
        re.IGNORECASE
    )
    if match_answer_prefix:
        return match_answer_prefix.group(1).strip()

    return raw_answer.strip()

UNCERTAINTY_PHRASES = [
    "unable to determine", "not possible to determine", "cannot provide",
    "not sure", "not visible", "does not provide enough information",
    "does not provide any information", "does not provide information",
    "cannot answer", "not clear", "insufficient information",
    "unanswerable", "cannot be determined", "is not mentioned", "unknown","cannot be identified"
]


def extract_clean_answer(ans: str) -> str:
    if not ans or not isinstance(ans, str):
        return ""

    ans = ans.lower().strip()

    if any(phrase in ans for phrase in UNCERTAINTY_PHRASES):
        return "unanswerable"

    KEEP_PUNCT = {"/"}

    REMOVE_PUNCT = "".join(ch for ch in string.punctuation if ch not in KEEP_PUNCT)

    ans_no_punct = ans.translate(str.maketrans('', '', REMOVE_PUNCT))

    # in/on the image is …
    match = re.search(
        r'\b(?:in|on)\s+the\s+(?:image|picture|photo|screenshot)\s+'
        r'(?:is|are)\s+(?:a|an|the|some)?\s*(.+)$',
        ans_no_punct
    )
    if match:
        return re.sub(r'\s+', ' ', match.group(1).strip())

    #  X is/are Y
    simple_match = re.search(
        r'.*\b(?:is|are|called|contains)\s+(?:a|an|the|some)?\s*(.+)$',
        ans_no_punct
    )
    if simple_match:
        candidate = re.sub(r'\s+', ' ', simple_match.group(1).strip())
        if candidate:
            return candidate

    cleaned = re.sub(
        r'^(the\s+|this\s+|that\s+|these\s+|those\s+|of\s+the\s+objects?\s+in\s+the\s+image\s+is\s+)?'
        r'(image\s+shows\s+|captcha\s+on\s+this\s+screenshot\s+is\s+)?'
        r'(answer|product|item|color|contents?|bottle|can|name|man|motorbike)?\s*'
        r'(is|are|contains?|called|shows)?\s*',
        '',
        ans_no_punct
    )

    return re.sub(r'\s+', ' ', cleaned).strip()

STOPWORDS = {"a", "an", "the", "this", "that", "item", "product", "kit"}

def tokenize(txt: str):
    return [w for w in txt.split() if w not in STOPWORDS]

def is_match(res_clean: str, gt_clean: str) -> bool:
    """
    Robust string‑level matching.
    Returns True if the two answers can be considered the same.
    """
    if not res_clean or not gt_clean:
        return False
    if res_clean == gt_clean:
        return True
    if res_clean in gt_clean or gt_clean in res_clean:
        return True
    res_tokens = set(res_clean.split())
    gt_tokens  = set(gt_clean.split())
    if not res_tokens or not gt_tokens:
        return False
    inter = res_tokens & gt_tokens
    if len(inter) >= 2 and len(inter) / len(res_tokens | gt_tokens) >= 0.6:
        return True

    return False

def clean(text):
    return re.sub(r'^[\W_]+|[\W_]+$', '', text.strip().lower())

def get_ground_truth(GTAns):
    if not GTAns:
        return None  
    count = Counter(GTAns)
    most_common = count.most_common(1)
    return most_common[0][0]

def str2int(label: str) -> int:
    mapping = {'A': 1, 'B': 2, 'C': 3, 'D': 4}
    if label in mapping:
        return mapping[label]
    else:
        return 5

def disable_torch_init():
    """
    Disable the redundant torch default initialization to accelerate model creation.
    """
    import torch
    setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
    setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)


def extract_prediction(text):
    text_str = ''.join(text).upper().strip()

    match = re.search(r'\b[A-D]\b', text_str)
    
    if match:
        return match.group()
    return None  


def extract_assistant_answers(raw_output: str) -> str:
    
    matches = re.findall(r'ASSISTANT:\s*(.*)', raw_output, re.IGNORECASE | re.DOTALL)
    
    if not matches:
        return "Assistant is not found"

    return matches[0]

def extract_Qwen_assistant_answers(raw_output: str) -> str:
    
    matches = re.findall(r'assistant\s*\n\s*(.*)', raw_output, re.IGNORECASE | re.DOTALL)
    
    if not matches:
        return "Assistant is not found"
    return matches[0]

import matplotlib.pyplot as plt
def plot_topk_softmax(logits_tensor, index, model, dataset, split, img_flag, k=10, save_dir=None):
    
    logits = logits_tensor.cpu().squeeze(0) 
    
    probs = F.softmax(logits, dim=-1)
    
    topk_probs, topk_indices = torch.topk(probs, k=min(k, len(probs)), dim=-1)
    topk_probs = topk_probs.numpy()[::-1]  
    topk_indices = topk_indices.numpy().astype(int)[::-1]
    
    # 
    decoded_labels = []
    for idx in topk_indices:
        decoded = model.processor.decode(torch.tensor([idx]), skip_special_tokens=True)
        decoded_labels.append(decoded.strip().replace('$', r'\$'))  
    
    plt.figure(figsize=(12, 6 + k*0.3))  #

    bars = plt.barh(np.arange(len(decoded_labels)),  
                    topk_probs,
                    color='#2ca02c',
                    edgecolor='black',
                    alpha=0.7,
                    height=0.8)
    
    bars[0].set_color('#d62728') 
    
    for i, (prob, label) in enumerate(zip(topk_probs, decoded_labels)):
        plt.text(prob + 0.01, i, 
                 f'{prob:.3f} ({label})',  
                 va='center', 
                 fontsize=9,
                 color='darkred')

    
    plt.title(f'Top-{k} Softmax Probabilities - Index {index}\n', fontsize=14)
    plt.xlabel('Probability', fontsize=12)
    plt.ylabel('Decoded Labels', fontsize=12)
    plt.yticks(np.arange(len(decoded_labels)), decoded_labels)  
    plt.xlim(0, 1.1)
    plt.grid(axis='x', linestyle='--', alpha=0.6)
    
    plt.tight_layout(rect=[0.05, 0, 0.75, 1]) if k > 5 else plt.tight_layout()
    
    save_path = f"./result/{dataset}/{split}/{model.name}/{img_flag}/logits"

    os.makedirs(save_path, exist_ok=True)

    # plt.savefig(f"{save_path}/top{k}_softmax_{index}.png", bbox_inches='tight', dpi=120)
    plt.savefig(f"{save_path}/test.png", bbox_inches='tight', dpi=120)
    print(f"Top-{k} plot saved to: {save_path}")
    plt.close()



def joint_prob(probs):
    if not probs:
        return 0.0
    return torch.prod(torch.tensor(probs)).item()


def joint_log_prob(probs):
    if not probs:
        return -float('inf')
    log_probs = torch.log(torch.tensor(probs))
    
    return torch.sum(log_probs).item()

import re

def extract_answer(question, raw_answer):
   # answerable: A
    match_answerable = re.search(r'(?i)answerable\s*:\s*([A-Z])', raw_answer)
    if match_answerable:
        return match_answerable.group(1).upper()

    question = question.split("Answer:")[0] if "Answer:" in question else question
    option_pattern = re.compile(
        r'\s*([A-Da-d])\s*[\.:\)\)]\s*(.*?)\s*(?=[A-Da-d]\s*[\.:\)\)]|\Z)', 
        re.DOTALL
    )
    matches = option_pattern.findall(question)
    options = {ltr.upper(): txt.strip() for ltr, txt in matches} if matches else {}

    # A.or A:
    direct_letter_match = re.match(r'^\s*([A-Da-d])(?:[\.:\)]|$)', raw_answer)
    if direct_letter_match:
        return direct_letter_match.group(1).upper()
    matched_positions = {}
    raw_lower = raw_answer.lower()
    for letter, text in options.items():
        pos = raw_lower.rfind(text.lower())
        if pos != -1:
            matched_positions[letter] = pos

    if len(matched_positions) == 1:
        return list(matched_positions.keys())[0]
    elif len(matched_positions) > 1:
        return max(matched_positions.items(), key=lambda x: x[1])[0]

    patterns = [
        re.compile(r'\(([A-D])\)', re.IGNORECASE),
        re.compile(r'\b([A-Da-d])\.\s*', re.IGNORECASE),
        re.compile(r'\banswer is ([A-D])\b', re.IGNORECASE),
        re.compile(r'\bThe answer is ([A-D])\b', re.IGNORECASE),
        re.compile(r'\bThe answer is: ([A-D])\b', re.IGNORECASE),
        re.compile(r'\b([A-D])\b(?=\s*[.:)])', re.IGNORECASE),
        re.compile(r'\bAnswer:[A-D]\b', re.IGNORECASE),
        re.compile(r'Answerability:\s*answerable\s*:\s*([A-D])', re.IGNORECASE),
    ]
    direct_letters = []
    for pattern in patterns:
        found = pattern.findall(raw_answer)
        for letter in found:
            letter = letter.upper()
            if not options or letter in options:
                direct_letters.append(letter)
        if direct_letters:
            break
    direct_letters = list(set(direct_letters))
    if len(direct_letters) == 1:
        return direct_letters[0]

    # yes/no
    yes_options = [letter for letter, text in options.items() if text.strip().lower() == 'yes']
    no_options = [letter for letter, text in options.items() if text.strip().lower() == 'no']
    if len(yes_options) == 1 and len(no_options) == 1:
        negation_pattern = re.compile(r'\b(not|no|n\'t|never|none|neither|nor)\b')
        has_negation = negation_pattern.search(raw_lower) is not None
        return no_options[0] if has_negation else yes_options[0]

    uncertainty_phrases = [
        "unable to determine", "not possible to determine", "cannot provide", "not sure", "not visible",
        "does not provide enough information", "does not provide any information",
        "does not provide information", "cannot answer", "not clear", "insufficient information",
        "unanswerable", "cannot be determined", "is not mentioned", "unknown"
    ]
    for phrase in uncertainty_phrases:
        if phrase in raw_lower:
            return "unanswerable"

    return raw_answer

def process_open_answer(raw_answer):
        uncertainty_phrases = [
            "unable to determine", "not possible to determine", "cannot provide", "not sure","not visible","unreadable","unanswerable","unquestionable",
            "does not provide enough information", "does not provide any information", "does not provide information", "cannot answer","unclear","unknownable",
            "not clear", "insufficient information", "unanswerable","cannot be determined","is not mentioned", "unknown","unsure","can't tell","can not tell"
        ]
        raw_lower = raw_answer.lower()
        for phrase in uncertainty_phrases:
            if phrase in raw_lower:
                return "unanswerable"
        return raw_answer.strip()   

def parse_binary_answer(raw_answer: str) -> int:
    raw_answer = raw_answer.strip().lower()
    match = re.search(r'\b([01])\b', raw_answer)
    if match:
        return int(match.group(1))

    if any(word in raw_answer for word in ["answerable", "yes", "definitely", "can be answered"]):
        return 1
    if any(word in raw_answer for word in ["unanswerable", "no", "cannot be answered", "can't answer"]):
        return 0
    return -1

