import re
import string
from copy import deepcopy
from collections import defaultdict


def normalize(s: str) -> str:
    """Lower text and remove punctuation, articles and extra whitespace."""
    s = s.lower()
    exclude = set(string.punctuation)
    s = "".join(char for char in s if char not in exclude)
    s = re.sub(r"\b(a|an|the)\b", " ", s)
    s = re.sub(r"\b(<pad>)\b", " ", s)
    s = " ".join(s.split())
    return s

def process_answers(answers_list):
    processed = []
    
    for ans in answers_list:
        ans = re.sub(r'#.*', '', ans).strip()
        all_ans_parts = ans.split("ans:")
        
        for part in all_ans_parts[1:]:
            clean = part.strip()

            if not clean:
                continue
                
            clean = re.sub(r'\s*\(.+', '', clean).strip()
            
            parts = clean.split()
            if parts and parts[0].lower() in ['visit']:
                clean = ' '.join(parts[1:])
                
            if not clean.strip():
                continue
                
            clean_normalized = normalize(clean)
            
            if "not available" in clean_normalized.lower():
                continue
                
            if clean_normalized:
                processed.append(clean_normalized)
                
    return list(set(processed)),processed


def match(s1: str, s2: str) -> bool:
    s1 = normalize(s1)
    s2 = normalize(s2)
    return s2 in s1

def compute_metrics(predictions, references):
    total_tp = 0
    total_fp = 0
    total_fn = 0
    macro_f1_sum = 0
    hit_count = 0
    hit1_count = 0
    sample_count = 0

    for pred, ref in zip(predictions, references):
        pred_processed,pred_processed1 = process_answers(pred)
        ref_normalized = [normalize(r) for r in ref]
        
        pred_sorted = sorted(pred_processed, key=lambda x: len(x), reverse=True)
        ref_normalized = sorted(ref_normalized, key=lambda x: len(x), reverse=True)
        frequency = defaultdict(int)
        for ans in pred_processed:
            frequency[ans] += 1

        pred_sorted1 = sorted(pred_processed1, key=lambda x: (-frequency[x]))
        hit = any(any(match(p, r) for p in pred_processed) for r in ref_normalized)
        if hit:
            hit_count += 1

        if pred_sorted1:
            first_pred = pred_sorted1[0]
            hit1 = any(match(first_pred, r) for r in ref_normalized)
            if hit1:
                hit1_count += 1

        current_pred = deepcopy(pred_sorted)
        matched_pred = 0
        matched_ref = 0
        
        for a in ref_normalized:
            for pred in current_pred:
                if match(pred, a):
                    matched_pred += 1
                    matched_ref += 1
                    current_pred.remove(pred)
                    break

        tp = matched_ref
        fp = len(current_pred)
        fn = len(ref_normalized) - matched_ref

        total_tp += tp
        total_fp += fp
        total_fn += fn

        precision = matched_pred / len(pred_sorted) if len(pred_sorted) > 0 else 0
        recall = matched_ref / len(ref_normalized) if len(ref_normalized) > 0 else 0
        denominator = precision + recall
        sample_f1 = (2 * precision * recall) / denominator if denominator != 0 else 0
        macro_f1_sum += sample_f1
        sample_count += 1

    precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) != 0 else 0
    recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) != 0 else 0
    micro_f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) != 0 else 0

    return {
        "Micro-F1": micro_f1,
        "Macro-F1": macro_f1_sum / sample_count,
        "Hit": hit_count / sample_count,
        "Hit@1": hit1_count / sample_count,
        "Precision": precision,
        "Recall": recall
    }
    
if __name__ == "__main__":
    import json
    data_path = "outputs/cwq/xxx.json"
    clean_data_path = "outputs/cwqnew/cleaned_xxx.json" 
    with open(data_path, 'r') as f:
        data = json.load(f)
    
    predictions = []
    references = []
    processed_data = {}
    for qid in data:
        q_info = data[qid]
        pred_ans = data[qid]['answers']
        ref_ans = data[qid]['a_entity']
        predictions.append(pred_ans)
        references.append(ref_ans)
        clean_pred,_ = process_answers(pred_ans)
        clean_ref = [normalize(r) for r in ref_ans]
        
        processed_data[qid] = {
            'question': q_info['question'],
            'a_entity': ref_ans, 
            'cleaned_a_entity': clean_ref,
            'answers': pred_ans,
            'cleaned_answers': clean_pred,

        }
    
    metrics = compute_metrics(predictions, references)
    print("Metrics:", metrics)
    output_path = clean_data_path
    with open(output_path, 'w') as f:
        json.dump(processed_data, f, ensure_ascii=False, indent=4)
    print(f"The cleaned data has been saved to {output_path}")
