import pandas as pd
from rank_bm25 import BM25Okapi
import json
from collections import defaultdict
from rapidfuzz import fuzz, process
import re
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from enum import Enum
from pydantic import BaseModel, Field, root_validator
from typing import List, Dict, Any, Optional
import instructor
from matplotlib.colors import LinearSegmentedColormap
import glob
import seaborn as sns
import os
import matplotlib.ticker as mtick
from matplotlib.lines import Line2D
from adjustText import adjust_text
from multiprocessing import Pool, cpu_count, Process
from eval import SR_EVAL, SRO_EVAL, cal_result_SR_EVAL, cal_result_SRO_EVAL, Gen_report_SR_EVAL, Gen_report_SRO_EVAL, cal_result_gen_report_SR_EVAL, cal_result_gen_report_SRO_EVAL


RELATIONS = ['cat', 'dx_status', 'dx_certainty', 'location', 'placement', 'associate', 'evidence',
'morphology', 'distribution', 'measurement', 'severity', 'comparison',
'onset', 'no change', 'improved', 'worsened', 'past hx', 'other source', 'assessment limitations']

def initialize_llm_client(llm_name, api_key):
    if llm_name.startswith('gpt'):
        from openai import OpenAI
        return instructor.from_openai(OpenAI(api_key=api_key), mode=instructor.Mode.JSON), None
    
    elif llm_name.startswith('deepseek') or llm_name.startswith('llama4') or llm_name.startswith('qwen3'):
        from openai import OpenAI
        return instructor.from_openai(OpenAI(api_key=api_key, base_url = "https://api.fireworks.ai/inference/v1"), mode=instructor.Mode.JSON), None

    elif llm_name.startswith('MedGemma'):
        import torch
        from transformers import AutoTokenizer, AutoModelForCausalLM
        model_cache_dir = "../medgemma"
        model_id = "google/medgemma-27b-text-it"
        access_token = api_key
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            cache_dir=model_cache_dir,
            token=access_token
        )

        tokenizer = AutoTokenizer.from_pretrained(
            model_id,
            cache_dir=model_cache_dir,
            token=access_token
        )
        return model, tokenizer

    elif llm_name.startswith('baichuan') or llm_name.startswith('vllm_medgemma') or llm_name.startswith('vllm_gpt-oss-120b') or llm_name.startswith('vllm_gpt-oss-20b'):
        from openai import OpenAI
        client = OpenAI(
            api_key="fake-key",
            base_url=f"http://localhost:{os.getenv('PORT')}/v1",
        )
        return instructor.from_openai(client, mode=instructor.Mode.JSON), None


def transform_triplets(triplets):
    """Transform triplets to a list of strings"""
    
    import random
    transformed_triplets = []
    for triplet in triplets:
        
        choice = random.randint(0, 2)
        
        if choice == 0:
            # Transform 1: Change dx_status p to n, dx_certainty d to t
            if triplet[1] == 'dx_status' and triplet[2] == 'positive':
                transformed_triplets.append((triplet[0], 'dx_status', 'negative'))
            elif triplet[1] == 'dx_certainty' and triplet[2] == 'definitive':
                transformed_triplets.append((triplet[0], 'dx_certainty', 'tentative'))
            elif triplet[0] and ' ' in triplet[0]:
                # Delete a random word from entity
                words = triplet[0].split()
                word_to_remove = random.choice(words)
                modified_entity = ' '.join([w for w in words if w != word_to_remove])
                transformed_triplets.append((modified_entity, triplet[1], triplet[2]))
            elif triplet[2] and ' ' in str(triplet[2]):
                # Delete a random word from value
                words = str(triplet[2]).split()
                word_to_remove = random.choice(words)
                modified_value = ' '.join([w for w in words if w != word_to_remove])
                # Remove commas from the beginning and end of modified_value
                modified_value = modified_value.strip(',')
                transformed_triplets.append((triplet[0], triplet[1], modified_value))
            else:
                transformed_triplets.append((triplet[0], triplet[1], triplet[2]))
        elif choice == 1:
            transformed_triplets.append((triplet[0], triplet[1], triplet[2]))
        elif choice == 2:
            continue
    
    return transformed_triplets

def get_fewshot(query, df, args, vocab, vocab_lookup, vocab_keys):

    ENTITY_CATEGORY = ['pf', 'cf', 'cof', 'cof/ncd', 'oth', 'ncd', 'pf/cf', 'patient info.']

    json_data = {
        'report_section': []
    }

    section_order = {'hist': 0, 'find': 1, 'impr': 2}
    
    df = df.sort_values(by='section', key=lambda x: x.map(section_order))

    df_idx_removed = df.copy()
    
    if args.candidate_type == 'no_candidates':
        max_sent_idx = df_idx_removed['sent_idx'].max()
        for sent_idx in range(1, max_sent_idx + 1):
            df_for_gt_sro = df_idx_removed[df_idx_removed['sent_idx'] == sent_idx]
            json_data['report_section'].append({
                'sent_idx': sent_idx,
                'sentence': df_for_gt_sro['sent'].iloc[0],
            })
        json_input = json.dumps(json_data, indent=4)
        user_string = f"INPUT:\n{json_input}\n"
    else:
        if args.unit == 'sent':
            if args.candidate_type == 'vocab_ent_rcg':
                candidates = list(set(get_words(query, vocab, vocab_lookup, vocab_keys, args)))
                word_cat_list = []
                for word in candidates:
                    cat_list = list(set(vocab[vocab['target_term'] == word]['category']))
                    word_cat_list.extend([(word, cat_list)])
                json_data['report_section'].append({
                            'sent_idx': 1,
                            'sentence': query,
                            'candidates': word_cat_list})
                json_input = json.dumps(json_data, indent=4)
                user_string = f"INPUT:\n{json_input}\n"
            
        else:

            if 'gt' in args.candidate_type:
                max_sent_idx = df_idx_removed['sent_idx'].max()
                for sent_idx in range(1, max_sent_idx + 1):
                    df_for_gt_sro = df_idx_removed[df_idx_removed['sent_idx'] == sent_idx]

                    gt = []

                    for _, row in df_for_gt_sro.iterrows():
                        triplet_list, _, _ = extract_triplets(row)
                        if args.candidate_type == 'gt_sro_review':
                            gt = transform_triplets(triplet_list)
                        elif args.candidate_type == 'gt_sro':
                            # Create a set to track unique triplets
                            unique_triplets = []
                            seen = set()
                            for triplet in triplet_list:
                                # Create a hashable representation of the triplet
                                triplet_key = (triplet[0], triplet[1], triplet[2])
                                if triplet_key not in seen:
                                    seen.add(triplet_key)
                                    unique_triplets.append(triplet_key)

                            gt.extend(unique_triplets)

                        elif args.candidate_type == 'gt_so':
                            for triplet in triplet_list:
                                gt.extend([(triplet[0])])
                                if triplet[1] not in ['cat', 'dx_status', 'dx_certainty']:
                                    gt.extend([(triplet[2])])
                        elif args.candidate_type == 'gt_s':
                            for triplet in triplet_list:
                                gt.extend([(triplet[0])])
                            
                        elif args.candidate_type == 'gt_ent_rcg':
                            for triplet in triplet_list:
                                ent_cat_list = vocab[vocab['target_term'] == triplet[0]]['category'].tolist()
                                gt.extend([(triplet[0], ent_cat_list)])
                                if triplet[1] not in ['cat', 'dx_status', 'dx_certainty', 'associate', 'evidence']:
                                    gt.extend([(triplet[2], [triplet[1]])])

                    if args.candidate_type in ['gt_sro', 'gt_sro_review']:
                        gt = list(set(tuple(triplet) for triplet in gt))
                    elif args.candidate_type == 'gt_so' or args.candidate_type == 'gt_s':
                        gt = list(dict.fromkeys(gt))
                    elif args.candidate_type == 'gt_ent_rcg':
                        # Remove duplicates while preserving category information
                        unique_entities = {}
                        for entity, categories in gt:
                            if entity in unique_entities:
                                # If entity already exists, merge the categories
                                unique_entities[entity] = list(set(unique_entities[entity] + categories))
                            else:
                                unique_entities[entity] = categories
                        # Convert back to list of tuples
                        gt = [(entity, categories) for entity, categories in unique_entities.items()]

                    json_data['report_section'].append({
                        'sent_idx': sent_idx,
                        'sentence': df_for_gt_sro['sent'].iloc[0],
                        'candidates': gt
                    })

                json_input = json.dumps(json_data, indent=4)
                user_string = f"INPUT:\n{json_input}\n"

            else:
                max_sent_idx = df['sent_idx'].max()
                
                for sent_idx in range(1, max_sent_idx + 1):
                    sent = df[df['sent_idx'] == sent_idx]['sent'].iloc[0]
                    candidates = list(set(get_words(sent.strip(), vocab, vocab_lookup, vocab_keys, args)))

                    if args.candidate_usage != 1:
                        target_length = int(len(candidates) * args.candidate_usage)
                        step = len(candidates) // max(target_length, 1)
                        candidates = candidates[::step][:target_length]

                    if args.candidate_type == 'vocab_so':
                        json_candidates = candidates
                    elif args.candidate_type == 'vocab_s':
                        ent_words = []
                        for word in candidates:
                            cat_list = list(set(vocab[vocab['target_term'] == word]['category']))
                            if all(cat in ENTITY_CATEGORY for cat in cat_list):
                                ent_words.append(word)
                        json_candidates = ent_words
                    elif args.candidate_type == 'vocab_ent_rcg':
                        word_cat_list = []
                        for word in candidates:
                            cat_list = list(set(vocab[vocab['target_term'] == word]['category']))
                            new_cat_list = []
                            for cat in cat_list:
                                if cat in ENTITY_CATEGORY:
                                    new_cat_list.append('entity')
                                else:
                                    new_cat_list.append(cat)
                            word_cat_list.extend([(word, new_cat_list)])
                        json_candidates = word_cat_list

                    json_data['report_section'].append({
                        'sent_idx': sent_idx,
                        'sentence': sent,
                        'candidates': json_candidates
                    })

                json_input = json.dumps(json_data, indent=4)
                user_string = f"INPUT:\n{json_input}\n"

    entities = []
    current_study_id = df['study_id'].iloc[0]
    ent_idx_counter = 1

    for _, row in df.iterrows():
        if row['cat'].upper() not in args.entity_types:
            continue
        if row['study_id'] != current_study_id:
            continue

        relations = []
        for attr in ['Cat', 'Dx_Status', 'Dx_Certainty'] + args.relation_types + args.attribute_types:
            attr_key = attr.lower()
            if pd.notna(row[attr_key]):
                if attr in ['Associate', 'Evidence']:
                    # Split value like "pneumonia, obj_ent_idx2, effusion, obj_ent_idx3"
                    tokens = [t.strip() for t in row[attr_key].split(',')]
                    for i in range(0, len(tokens), 2):
                        value = tokens[i]
                        if i+1 < len(tokens) and tokens[i+1].startswith('obj_ent_idx'):
                            obj_ent_idx = int(tokens[i+1].replace('obj_ent_idx', ''))
                            relations.append({
                                "relation": attr,
                                "value": value,
                                "obj_ent_idx": obj_ent_idx
                            })
                else:
                    value = row[attr_key]
                    if attr == 'Location' and isinstance(value, str) and value.lower().startswith("loc:"):
                        value = value[4:].strip()
                    if attr in ['Cat', 'Dx_Status', 'Dx_Certainty']:
                        value = value.upper()
                    relations.append({
                        "relation": attr,
                        "value": value
                    })

        entity_obj = {
            "name": row['ent'],
            "sent_idx": int(row['sent_idx']),
            "ent_idx": int(row['ent_idx']),
            "relations": relations
        }
        entities.append(entity_obj)

    assistant_string = "OUTPUT: " + json.dumps({"entities": entities}, indent=2)

    return user_string, assistant_string

def get_fewshot_multi_turn(query, df, args, vocab, vocab_lookup, vocab_keys):

    ENTITY_CATEGORY = ['pf', 'cf', 'cof', 'cof/ncd', 'oth', 'ncd', 'pf/cf', 'patient info.']

    user_string_list = []
    assistant_string_list = []
    
    section_order = {'hist': 0, 'find': 1, 'impr': 2}
    
    df = df.sort_values(by='section', key=lambda x: x.map(section_order))

    df_idx_removed = df.copy()
    
    if args.candidate_type == 'no_candidates':
        max_sent_idx = df_idx_removed['sent_idx'].max()
        for sent_idx in range(1, max_sent_idx + 1):
            df_for_gt_sro = df_idx_removed[df_idx_removed['sent_idx'] == sent_idx]
            json_data['report_section'].append({
                'sent_idx': sent_idx,
                'sentence': df_for_gt_sro['sent'].iloc[0],
            })
        json_input = json.dumps(json_data, indent=4)
        user_string = f"INPUT:\n{json_input}\n"
    else:
        max_sent_idx = df['sent_idx'].max()
        
        for sent_idx in range(1, max_sent_idx + 1):
            
            json_data = {
                'report_section': []
            }
            sent = df[df['sent_idx'] == sent_idx]['sent'].iloc[0]
            candidates = list(set(get_words(sent.strip(), vocab, vocab_lookup, vocab_keys, args)))
            if args.candidate_usage != 1:
                target_length = int(len(candidates) * args.candidate_usage)
                step = len(candidates) // max(target_length, 1)
                candidates = candidates[::step][:target_length]
            if args.candidate_type == 'vocab_so':
                json_candidates = candidates
            elif args.candidate_type == 'vocab_s':
                ent_words = []
                for word in candidates:
                    cat_list = list(set(vocab[vocab['target_term'] == word]['category']))
                    if all(cat in ENTITY_CATEGORY for cat in cat_list):
                        ent_words.append(word)
                json_candidates = ent_words
            elif args.candidate_type == 'vocab_ent_rcg':
                word_cat_list = []
                for word in candidates:
                    cat_list = list(set(vocab[vocab['target_term'] == word]['category']))
                    word_cat_list.extend([(word, cat_list)])
                json_candidates = word_cat_list
            json_data['report_section'].append({
                'sent_idx': sent_idx,
                'sentence': sent,
                'candidates': json_candidates
            })
            json_input = json.dumps(json_data, indent=4)
            user_string = f"INPUT:\n{json_input}\n"
            user_string_list.append(user_string)


   
    current_study_id = df['study_id'].iloc[0]
    ent_idx_counter = 1

    for sent_idx in range(1, max_sent_idx + 1):
        
        df_idx = df[df['sent_idx'] == sent_idx]
        entities = []
        for _, row in df_idx.iterrows():
            if row['cat'].upper() not in args.entity_types:
                continue
            if row['study_id'] != current_study_id:
                continue

            relations = []
            for attr in ['Cat', 'Dx_Status', 'Dx_Certainty'] + args.relation_types + args.attribute_types:
                attr_key = attr.lower()
                if pd.notna(row[attr_key]):
                    if attr in ['Associate', 'Evidence']:
                        # Split value like "pneumonia, obj_ent_idx2, effusion, obj_ent_idx3"
                        tokens = [t.strip() for t in row[attr_key].split(',')]
                        for i in range(0, len(tokens), 2):
                            value = tokens[i]
                            if i+1 < len(tokens) and tokens[i+1].startswith('obj_ent_idx'):
                                obj_ent_idx = int(tokens[i+1].replace('obj_ent_idx', ''))
                                relations.append({
                                    "relation": attr,
                                    "value": value,
                                    "obj_ent_idx": obj_ent_idx
                                })
                    else:
                        value = row[attr_key]
                        if attr == 'Location' and isinstance(value, str) and value.lower().startswith("loc:"):
                            value = value[4:].strip()
                        if attr in ['Cat', 'Dx_Status', 'Dx_Certainty']:
                            value = value.upper()
                        relations.append({
                            "relation": attr,
                            "value": value
                        })

            entity_obj = {
                "name": row['ent'],
                "sent_idx": int(row['sent_idx']),
                "ent_idx": int(row['ent_idx']),
                "relations": relations
            }
            entities.append(entity_obj)

        assistant_string = "OUTPUT: " + json.dumps({"entities": entities}, indent=2)
        assistant_string_list.append(assistant_string)
        
    return user_string_list, assistant_string_list

def preprocess_text(text):
    """텍스트를 소문자로 변환하고, 일부 구두점을 공백으로 대체한 후 불필요한 공백 제거"""
    text = str(text).lower()
    for p in ['.', ',', '-', '/', '!', '?', '(', ')', '"', "'"]:
        text = text.replace(p, ' ')
    return ' '.join(text.split()).strip()

def find_subsequence_indices(pattern_tokens, tokens, start=0):
    """
    pattern_tokens가 tokens 내에서 순서대로 나타나는 모든 인덱스 시퀀스(튜플)를 재귀적으로 반환.
    예) pattern_tokens=["a", "b"]와 tokens=["a", "x", "b", "a", "b"] → (0,2)와 (3,4)를 반환.
    """
    if not pattern_tokens:
        return [()]
    results = []
    first = pattern_tokens[0]
    for i in range(start, len(tokens)):
        if tokens[i] == first:
            for tail in find_subsequence_indices(pattern_tokens[1:], tokens, i+1):
                results.append((i,) + tail)
    return results

def find_subsequence_indices_with_window(pattern_tokens, tokens, start=0, window=10):
    """
    pattern_tokens가 tokens 내에서 순서대로 나타나는 모든 인덱스 시퀀스(튜플)를 재귀적으로 반환.
    window 파라미터로 각 토큰 기준 앞뒤로 검색할 범위를 제한함.
    예) pattern_tokens=["a", "b"]와 tokens=["a", "x", "b", "a", "b"], window=1 → (3,4)를 반환.
    """
    if not pattern_tokens:
        return [()]
    
    results = []
    first = pattern_tokens[0]
    
    # 시작 위치부터 토큰 리스트 끝까지 순회
    for i in range(start, len(tokens)):
        if tokens[i] == first:
            # 현재 토큰이 첫번째 패턴과 일치하면
            # 다음 패턴을 찾을 때는 현재 위치+1부터 window 크기 내에서만 검색
            next_start = i + 1
            next_end = min(i + 1 + window, len(tokens))
            # 재귀적으로 나머지 패턴 검색 - window 범위 내에서만
            for tail in find_subsequence_indices(pattern_tokens[1:], tokens[next_start:next_end], 0):
                # tail의 인덱스를 원래 위치로 조정
                adjusted_tail = tuple(idx + next_start for idx in tail)
                results.append((i,) + adjusted_tail)
    
    return results

def map_tokens_to_original(original_sent, tokens):
    """
    preprocessed 토큰 리스트(tokens)를 받아 원본 문장에서 각 토큰의 (token_text, start, end)를 반환.
    re.search로 매칭하며, 매칭 실패시 fallback으로 현재 위치부터 token 길이만큼 진행.
    """
    original_tokens = []
    current_pos = 0
    for token, _, _ in tokens:
        pattern = re.escape(token)
        match = re.search(pattern, original_sent[current_pos:], flags=re.IGNORECASE)
        if match:
            real_start = current_pos + match.start()
            token_end  = current_pos + match.end()
            original_tokens.append((original_sent[real_start:token_end], real_start, token_end))
            current_pos = token_end
        else:
            original_tokens.append((token, current_pos, current_pos + len(token)))
            current_pos += len(token)
    return original_tokens

def tokenize_sentence(sent):
    """
    문장을 받아서 preprocessed 문장, preprocessed 토큰(각각 (token, start, end)),
    원본상의 토큰 위치 리스트, 그리고 단순 토큰 문자열 리스트를 반환.
    """
    preprocessed = preprocess_text(sent) # . , --> 공백으로 변환환
    tokens = [(m.group(), m.start(), m.end()) for m in re.finditer(r'\S+', preprocessed)] # 단어, 시작, 끝
    original_tokens = map_tokens_to_original(sent, tokens) # 원본 문장에서 각 단어의 (token_text, start, end)를 반환.
    token_texts = [t[0] for t in tokens] # 단어만 가져온 리스트
    return preprocessed, tokens, original_tokens, token_texts

def get_vocab_lookup(vocab):
    # --- Vocabulary Lookup 초기화 (fuzzy continuous 매칭용) ---
    # 전처리된 용어를 key로, 해당 key에 매칭되는 정보를 리스트로 저장
    vocab_lookup = defaultdict(list)
    fields_to_use = ['target_term']

    #if 'raw_term' in vocab.columns:
    #    fields_to_use.append('raw_term')

    for idx, row in vocab.iterrows():
        for field in fields_to_use:
            if field in row and pd.notna(row[field]):
                term = str(row[field])
                processed_term = preprocess_text(term)
                if processed_term:
                    vocab_lookup[processed_term].append({
                        'matched_term': term,   # 추후 matched_word로 사용됨
                        'source_field': field,
                        'target_term': row.get('target_term', None),
                        'category': row.get('category', None),
                        'normed_term': row.get('normed_term', None)
                    })

    vocab_keys = list(vocab_lookup.keys())

    return vocab_lookup, vocab_keys

def find_fuzzy_continuous_matches(text, vocab_lookup, vocab_keys, fuzzy_threshold=100):
    """
    For all continuous candidate spans of the original text,
    calculate the fuzzy similarity between the preprocessed candidate and the keys in vocab_lookup using RapidFuzz's fuzz.ratio.
    If the similarity is greater than or equal to fuzzy_threshold, add the match to the result.
      - 'word': original text span
      - 'matched_word': matched vocabulary word
    """

    if pd.isna(text):
        return []
    
    original_text = text # 
    _, tokens, original_tokens, _ = tokenize_sentence(text)
    matches = []
    seen_matches = set() 
    n = len(tokens)
    
    # Generate all possible continuous spans of original_tokens
    # text = "patient has known pulmonary fibrosis with interstitial abnormalities, larger in the lower lobes bilaterally."
    # original_tokens = [('patient', 0, 7), ('has', 8, 11), ('known', 12, 17), ('pulmonary', 18, 27), ('fibrosis', 28, 36), ('with', 37, 41), ('interstitial', 42, 54), ('abnormalities', 55, 68), ('larger', 70, 76), ('in', 77, 79), ('the', 80, 83), ('lower', 84, 89), ('lobes', 90, 95), ('bilaterally', 96, 107)]

    candidate_spans = []
    for length in range(n, 0, -1): 
        for i in range(n - length + 1): 
            start_pos = original_tokens[i][1]
            end_pos = original_tokens[i + length - 1][2]
            span_text = original_text[start_pos:end_pos].strip()
            processed_span = preprocess_text(span_text)
            if processed_span:
                candidate_spans.append((span_text, processed_span, start_pos, end_pos))
    
    for span_text, processed_span, start_pos, end_pos in candidate_spans:
        fuzzy_result = process.extractOne(processed_span, vocab_keys, scorer=fuzz.ratio)
        if fuzzy_result and fuzzy_result[1] >= fuzzy_threshold:
            best_key, score, _ = fuzzy_result
            for entry in vocab_lookup[best_key]:
                key = (start_pos, end_pos, entry['source_field'], entry['matched_term'].lower(), entry['category'], 'fuzzy_continuous')
                if key not in seen_matches:
                    matches.append({
                        'word': span_text,                           # original text span
                        'matched_word': entry['matched_term'],       # matched vocabulary word
                        'source_field': entry['source_field'],       # source field (raw_term or target_term)
                        'target_term': entry['target_term'],         # target term
                        'category': entry['category'],               # category
                        'normed_term': entry['normed_term'],         # normed term
                        'start': start_pos,                          # start position of the span in the original text
                        'end': end_pos,                              # end position of the span in the original text
                        'match_type': 'fuzzy_continuous',            # match type
                        'fuzzy_score': score                         # fuzzy score
                    })
                    seen_matches.add(key)

    return matches

def find_fuzzy_discontinuous_matches(sent, vocab_df, fuzzy_threshold=90):
    """
    각 vocabulary 용어(예, target_term)에 대해 전처리된 토큰 리스트를 구하고,
    문장 내 토큰 순서를 유지하면서 불연속 매칭(candidate index sequence)를 찾는다.
    후보들 중 원본 텍스트에서의 span 길이가 최소인 것을 선택한 후,
    preprocessed candidate와 vocabulary 용어 간 fuzzy 유사도를 계산하여
    유사도가 fuzzy_threshold 이상이면 결과에 포함.
      - 'word': 원본 텍스트에서 추출한 정확한 span
      - 'matched_word': 매칭된 vocabulary 용어
    """
    if pd.isna(sent):
        return []
    
    original_sent = sent
    _, _, original_tokens, token_texts = tokenize_sentence(sent)
    matches = []
    seen_matches = set()
    
    # Sort vocab_df by the length of target_term in descending order
    sorted_vocab_df = vocab_df.copy().sort_values(by='target_term', key=lambda x: x.str.len(), ascending=False)
    
    for _, row in sorted_vocab_df.iterrows():

        term = row['target_term']
        
        if pd.isna(term):
            continue

        processed_vocab = preprocess_text(term)
        pattern_tokens = processed_vocab.split() # split terms in vocabulary, ex) left lung -> ['left', 'lung']
        #if len(pattern_tokens) < 2:
        #    continue
        
        # Find all subsequence indices of pattern_tokens in token_texts with a window of 10
        subseq_indices = find_subsequence_indices_with_window(pattern_tokens, token_texts, window=5)
        if not subseq_indices:
            continue
        
        # Select the candidate with the smallest span length
        best_seq, best_length = None, None
        for seq in subseq_indices:
            if len(seq) == 1:
                start_pos = original_tokens[seq[0]][1]
                end_pos = original_tokens[seq[0]][2]
            else:
                start_pos = original_tokens[seq[0]][1]
                end_pos = original_tokens[seq[-1]][2]
            length = end_pos - start_pos
            if best_seq is None or length < best_length:
                best_seq, best_length = seq, length

        # text = "left lower lung" and pattern_tokens = ['left', 'lung']
        # best_seq = (0, 2)

        if best_seq is not None:
            # preprocessed candidate (for fuzzy comparison)
            candidate_text = " ".join(token_texts[i] for i in best_seq)
            # actual original text span
            if len(best_seq) == 1:
                start_pos = original_tokens[best_seq[0]][1]
                end_pos = original_tokens[best_seq[0]][2]
            else:
                start_pos = original_tokens[best_seq[0]][1]
                end_pos = original_tokens[best_seq[-1]][2]
            candidate_original = original_sent[start_pos:end_pos].strip()
            fuzzy_score = fuzz.ratio(candidate_text, processed_vocab)
            if fuzzy_score >= fuzzy_threshold:
                key = (candidate_original.lower(), row.get('category', None), 'fuzzy_discontinuous')
                if key not in seen_matches:
                    matches.append({
                        'word': candidate_original,    # original text span
                        'matched_word': term,          # matched vocabulary term
                        'source_field': 'target_term',
                        'target_term': term,
                        'category': row.get('category', None),
                        'normed_term': row.get('normed_term', None),
                        'start': start_pos,
                        'end': end_pos,
                        'match_type': 'fuzzy_discontinuous',
                        'fuzzy_score': fuzzy_score
                    })
                    seen_matches.add(key)
    return matches

def process_with_all_fuzzy_matches(sent, vocab, vocab_lookup, vocab_keys, fuzzy_threshold=90, allow_overlap=True):
    """
    한 문장에 대해 연속(fuzzy_continuous) 및 불연속(fuzzy_discontinuous) 매칭을 모두 수행.
    각 결과에서
      - 'word'는 원본 텍스트상의 실제 span,
      - 'matched_word'는 매칭된 target 단어를 나타냄.
    allow_overlap이 False이면 겹치는 span은 긴 span 우선으로 필터링.
    """
    continuous_matches = find_fuzzy_continuous_matches(sent, vocab_lookup, vocab_keys, fuzzy_threshold)
    discontinuous_matches = find_fuzzy_discontinuous_matches(sent, vocab, fuzzy_threshold)
    all_matches = continuous_matches + discontinuous_matches
    
    return all_matches

def get_target_term_positions(sent_term, target_term, sent_start, sent_end):
    # Convert to lowercase for case-insensitive matching
    sent_term_lower = sent_term.lower()
    target_term_lower = target_term.lower()
    
    # Find exact match position
    target_pos = sent_term_lower.find(target_term_lower)
    
    if target_pos != -1:
        # If exact match found, calculate absolute positions
        target_start = sent_start + target_pos
        target_end = target_start + len(target_term)
    else:
        # If no exact match, find longest matching substring
        longest_match = 0
        match_start = 0
        match_end = 0
        
        # Compare characters to find longest matching substring
        for i in range(len(sent_term_lower)):
            for j in range(i+1, len(sent_term_lower)+1):
                substring = sent_term_lower[i:j]
                if substring in target_term_lower:
                    if len(substring) > longest_match:
                        longest_match = len(substring)
                        match_start = i
                        match_end = j
        
        # Calculate absolute positions based on matching substring
        target_start = sent_start + match_start
        target_end = sent_start + match_end
        
    return target_start, target_end

def get_overlap_status(target_start, target_end, used_target_start, used_target_end):
    
    if target_start > used_target_end or target_end < used_target_start:
        return False
    else:
        return True

def get_words(text, vocab, vocab_lookup, vocab_keys, args):

    if args.candidate_discontinuous:
        vocab_match = find_fuzzy_discontinuous_matches(text, vocab, fuzzy_threshold=100)
    else:
        vocab_match = find_fuzzy_continuous_matches(text, vocab_lookup, vocab_keys, fuzzy_threshold=100)
        
    # Get all target terms from vocab values and convert to lowercase
    sent_terms = []

    for match in vocab_match:
        # Get sent term positions
        org_start = match['start']
        org_end = match['end']
        org_term = str(match['word']).lower()
        target_term = str(match['matched_word']).lower()
        
        if any(char.isdigit() for char in target_term):
            continue

        # Get target term positions
        target_start, target_end = get_target_term_positions(org_term, target_term, org_start, org_end)
        
        sent_terms.append([
            org_term, 
            target_term,
            org_start, 
            org_end,
            target_start,
            target_end
        ])

    # Filter out contained terms
    filtered_sent_terms = []
    used_positions = set()
    
    # Sort by length (descending) to check longer terms first
    if args.candidate_discontinuous:
        sent_terms_sorted = sorted(sent_terms, key=lambda x: x[3]-x[2], reverse=True) # x[3] - x[2] = org_end - org_start
    else:
        sent_terms_sorted = sorted(sent_terms, key=lambda x: x[5]-x[4], reverse=True) # x[5] - x[4] = target_end - target_start
    
    for sent_term in sent_terms_sorted:
        
        if args.candidate_discontinuous:
            org_start, org_end = sent_term[2], sent_term[3]
            
            if any((org_start, org_end) == (used_org_start, used_org_end) for used_org_start, used_org_end in used_positions):
                continue
            
            if not any(get_overlap_status(org_start, org_end, used_org_start, used_org_end) for used_org_start, used_org_end in used_positions):
                filtered_sent_terms.append(sent_term)
                used_positions.add((org_start, org_end))
        else:
            target_start, target_end = sent_term[4], sent_term[5]
        
            # Check if this term's position is already used
            if any((target_start, target_end) == (used_target_start, used_target_end) for used_target_start, used_target_end in used_positions):
                continue

            # Check if this term is contained within any other term
            if not any(get_overlap_status(target_start, target_end, used_target_start, used_target_end) # 한 번 뽑은 span 다시 뽑지 않도록 and에서 or로 수정. 2025.03.31
                      for used_target_start, used_target_end in used_positions):
                filtered_sent_terms.append(sent_term)
                used_positions.add((target_start, target_end))
    
    filtered_terms = [term[1] for term in filtered_sent_terms]

    return filtered_terms

def get_cols(query_words, vocab):
    
    cols = {'location', 'evidence', 'associate'}
    entity_cats = set()
    entities = set()

    clinical_categories = ['comparison', 'distribution', 'improved', 'location', 
                         'measurement', 'morphology', 'no change', 'onset', 
                         'other source', 'past hx', 'placement',
                         'severity', 'assessment limitations', 'worsened']

    for word in query_words:
        categories = set(vocab[vocab['target_term'] == word]['category'])
        for category in categories:
            if category in clinical_categories:
                cols.add(category)
            else:
                if category in ['lf', 'pf']:
                    entity_cats.add('pf')
                else:
                    entity_cats.add(category)
                entities.add(word)

    return list(cols), list(entity_cats), list(entities)


def get_n_retrieval(query, result, vocab, devset, args):

    devset = devset.copy()
    
    ENTITY_CATEGORY = ['pf', 'cf', 'cof', 'cof/ncd', 'oth', 'ncd', 'pf/cf', 'patient info.']
    
    section_key = 'section_report' if args.unit == 'section' else 'sent'
    diverse_result = pd.DataFrame()
    vocab_lookup, vocab_keys = get_vocab_lookup(vocab)
    query_words_dict = {}

    ###### New User Input에 대한 JSON INPUT 생성 ######
    # 단, candidate가 vocab 기반으로 추출된다.

    json_vocab_input = {
        'report_section': []
    }

    query_words = []
    if section_key == 'sent':
        idx = re.findall(r'\((\d+)\)', query, re.DOTALL)
        query_words = list(set(get_words(query, vocab, vocab_lookup, vocab_keys, args)))
        
        if args.candidate_type == 'vocab_ent_rcg':
            word_cat_list = []
            for word in query_words:
                cat_list = list(set(vocab[vocab['target_term'] == word]['category']))
                word_cat_list.extend([(word, cat_list)])
            candidates = word_cat_list
            
        json_vocab_input['report_section'].append({
            'sent_idx': idx,
            'sentence': query,
            'candidates': candidates
        })
        
    else:
        # Extract numbered sentences from the query (format: "(1) text (2) text...")
        sentences = re.findall(r'\(\d+\)\s*(.*?)(?=\s*\(\d+\)|$)', query, re.DOTALL)
        sentences = [s.strip() for s in sentences if s.strip()]

        for idx, sentence in enumerate(sentences):
            
            all_words = list(set(get_words(sentence, vocab, vocab_lookup, vocab_keys, args)))
            
            query_words.extend(all_words)
            
            if ('gt' in args.candidate_type) or (args.candidate_type == 'no_candidates'): # 문제가 있어 보인다.
                candidates = all_words
            
            elif args.candidate_type == 'vocab_so':
                candidates = all_words
            
            elif args.candidate_type == 'vocab_s': # entity category만 준다.
                ent_words = []
                for word in all_words:
                    cat_list = list(set(vocab[vocab['target_term'] == word]['category']))
                    if all(cat in ENTITY_CATEGORY for cat in cat_list):
                        ent_words.append(word)
                candidates = ent_words
            
            elif args.candidate_type == 'vocab_ent_rcg':
                word_cat_list = []
                for word in all_words:
                    cat_list = list(set(vocab[vocab['target_term'] == word]['category']))
                    new_cat_list = []
                    for cat in cat_list:
                        if cat in ENTITY_CATEGORY:
                            new_cat_list.append('entity')
                        else:
                            new_cat_list.append(cat)
                            
                    word_cat_list.extend([(word, new_cat_list)])
                    
                candidates = word_cat_list

            
            if args.candidate_usage != 1:
                target_length = int(len(candidates) * args.candidate_usage)
                step = len(candidates) // max(target_length, 1)
                candidates = candidates[::step][:target_length]
                        
            json_vocab_input['report_section'].append({
                'sent_idx': idx+1,
                'sentence': sentence,
                'candidates': candidates
            })
    
    query_words = list(set(query_words))
    query_cols, query_entity_cats, query_entities = get_cols(query_words, vocab)

    # Convert sets to sorted lists for deterministic order
    query_cols = sorted(query_cols)
    query_entity_cats = sorted(query_entity_cats)
    query_entities = sorted(query_entities)
    
    # result의 sent 순서대로 devset을 정렬하고 중복 sent 확인
    ordered_sections = result['section'].unique().tolist()
    devset.loc[:, f'{section_key}_order'] = devset[f'{section_key}'].map({s: i for i, s in enumerate(ordered_sections)})
    
    # 각 section_report에 query_entities가 얼마나 포함되어 있는지 계산
    if len(query_entities) > 0:
        # 각 entity가 section_report에 포함되어 있는지 확인하는 함수
        def count_entities(text):
            return sum(1 for entity in query_entities if entity in text)
        
        # section_report에 포함된 entity 수를 계산하여 새 컬럼 추가
        devset.loc[:, 'entity_count'] = devset[section_key].apply(count_entities)
        
        # 먼저 entity가 많이 포함된 순서대로 정렬하고, 그 다음 section_order로 정렬
        devset = devset.sort_values(['entity_count', f'{section_key}_order'], ascending=[False, True])
    else:
        # entity가 없는 경우 section_order로만 정렬
        devset = devset.sort_values(f'{section_key}_order')
    
    # 정렬 후 임시 컬럼 제거
    devset = devset.drop(['entity_count', f'{section_key}_order'], axis=1, errors='ignore')
    
    
    seen_sections = set()
    shot_count = 0
    
    # Process columns first, then entity categories
    while len(query_cols) > 0 or len(query_entities) > 0 or shot_count < args.n_retrieval:
        section_list = []
        
        if len(query_entities) > 0:
            entity = query_entities[0]  # Take first element instead of random pop
            query_entities = query_entities[1:]  # Remove first element
            section_list = list(devset[devset['ent'] == entity][section_key].unique())
        elif len(query_cols) > 0:
            col = query_cols[0]  # Take first element instead of random pop
            query_cols = query_cols[1:]  # Remove first element
            section_list = list(devset[devset[col].notna()][section_key].unique())
        elif shot_count < args.n_retrieval:
            section_list = list(devset[section_key].unique())

        if len(section_list) == 0:
            continue

        # Sort section_list for deterministic selection
        selected_section = None
        for section in section_list:
            if section not in seen_sections:
                selected_section = section
                seen_sections.add(section)
                break

        if selected_section is None:
            continue
        
        if section_key == 'sent':
            first_group_keys = devset[devset[section_key] == selected_section][['subject_id', 'study_id', 'sequence', 'sent', 'section']].iloc[0]

            # Take the first group's all rows after grouping
            devset_col = devset[
                (devset['subject_id'] == first_group_keys['subject_id']) & 
                (devset['study_id'] == first_group_keys['study_id']) &
                (devset['sequence'] == first_group_keys['sequence']) &
                (devset['sent'] == first_group_keys['sent']) &
                (devset['section'] == first_group_keys['section'])
            ]
        elif section_key == 'report':
            first_group_keys = devset[devset[section_key] == selected_section][['subject_id', 'study_id', 'sequence', 'report', 'section']].iloc[0]
            devset_col = devset[
                (devset['subject_id'] == first_group_keys['subject_id']) & 
                (devset['study_id'] == first_group_keys['study_id']) &
                (devset['sequence'] == first_group_keys['sequence']) &
                (devset['report'] == first_group_keys['report']) &
                (devset['section'] == first_group_keys['section'])
            ]
        else:
            first_group_keys = devset[devset[section_key] == selected_section][['subject_id', 'study_id', 'section', 'section_report']].iloc[0]
            devset_col = devset[
                (devset['subject_id'] == first_group_keys['subject_id']) & 
                (devset['study_id'] == first_group_keys['study_id'])&
                (devset['section'] == first_group_keys['section'])&
                (devset['section_report'] == first_group_keys['section_report'])
            ]

        filled_cols = sorted([col for col in query_cols if not devset_col[col].isna().all()])
        filled_entities = sorted(list(devset_col['ent'].unique()))
        query_cols = sorted(list(set(query_cols) - set(filled_cols)))
        query_entities = sorted(list(set(query_entities) - set(filled_entities)))

        diverse_result = pd.concat([diverse_result, devset_col])
        shot_count += 1
        
    user_history = []
    assistant_history = []
    unique_sections = list(diverse_result[section_key].unique())

    for section in unique_sections[:args.n_retrieval]:
        if args.multi:
            user, assistant = get_fewshot_multi_turn(section, diverse_result[diverse_result[section_key] == section], args, vocab, vocab_lookup, vocab_keys)
        else:
            user, assistant = get_fewshot(section, diverse_result[diverse_result[section_key] == section], args, vocab, vocab_lookup, vocab_keys)
        user_history.append(user)
        assistant_history.append(assistant)

    return json_vocab_input, user_history, assistant_history

def get_dynamic_retrieval(query, result, vocab, devset, args):

    devset = devset.copy()
    
    ENTITY_CATEGORY = ['pf', 'cf', 'cof', 'cof/ncd', 'oth', 'ncd', 'pf/cf', 'patient info.']
    
    
    section_key = 'section_report' if args.unit == 'section' else 'sent'
    diverse_result = pd.DataFrame()
    vocab_lookup, vocab_keys = get_vocab_lookup(vocab)
    query_words_dict = {}

    ###### New User Input에 대한 JSON INPUT 생성 ######
    # 단, candidate가 vocab 기반으로 추출된다.

    json_vocab_input = {
        'report_section': []
    }

    query_words = []
    if section_key == 'sent':
        query_words = get_words(query, vocab, vocab_lookup, vocab_keys, args)
    else:
        # Extract numbered sentences from the query (format: "(1) text (2) text...")
        sentences = re.findall(r'\(\d+\)\s*(.*?)(?=\s*\(\d+\)|$)', query, re.DOTALL)
        sentences = [s.strip() for s in sentences if s.strip()]

        for idx, sentence in enumerate(sentences):
            
            all_words = list(set(get_words(sentence, vocab, vocab_lookup, vocab_keys, args)))
            
            query_words.extend(all_words)
            
            if ('gt' in args.candidate_type) or (args.candidate_type == 'no_candidates'): # 문제가 있어 보인다.
                candidates = all_words
            
            elif args.candidate_type == 'vocab_so':
                candidates = all_words
            
            elif args.candidate_type == 'vocab_s': # entity category만 준다.
                ent_words = []
                for word in all_words:
                    cat_list = list(set(vocab[vocab['target_term'] == word]['category']))
                    if all(cat in ENTITY_CATEGORY for cat in cat_list):
                        ent_words.append(word)
                candidates = ent_words
            
            elif args.candidate_type == 'vocab_ent_rcg':
                word_cat_list = []
                for word in all_words:
                    cat_list = list(set(vocab[vocab['target_term'] == word]['category']))
                    word_cat_list.extend([(word, cat_list)])
                candidates = word_cat_list

            
            if args.candidate_usage != 1:
                target_length = int(len(candidates) * args.candidate_usage)
                step = len(candidates) // max(target_length, 1)
                candidates = candidates[::step][:target_length]
                        
            json_vocab_input['report_section'].append({
                'sent_idx': idx+1,
                'sentence': sentence,
                'candidates': candidates
            })
    
    query_words = list(set(query_words))
    query_cols, query_entity_cats, query_entities = get_cols(query_words, vocab)

    # Convert sets to sorted lists for deterministic order
    query_cols = sorted(query_cols)
    query_entity_cats = sorted(query_entity_cats)
    query_entities = sorted(query_entities)
    
    # result의 sent 순서대로 devset을 정렬하고 중복 sent 확인
    ordered_sections = result['section'].unique().tolist()
    devset.loc[:, f'{section_key}_order'] = devset[f'{section_key}'].map({s: i for i, s in enumerate(ordered_sections)})
    
    # 각 section_report에 query_entities가 얼마나 포함되어 있는지 계산
    if len(query_entities) > 0:
        # 각 entity가 section_report에 포함되어 있는지 확인하는 함수
        def count_entities(text):
            return sum(1 for entity in query_entities if entity in text)
        
        # section_report에 포함된 entity 수를 계산하여 새 컬럼 추가
        devset.loc[:, 'entity_count'] = devset[section_key].apply(count_entities)
        
        # 먼저 entity가 많이 포함된 순서대로 정렬하고, 그 다음 section_order로 정렬
        devset = devset.sort_values(['entity_count', f'{section_key}_order'], ascending=[False, True])
    else:
        # entity가 없는 경우 section_order로만 정렬
        devset = devset.sort_values(f'{section_key}_order')
    
    # 정렬 후 임시 컬럼 제거
    devset = devset.drop(['entity_count', f'{section_key}_order'], axis=1, errors='ignore')
    
    
    seen_sections = set()
    shot_count = 0
    
    # query = '(1) ..... (2) ..... (3) .....'
    # query_entities = [opacity, cardiomegaly, a, b, c]
    # attr = [no change, morphology, distribution, improvement, worsening, d, e, f]
    
    # shot 1 [opacity, cardiomegaly, a], [d, e]
    
    # Process columns first, then entity categories
    while len(query_cols) > 0 or len(query_entities) > 0 or shot_count < args.n_retrieval:
        section_list = []
        
        if len(query_entities) > 0:
            entity = query_entities[0]  # Take first element instead of random pop
            query_entities = query_entities[1:]  # Remove first element
            section_list = list(devset[devset['ent'] == entity][section_key].unique())
        elif len(query_cols) > 0:
            col = query_cols[0]  # Take first element instead of random pop
            query_cols = query_cols[1:]  # Remove first element
            section_list = list(devset[devset[col].notna()][section_key].unique())
        elif shot_count < args.n_retrieval:
            section_list = list(devset[section_key].unique())

        if len(section_list) == 0:
            continue

        # Sort section_list for deterministic selection
        selected_section = None
        for section in section_list:
            if section not in seen_sections:
                selected_section = section
                seen_sections.add(section)
                break

        if selected_section is None:
            continue
        

        first_group_keys = devset[devset[section_key] == selected_section][['subject_id', 'study_id', 'section', 'section_report']].iloc[0]
        devset_col = devset[
            (devset['subject_id'] == first_group_keys['subject_id']) & 
            (devset['study_id'] == first_group_keys['study_id'])&
            (devset['section'] == first_group_keys['section'])&
            (devset['section_report'] == first_group_keys['section_report'])
        ]

        filled_cols = sorted([col for col in query_cols if not devset_col[col].isna().all()])
        filled_entities = sorted(list(devset_col['ent'].unique()))
        query_cols = sorted(list(set(query_cols) - set(filled_cols)))
        query_entities = sorted(list(set(query_entities) - set(filled_entities)))

        diverse_result = pd.concat([diverse_result, devset_col])
        shot_count += 1
        
    user_history = []
    assistant_history = []
    unique_sections = list(diverse_result[section_key].unique())

    for section in unique_sections:
        user, assistant = get_fewshot(section, diverse_result[diverse_result[section_key] == section], args, vocab, vocab_lookup, vocab_keys)
        user_history.append(user)
        assistant_history.append(assistant)

    return json_vocab_input, user_history, assistant_history

def retreive_query_related_fewshot(devset, query, query_subject_id, args):
    """
    Search for similar sentences using BM25
    
    Args:
        query (str): Query string to search for
        n_retrieval (int): Number of results to retrieve
        
    Returns:
        list: List of matching sentences with their metadata from devset
    """

    # BM25
    if args.unit == 'sent':
        corpus_col = 'sent'
    elif args.unit == 'section':
        corpus_col = 'section_report'

    if not args.holdout_devset:
        corpus = devset[devset['subject_id'] != query_subject_id][corpus_col].dropna().tolist()
        devset = devset[devset['subject_id'] != query_subject_id]
    else:
        corpus = devset[corpus_col].dropna().tolist()
        
    tokenized_corpus = [doc.split() for doc in corpus]
    bm25 = BM25Okapi(tokenized_corpus)

    # Get scores
    tokenized_query = query.split()
    corpus_scores = bm25.get_scores(tokenized_query)

    # Create dataframe
    corpus_ranked = pd.DataFrame({'section': corpus, 'score': corpus_scores})
    corpus_ranked = corpus_ranked.sort_values(by='score', ascending=False)

    vocab = pd.read_csv(args.vocab_path)
        
    # If diverse retrieval is True, get diverse retrieval
    if args.dynamic_retrieval:
        json_vocab_input, user_history, assistant_history = get_dynamic_retrieval(query, corpus_ranked, vocab, devset, args)
    else:
        json_vocab_input, user_history, assistant_history = get_n_retrieval(query, corpus_ranked, vocab, devset, args)

    return json_vocab_input, user_history, assistant_history

def parse_eval_format(idx, text):
    pred_triplet = defaultdict(list)
    if isinstance(text, dict):
        # Check if text is already in the expected format with relation keys
        if any(key.lower() in RELATIONS for key in text.keys()):
            # Already in the correct format, just ensure keys are lowercase
            for relation in text:
                relation_lower = relation.lower()
                if relation_lower in RELATIONS:
                    pred_triplet[relation_lower] = text[relation]
            return idx, pred_triplet
        else:
            # Try to convert from structured output format
            try:
                return idx, convert_to_sr_structure(text)
            except Exception as e:
                print(f"Error converting structured output: {e}")
                print("type(text)", type(text))
                print("text", text)
    
    elif isinstance(text, list):
        raise ValueError("여기도 있다면 수정 해야함. Invalid text format")

    elif isinstance(text, str):
        try:
            # JSON 문자열인지 확인
            if "{" in text and "}" in text:
                parsed_json = json.loads(text)
                
                for entity in parsed_json["entities"]:
                    entity_name = entity["name"].lower().strip()
                    sent_idx = entity.get("sent_idx", None)
                    ent_idx = entity.get("ent_idx", None)
                    
                    for relation in entity["relations"]:
                        relation_name = relation["relation"].lower()
                        relation_value = relation["value"].lower().strip()
                        obj_ent_idx = relation.get("obj_ent_idx", None)
                        
                        if relation_name in RELATIONS:
                            pred_triplet[relation_name].append((
                                entity_name, relation_name, relation_value,
                                sent_idx, ent_idx, obj_ent_idx
                            ))
                return idx, pred_triplet                
        except:
            pass
        
    else:
        print("type(text)", type(text))
        print("text", text)
        raise ValueError("Invalid text format")
    
def generate_graph(eval_path, args):
    # Load results
    results = {}
    results_no_sent_idx = {}
    
    if args.mode == 'rexval':
        triplets_path = f'{eval_path}/SRO_result_by_model.json'
        if os.path.exists(triplets_path):
            results['triplets'] = json.load(open(triplets_path))

        subj_path = f'{eval_path}/SR_result_by_model.json'
        if os.path.exists(subj_path):
            results['subj'] = json.load(open(subj_path))
        
        # For rexval mode, create separate visualizations for each model
        if 'triplets' in results or 'subj' in results:
            # Get all model names from either triplets or subj results
            model_names = []
            if 'triplets' in results:
                model_names = list(results['triplets'].keys())
            elif 'subj' in results:
                model_names = list(results['subj'].keys())
            
            # Create a separate visualization for each model
            for model_name in model_names:
                create_model_visualization(eval_path, model_name, results, args)
    else:
        triplets_path = f'{eval_path}/SRO_result.json'
        if os.path.exists(triplets_path):
            results['triplets'] = json.load(open(triplets_path))

        subj_path = f'{eval_path}/SR_result.json'
        if os.path.exists(subj_path):
            results['subj'] = json.load(open(subj_path))
        
        # Create a single visualization for non-rexval mode
        create_standard_visualization(eval_path, results, args)
        
        triplets_path = f'{eval_path}/SRO_result_no_sent_idx.json'
        if os.path.exists(triplets_path):
            results_no_sent_idx['triplets'] = json.load(open(triplets_path))

        subj_path = f'{eval_path}/SR_result_no_sent_idx.json'
        if os.path.exists(subj_path):
            results_no_sent_idx['subj'] = json.load(open(subj_path))
        
        # Create a single visualization for non-rexval mode
        create_standard_visualization(eval_path, results_no_sent_idx, args, no_sent_idx=True)


def create_model_visualization(eval_path, model_name, results, args):
    """Create visualization for a specific model in rexval mode"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(20, 16))
    x = np.arange(len(RELATIONS) + 1)  # Add 1 for average
    width = 0.2
    
    # Calculate F1 scores
    f1_scores = {}
    metrics = {}  # Store TP, TN, FP, FN for each relation
    
    for key in ['triplets', 'subj']:
        if key not in results or model_name not in results[key]:
            continue
            
        result = results[key][model_name]
        scores = []
        metrics[key] = []
        
        for r in RELATIONS:
            if r in result:
                scores.append(result[r]['f1'])
                # Store metrics for each relation
                metrics[key].append({
                    'tp': result[r].get('tp', 0),
                    'tn': result[r].get('tn', 0),
                    'fp': result[r].get('fp', 0),
                    'fn': result[r].get('all', 0) - result[r].get('tp', 0) if key == 'subj' else result[r].get('miss', 0)
                })
            else:
                scores.append(0.0)
                metrics[key].append({'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0})
        
        # Add average metrics
        avg_metrics = {
            'tp': sum(m['tp'] for m in metrics[key]) // len(RELATIONS),
            'tn': sum(m['tn'] for m in metrics[key]) // len(RELATIONS),
            'fp': sum(m['fp'] for m in metrics[key]) // len(RELATIONS),
            'fn': sum(m['fn'] for m in metrics[key]) // len(RELATIONS)
        }
        metrics[key].append(avg_metrics)
        
        # Add average of non-zero scores
        scores.append(np.mean([s for s in scores if s > 0]))
        f1_scores[key] = scores
        
        # Calculate overall F1 based on total TP, FP, FN
        total_tp = sum(m['tp'] for m in metrics[key][:-1])  # Exclude the average we just added
        total_fp = sum(m['fp'] for m in metrics[key][:-1])
        total_fn = sum(m['fn'] for m in metrics[key][:-1])
        
        precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
        recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
        overall_f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        f1_scores[f'{key}_overall'] = [overall_f1] * len(RELATIONS) + [overall_f1]  # Same value for all relations + average
        metrics[f'{key}_overall'] = [{
            'tp': total_tp,
            'fp': total_fp,
            'fn': total_fn,
            'tn': 0  # Not typically used for F1
        }] * (len(RELATIONS) + 1)  # Same metrics for all relations + average
    
    # Plot configurations
    plot_configs = [
        (ax1, 'bar', f'{model_name} - Triplets F1 Scores (Bar Plot)'),
        (ax2, 'line', f'{model_name} - Triplets F1 Scores (Line Plot)'),
        (ax3, 'bar', f'{model_name} - Subject F1 Scores (Bar Plot)'),
        (ax4, 'line', f'{model_name} - Subject F1 Scores (Line Plot)')
    ]
    
    for ax, plot_type, title in plot_configs:
        result_type = 'triplets' if 'Triplets' in title else 'subj'
        
        if result_type in f1_scores:
            if plot_type == 'bar':
                # Plot per-relation F1 scores
                bars = ax.bar(x - width, f1_scores[result_type], width, label=f'{args.output_format} (Per Relation)', alpha=0.8)
                
                # Add metrics text on bars
                for i, bar in enumerate(bars):
                    height = bar.get_height()
                    metric = metrics[result_type][i]
                    ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
                            f'TP:{metric["tp"]}\nFP:{metric["fp"]}\nFN:{metric["fn"]}',
                            ha='center', va='bottom', fontsize=8, rotation=0)
                
                # Plot overall F1 score
                if f'{result_type}_overall' in f1_scores:
                    bars = ax.bar(x, f1_scores[f'{result_type}_overall'], width, label=f'{args.output_format} (Overall)', alpha=0.8)
                    # Add overall metrics
                    for i, bar in enumerate(bars):
                        height = bar.get_height()
                        metric = metrics[f'{result_type}_overall'][i]
                
                if f'{result_type}_with_gpt' in f1_scores:
                    bars = ax.bar(x + width, f1_scores[f'{result_type}_with_gpt'], width, label=f'{args.output_format} with GPT', alpha=0.8)
                    # Add metrics for second set if available
                    if f'{result_type}_with_gpt' in metrics:
                        for i, bar in enumerate(bars):
                            height = bar.get_height()
                            metric = metrics[f'{result_type}_with_gpt'][i]
                            ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
                                    f'TP:{metric["tp"]}\nFP:{metric["fp"]}\nFN:{metric["fn"]}',
                                    ha='center', va='bottom', fontsize=8, rotation=0)
                
                if f'{result_type}_with_gpt2' in f1_scores:
                    bars = ax.bar(x + width*2, f1_scores[f'{result_type}_with_gpt2'], width, label=f'{args.output_format} with GPT2', alpha=0.8)
                    # Add metrics for third set if available
                    if f'{result_type}_with_gpt2' in metrics:
                        for i, bar in enumerate(bars):
                            height = bar.get_height()
                            metric = metrics[f'{result_type}_with_gpt2'][i]
                            ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
                                    f'TP:{metric["tp"]}\nFP:{metric["fp"]}\nFN:{metric["fn"]}',
                                    ha='center', va='bottom', fontsize=8, rotation=0)
            else:  # line plot
                # Plot per-relation F1 scores
                line = ax.plot(x, f1_scores[result_type], marker='s', linestyle='-', label=f'{args.output_format} (Per Relation)', alpha=0.8)
                # Add metrics as annotations on line points
                for i, (xi, yi) in enumerate(zip(x, f1_scores[result_type])):
                    metric = metrics[result_type][i]
                    ax.annotate(f'TP:{metric["tp"]}\nFP:{metric["fp"]}\nFN:{metric["fn"]}',
                                xy=(xi, yi), xytext=(0, 10), textcoords='offset points',
                                ha='center', va='bottom', fontsize=7, bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7))
                
                # Plot overall F1 score
                if f'{result_type}_overall' in f1_scores:
                    line = ax.plot(x, f1_scores[f'{result_type}_overall'], marker='d', linestyle='--', 
                                  label=f'{args.output_format} (Overall)', alpha=0.8)
                    # Add overall metrics
                    for i, (xi, yi) in enumerate(zip(x, f1_scores[f'{result_type}_overall'])):
                        metric = metrics[f'{result_type}_overall'][i]
                        if i == len(x) // 2:  # Only annotate in the middle to avoid clutter
                            ax.annotate(f'Overall: TP:{metric["tp"]}\nFP:{metric["fp"]}\nFN:{metric["fn"]}',
                                        xy=(xi, yi), xytext=(0, -30), textcoords='offset points',
                                        ha='center', va='top', fontsize=8, 
                                        bbox=dict(boxstyle='round,pad=0.3', fc='yellow', alpha=0.7))
                
                if f'{result_type}_with_gpt' in f1_scores:
                    line = ax.plot(x, f1_scores[f'{result_type}_with_gpt'], marker='o', linestyle='-', label=f'{args.output_format} with GPT', alpha=0.8)
                    if f'{result_type}_with_gpt' in metrics:
                        for i, (xi, yi) in enumerate(zip(x, f1_scores[f'{result_type}_with_gpt'])):
                            metric = metrics[f'{result_type}_with_gpt'][i]
                            ax.annotate(f'TP:{metric["tp"]}\nFP:{metric["fp"]}\nFN:{metric["fn"]}',
                                        xy=(xi, yi), xytext=(0, 10), textcoords='offset points',
                                        ha='center', va='bottom', fontsize=7, bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7))
                
                if f'{result_type}_with_gpt2' in f1_scores:
                    line = ax.plot(x, f1_scores[f'{result_type}_with_gpt2'], marker='^', linestyle='-', label=f'{args.output_format} with GPT2', alpha=0.8)
                    if f'{result_type}_with_gpt2' in metrics:
                        for i, (xi, yi) in enumerate(zip(x, f1_scores[f'{result_type}_with_gpt2'])):
                            metric = metrics[f'{result_type}_with_gpt2'][i]
                            ax.annotate(f'TP:{metric["tp"]}\nFP:{metric["fp"]}\nFN:{metric["fn"]}',
                                        xy=(xi, yi), xytext=(0, 10), textcoords='offset points',
                                        ha='center', va='bottom', fontsize=7, bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7))
                
                ax.grid(True)
        
        ax.set_ylabel('F1 Score')
        ax.set_title(title)
        ax.set_xticks(x)
        ax.set_xticklabels(RELATIONS + ['Average'], rotation=45, ha='right')
        ax.legend()
    
    plt.tight_layout()
    
    # Save the visualization
    os.makedirs(f'{eval_path}/figures', exist_ok=True)
    plt.savefig(f'{eval_path}/figures/{model_name}_evaluation_results.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Evaluation visualization for {model_name} saved to {eval_path}/figures/{model_name}_evaluation_results.png")

def create_standard_visualization(eval_path, results, args, no_sent_idx=False):
    """Create standard visualization for non-rexval mode"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(20, 16))
    x = np.arange(len(RELATIONS) + 1)  # Add 1 for average
    width = 0.2
    
    # Calculate F1 scores
    f1_scores = {}
    metrics = {}  # Store TP, TN, FP, FN for each relation
    
    for key, result in results.items():
        scores = []
        metrics[key] = []
        
        for r in RELATIONS:
            if r in result:
                scores.append(result[r]['f1'])
                # Store metrics for each relation
                metrics[key].append({
                    'tp': result[r].get('tp', 0),
                    'tn': result[r].get('tn', 0),
                    'fp': result[r].get('fp', 0),
                    'fn': result[r].get('all', 0) - result[r].get('tp', 0) if key == 'subj' else result[r].get('miss', 0)
                })
            else:
                scores.append(0.0)
                metrics[key].append({'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0})
        
        # Add average metrics
        avg_metrics = {
            'tp': sum(m['tp'] for m in metrics[key]) // len(RELATIONS),
            'tn': sum(m['tn'] for m in metrics[key]) // len(RELATIONS),
            'fp': sum(m['fp'] for m in metrics[key]) // len(RELATIONS),
            'fn': sum(m['fn'] for m in metrics[key]) // len(RELATIONS)
        }
        metrics[key].append(avg_metrics)
        
        # Add average of non-zero scores
        scores.append(np.mean([s for s in scores if s > 0]))
        f1_scores[key] = scores
        
        # Calculate overall F1 based on total TP, FP, FN
        total_tp = sum(m['tp'] for m in metrics[key][:-1])  # Exclude the average we just added
        total_fp = sum(m['fp'] for m in metrics[key][:-1])
        total_fn = sum(m['fn'] for m in metrics[key][:-1])
        
        precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
        recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
        overall_f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        f1_scores[f'{key}_overall'] = [overall_f1] * len(RELATIONS) + [overall_f1]  # Same value for all relations + average
        metrics[f'{key}_overall'] = [{
            'tp': total_tp,
            'fp': total_fp,
            'fn': total_fn,
            'tn': 0  # Not typically used for F1
        }] * (len(RELATIONS) + 1)  # Same metrics for all relations + average
    
    # Plot configurations
    plot_configs = [
        (ax1, 'bar', 'triplets F1 Scores (Bar Plot)'),
        (ax2, 'line', 'triplets F1 Scores (Line Plot)'),
        (ax3, 'bar', 'Subject F1 Scores (Bar Plot)'),
        (ax4, 'line', 'Subject F1 Scores (Line Plot)')
    ]
    
    for ax, plot_type, title in plot_configs:
        result_type = 'triplets' if 'triplets' in title else 'subj'
        
        if result_type in f1_scores:
            if plot_type == 'bar':
                # Plot per-relation F1 scores
                bars = ax.bar(x - width, f1_scores[result_type], width, label=f'{args.output_format} (Per Relation)', alpha=0.8)
                
                # Add metrics text on bars
                for i, bar in enumerate(bars):
                    height = bar.get_height()
                    metric = metrics[result_type][i]
                    ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
                            f'TP:{metric["tp"]}\nFP:{metric["fp"]}\nFN:{metric["fn"]}',
                            ha='center', va='bottom', fontsize=8, rotation=0)
                
                # Plot overall F1 score
                if f'{result_type}_overall' in f1_scores:
                    bars = ax.bar(x, f1_scores[f'{result_type}_overall'], width, label=f'{args.output_format} (Overall)', alpha=0.8)
                    # Add overall metrics
                    for i, bar in enumerate(bars):
                        height = bar.get_height()
                        metric = metrics[f'{result_type}_overall'][i]
                
                if f'{result_type}_with_gpt' in f1_scores:
                    bars = ax.bar(x + width, f1_scores[f'{result_type}_with_gpt'], width, label=f'{args.output_format} with GPT', alpha=0.8)
                    # Add metrics for second set if available
                    if f'{result_type}_with_gpt' in metrics:
                        for i, bar in enumerate(bars):
                            height = bar.get_height()
                            metric = metrics[f'{result_type}_with_gpt'][i]
                            ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
                                    f'TP:{metric["tp"]}\nFP:{metric["fp"]}\nFN:{metric["fn"]}',
                                    ha='center', va='bottom', fontsize=8, rotation=0)
                
                if f'{result_type}_with_gpt2' in f1_scores:
                    bars = ax.bar(x + width*2, f1_scores[f'{result_type}_with_gpt2'], width, label=f'{args.output_format} with GPT2', alpha=0.8)
                    # Add metrics for third set if available
                    if f'{result_type}_with_gpt2' in metrics:
                        for i, bar in enumerate(bars):
                            height = bar.get_height()
                            metric = metrics[f'{result_type}_with_gpt2'][i]
                            ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
                                    f'TP:{metric["tp"]}\nFP:{metric["fp"]}\nFN:{metric["fn"]}',
                                    ha='center', va='bottom', fontsize=8, rotation=0)
            else:  # line plot
                # Plot per-relation F1 scores
                line = ax.plot(x, f1_scores[result_type], marker='s', linestyle='-', label=f'{args.output_format} (Per Relation)', alpha=0.8)
                # Add metrics as annotations on line points
                for i, (xi, yi) in enumerate(zip(x, f1_scores[result_type])):
                    metric = metrics[result_type][i]
                    ax.annotate(f'TP:{metric["tp"]}\nFP:{metric["fp"]}\nFN:{metric["fn"]}',
                                xy=(xi, yi), xytext=(0, 10), textcoords='offset points',
                                ha='center', va='bottom', fontsize=7, bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7))
                
                # Plot overall F1 score
                if f'{result_type}_overall' in f1_scores:
                    line = ax.plot(x, f1_scores[f'{result_type}_overall'], marker='d', linestyle='--', 
                                  label=f'{args.output_format} (Overall)', alpha=0.8)
                    # Add overall metrics
                    for i, (xi, yi) in enumerate(zip(x, f1_scores[f'{result_type}_overall'])):
                        metric = metrics[f'{result_type}_overall'][i]
                        if i == len(x) // 2:  # Only annotate in the middle to avoid clutter
                            ax.annotate(f'Overall: TP:{metric["tp"]}\nFP:{metric["fp"]}\nFN:{metric["fn"]}',
                                        xy=(xi, yi), xytext=(0, -30), textcoords='offset points',
                                        ha='center', va='top', fontsize=8, 
                                        bbox=dict(boxstyle='round,pad=0.3', fc='yellow', alpha=0.7))
                
                if f'{result_type}_with_gpt' in f1_scores:
                    line = ax.plot(x, f1_scores[f'{result_type}_with_gpt'], marker='o', linestyle='-', label=f'{args.output_format} with GPT', alpha=0.8)
                    if f'{result_type}_with_gpt' in metrics:
                        for i, (xi, yi) in enumerate(zip(x, f1_scores[f'{result_type}_with_gpt'])):
                            metric = metrics[f'{result_type}_with_gpt'][i]
                            ax.annotate(f'TP:{metric["tp"]}\nFP:{metric["fp"]}\nFN:{metric["fn"]}',
                                        xy=(xi, yi), xytext=(0, 10), textcoords='offset points',
                                        ha='center', va='bottom', fontsize=7, bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7))
                
                if f'{result_type}_with_gpt2' in f1_scores:
                    line = ax.plot(x, f1_scores[f'{result_type}_with_gpt2'], marker='^', linestyle='-', label=f'{args.output_format} with GPT2', alpha=0.8)
                    if f'{result_type}_with_gpt2' in metrics:
                        for i, (xi, yi) in enumerate(zip(x, f1_scores[f'{result_type}_with_gpt2'])):
                            metric = metrics[f'{result_type}_with_gpt2'][i]
                            ax.annotate(f'TP:{metric["tp"]}\nFP:{metric["fp"]}\nFN:{metric["fn"]}',
                                        xy=(xi, yi), xytext=(0, 10), textcoords='offset points',
                                        ha='center', va='bottom', fontsize=7, bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7))
                
                ax.grid(True)
        
        ax.set_ylabel('F1 Score')
        ax.set_title(title)
        ax.set_xticks(x)
        ax.set_xticklabels(RELATIONS + ['Average'], rotation=45, ha='right')
        ax.legend()
    
    plt.tight_layout()
    
    # Save the visualization
    if no_sent_idx:
        os.makedirs(f'{eval_path}/figures', exist_ok=True)
        plt.savefig(f'{eval_path}/figures/evaluation_results_no_sent_idx.png', dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Evaluation visualization saved to {eval_path}/figures/evaluation_results_no_sent_idx.png")
    else:
        os.makedirs(f'{eval_path}/figures', exist_ok=True)
        plt.savefig(f'{eval_path}/figures/evaluation_results.png', dpi=300, bbox_inches='tight')
        plt.close()
    
        print(f"Evaluation visualization saved to {eval_path}/figures/evaluation_results.png")

def create_fp(eval_path, gold_file_path):
    
    sro_review_path = os.path.join(eval_path, 'all.json')
    #sr_review_path = os.path.join(eval_path, 'all_subject.json')
    gold_file = json.load(open(gold_file_path))
    json_data = json.load(open(sro_review_path))
    
    fp_data = []
    for data in json_data:
        wrong_triplet_list = data['wrong_triplet_list']
        for triplet in wrong_triplet_list:
            fp_data.append({
                'custom_id': data['custom_id'],
                'study_id': gold_file[data['custom_id']]['study_id'],
                'sentence': data['sentence'],
                'wrong_triplet': triplet,
            })
    
    # Convert to CSV
    fp_df = pd.DataFrame(fp_data)
    csv_path = os.path.join(eval_path, 'false_positives.csv')
    fp_df.to_csv(csv_path, index=False)
    print(f"False positives saved to {csv_path}")
    

def create_fn(eval_path, gold_file_path):
    
    sro_review_path = os.path.join(eval_path, 'all.json')
    #sr_review_path = os.path.join(eval_path, 'all_subject.json')
    gold_file = json.load(open(gold_file_path))
    json_data = json.load(open(sro_review_path))
    
    fn_data = []
    for data in json_data:
        wrong_triplet_list = data['miss_triplet_list']
        for triplet in wrong_triplet_list:
            fn_data.append({
                'custom_id': data['custom_id'],
                'study_id': gold_file[data['custom_id']]['study_id'],
                'sentence': data['sentence'],
                'miss_triplet': triplet,
            })
    
    # Convert to CSV
    fn_df = pd.DataFrame(fn_data)
    csv_path = os.path.join(eval_path, 'false_negatives.csv')
    fn_df.to_csv(csv_path, index=False)
    print(f"False negatives saved to {csv_path}")
    
# test reports
def extract_triplets(row):
    """Extract triplets from a row of grouped data"""
    triplet_list = []
    same_triplet_list = []
    sent_idx = row['sent_idx']
  
    for relation in RELATIONS:
        ent = row['ent']
        if str(row[relation]) != 'nan':
            
            if relation in ['evidence', 'associate']:
                # Split by comma and process pairs of values and obj_ent_idx
                relation_values = row[relation].split(', ')
                i = 0
                while i < len(relation_values):
                    value = relation_values[i].strip()
                    # Check if the next item is an obj_ent_idx reference
                    if i+1 < len(relation_values) and 'obj_ent_idx' in relation_values[i+1]:
                        # Extract the index number from obj_ent_idx
                        obj_ent_idx = int(relation_values[i+1].replace('obj_ent_idx', ''))
                        triplet = (ent, relation, value, row['sent_idx'], row['ent_idx'], obj_ent_idx)
                        i += 2  # Skip to the next pair
                    else:
                        print("row", row)
                        raise ValueError(f"No obj_ent_idx found for value: {value}")
                    same_triplet_list.append([triplet])
                    triplet_list.append(triplet)

            else:
                triplet = (ent, relation, row[relation], row['sent_idx'], int(row['ent_idx']), None)
                same_triplet_list.append([triplet])
                triplet_list.append(triplet)
                
    return triplet_list, same_triplet_list, sent_idx

def create_report_level_data(df, args):
    """Create report level data entry"""

    section_list = ['hist', 'find', 'impr']
    
    hist = df[df['section'] == 'hist']
    find = df[df['section'] == 'find']
    impr = df[df['section'] == 'impr']

    report = ""
    if not hist.empty:
        report += f"{hist['report'].iloc[0]}\n"
    if not find.empty:
        report += f"{find['report'].iloc[0]}\n"
    if not impr.empty:
        report += f"{impr['report'].iloc[0]}"

    report_triplet_list = []
    report_same_triplet_list = []

    for section in section_list:    
        df_section = df[df['section'] == section]
        for row_idx, row in df_section.iterrows():
            triplet_list, same_triplet_list = extract_triplets(row)
            report_triplet_list.extend(triplet_list)
            report_same_triplet_list.extend(same_triplet_list)

    return {
        'subject_id': df['subject_id'].iloc[0],
        'study_id': df['study_id'].iloc[0],
        'passage': report,
        'relations': RELATIONS,
        'triplet_list': report_triplet_list,
        'same_triplet_list': report_same_triplet_list,
        'data_from': args.mode
    }

def prepend_idx(df_section, args, col=None):
    """Add index numbers to sentences based on their position in the report"""
    
    numbered_report = ""
    numbered_report_list = []
    
    if args.mode in ['maira', 'maira_cascade', 'medversa', 'rgrg', 'cvt2distilgpt2', 'medgemma', 'lingshu', 'silver_eval']:
        for idx, row in df_section.iterrows():
            text = row[f'{args.mode}_report'] if args.mode != 'silver_eval' else row['report']
            
            if pd.isna(text) or not isinstance(text, str) or len(text) < 2:
                # Handle missing or non-string data
                continue
            
            # Split text into sentences using regex
            sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s', text)
            for sent_idx, sent in enumerate(sentences):
                if sent.strip():  # Skip empty sentences
                    numbered_report += f"({sent_idx+1}) {sent.strip()}"
                    numbered_report_list.append(f"({sent_idx+1}) {sent.strip()}")
            return numbered_report.strip(), numbered_report_list
        
    elif args.mode == 'rexerr':
        if df_section['section'].iloc[0] == 'find':
            # 모든 행에 대해 error_findings 값이 nan인지 확인
            all_nan = df_section['error_findings'].isna().all()
            if all_nan:
                print("Warning: All error_findings values are NaN")
                return "", []
                
            for idx, row in df_section.iterrows():
                text = row['error_findings']

                if pd.isna(text) or not isinstance(text, str):
                    continue  # 이 행은 건너뛰고 다음 행으로
                
                sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s', text)
                for sent_idx, sent in enumerate(sentences):
                    if sent.strip():  # Skip empty sentences
                        numbered_report += f"({sent_idx+1}) {sent.strip()}"
                        numbered_report_list.append(f"({sent_idx+1}) {sent.strip()}")
            
            # 모든 행을 처리한 후 결과 반환
            return numbered_report.strip(), numbered_report_list
            
        elif df_section['section'].iloc[0] == 'impr':
            # 모든 행에 대해 error_impression 값이 nan인지 확인
            all_nan = df_section['error_impression'].isna().all()
            if all_nan:
                print("Warning: All error_impression values are NaN")
                return "", []
                
            for idx, row in df_section.iterrows():
                text = row['error_impression']
                
                if pd.isna(text) or not isinstance(text, str):
                    continue  # 이 행은 건너뛰고 다음 행으로
                
                sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s', text)
                for sent_idx, sent in enumerate(sentences):
                    if sent.strip():  # Skip empty sentences
                        numbered_report += f"({sent_idx+1}) {sent.strip()}"
                        numbered_report_list.append(f"({sent_idx+1}) {sent.strip()}")
            
            # 모든 행을 처리한 후 결과 반환
            return numbered_report.strip(), numbered_report_list
        
        # 기본 반환값
        return "", []
    else:
        if 'section_report' not in df_section.columns:
            if 'sent_idx' in df_section.columns:
                max_num = int(df_section['sent_idx'].max())
                for idx in range(max_num):
                    sent = df_section[df_section['sent_idx'] == idx+1]['sent'].values[0]
                    numbered_report += f"({idx+1}) {sent}"
                    numbered_report_list.append(f"({idx+1}) {sent}")
            else:
                for idx, row in df_section.iterrows():
                    text = row[col]
                    # 숫자와 온점(예: "1.", "2.")을 보존하면서 문장 분리
                    # 숫자.으로 시작하는 패턴을 분리하지 않도록 수정
                    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<!\d\.)(?<=\.|\?|\!)\s', text)
                    for sent_idx, sent in enumerate(sentences):
                        if sent.strip():  # Skip empty sentences
                            numbered_report += f"({sent_idx+1}) {sent.strip()}"
                            numbered_report_list.append(f"({sent_idx+1}) {sent.strip()}")
            return numbered_report.strip(), numbered_report_list
        else:
            section_report = df_section['section_report'].iloc[0]
            numbered_report_list = []
            pattern = r'\(\d+\)\s*([^(]*?)(?=\s*\(\d+\)|$)'
            matches = re.finditer(pattern, section_report)
            
            for match in matches:
                # 원래 번호 패턴을 포함한 전체 텍스트 조각 찾기
                start_pos = match.start()
                # 번호 패턴 찾기
                number_match = re.search(r'\(\d+\)', section_report[start_pos:match.end()])
                if number_match:
                    number_part = number_match.group(0)
                    text_part = match.group(1).strip()
                    numbered_report_list.append(f"{number_part} {text_part}")
            
            # 리스트가 비어있으면 원본 텍스트를 그대로 리스트에 추가
            if not numbered_report_list:
                numbered_report_list = [section_report]
                
            return section_report, numbered_report_list

def create_section_level_data(df, args):
    """Create section level data entry"""

    vocab = pd.read_csv(args.vocab_path)
    vocab['category'] = vocab['category'].replace(['lf', 'If'], 'pf')

    if 'section' in df.columns:
        if args.mode in ['maira', 'maira_cascade', 'rexerr', 'medversa', 'rgrg', 'cvt2distilgpt2', 'lingshu', 'medgemma']:
            section_list = ['find', 'impr']
        else:
            section_list = ['hist', 'find', 'impr']
    
        section_data = []
        for section in section_list:
            section_triplet_list = []
            section_same_triplet_list = []

            df_section = df[df['section'] == section]
            # 빈 섹션 건너뛰기
            if df_section.empty:
                continue

            gt_sro_review = defaultdict(list)
            gt_sro = defaultdict(list)
            gt_so = defaultdict(list)
            gt_s = defaultdict(list)
            gt_ent_rcg = defaultdict(list)
            gt_sro_rmd = defaultdict(list)
            gt_so_rmd = defaultdict(list)
            gt_s_rmd = defaultdict(list)
            gt_ent_rcg_rmd = defaultdict(list)
            
            if not df_section.empty:
                report_section, _ = prepend_idx(df_section, args)

                if len(df_section) == 0:
                    continue
                
                if args.mode == 'silver_eval':
                    section_data.append({
                        'subject_id': df['subject_id'].iloc[0],
                        'study_id': df['study_id'].iloc[0],
                        'section': section,
                        'passage': report_section,
                        'relations': RELATIONS,
                        'triplet_list': None,
                        'same_triplet_list': None,
                        'data_from': args.mode,
                        'json_input': None
                    })
                    continue

                df_section_idx_removed = df_section.copy()
                for row_idx, row in df_section_idx_removed.iterrows():
                    triplet_list, same_triplet_list, sent_idx = extract_triplets(row)
                    section_triplet_list.extend(triplet_list)
                    section_same_triplet_list.extend(same_triplet_list)

                    for triplet in triplet_list:
                        gt_sro[sent_idx].extend([(triplet[0], triplet[1], triplet[2])])
                        
                        ent_cat_list = vocab[vocab['target_term'] == triplet[0]]['category'].tolist()
                        gt_ent_rcg[sent_idx].extend([(triplet[0], ent_cat_list)])
                        gt_so[sent_idx].extend([(triplet[0])])
                        if triplet[1] not in ['cat', 'dx_status', 'dx_certainty', 'associate', 'evidence']:
                            gt_so[sent_idx].extend([(triplet[2])])
                            gt_ent_rcg[sent_idx].extend([(triplet[2], [triplet[1]])])

                        gt_s[sent_idx].extend([(triplet[0])])
                        
                        
                # Remove duplicates from gt_s_rmd
                for sent_idx, entities in gt_s.items():
                    gt_s_rmd[sent_idx] = list(dict.fromkeys(entities))
                    
                for sent_idx, entities in gt_so.items():
                    gt_so_rmd[sent_idx] = list(dict.fromkeys(entities))

                for sent_idx, entities in gt_ent_rcg.items():
                    # Convert list of tuples to dictionary to remove duplicates
                    # This won't work with dict.fromkeys since tuples with same entity name but different categories will be treated as duplicates
                    unique_entities = {}
                    for entity, categories in entities:
                        if entity in unique_entities:
                            # If entity already exists, merge the categories
                            unique_entities[entity] = list(set(unique_entities[entity] + categories))
                        else:
                            unique_entities[entity] = categories
                    
                    # Convert back to list of tuples
                    gt_ent_rcg_rmd[sent_idx] = [(entity, categories) for entity, categories in unique_entities.items()]
                
                # Remove duplicates from gt_sro
                for sent_idx, triplets in gt_sro.items():
                    # Use a set to track unique triplets
                    unique_triplets = []
                    seen = set()
                    for triplet in triplets:
                        # Create a hashable representation of the triplet
                        triplet_key = (triplet[0], triplet[1], triplet[2])
                        if triplet_key not in seen:
                            seen.add(triplet_key)
                            unique_triplets.append(triplet)
                    gt_sro_rmd[sent_idx] = unique_triplets

                if args.candidate_type == 'gt_sro_review':
                    final_candidates = gt_sro_rmd
                elif args.candidate_type == 'gt_sro':
                    final_candidates = gt_sro_rmd
                elif args.candidate_type == 'gt_so':
                    final_candidates = gt_so_rmd
                elif args.candidate_type == 'gt_s':
                    final_candidates = gt_s_rmd
                elif args.candidate_type == 'gt_ent_rcg':
                    final_candidates = gt_ent_rcg_rmd
                else:
                    final_candidates = gt_so_rmd

                if args.mode not in ['maira', 'maira_cascade', 'rexerr', 'medversa', 'rgrg', 'cvt2distilgpt2', 'medgemma', 'lingshu', 'silver_eval']:
                    json_data = {
                        'report_section': []
                    }
                    for sent_idx in range(1, df_section['sent_idx'].max() + 1):
                        if args.candidate_type == 'no_candidates':
                            json_data['report_section'].append({
                                'sent_idx': sent_idx,
                                'sentence': df_section[df_section['sent_idx'] == sent_idx]['sent'].values[0],
                            })
                        else:
                            json_data['report_section'].append({
                                'sent_idx': sent_idx,
                                'sentence': df_section[df_section['sent_idx'] == sent_idx]['sent'].values[0],
                                'candidates': final_candidates[sent_idx]
                            })

                    section_data.append({
                        'subject_id': df['subject_id'].iloc[0],
                        'study_id': df['study_id'].iloc[0],
                        'section': section,
                        'passage': report_section,
                        'relations': RELATIONS,
                        'triplet_list': section_triplet_list,
                        'same_triplet_list': section_same_triplet_list,
                        'data_from': args.mode,
                        'json_input': json_data # GT candidate를 사용하거나 candidate를 아예 사용하지 않을 경우 쓰는 json input
                    })
                    
                else:
                    section_data.append({
                        'subject_id': df['subject_id'].iloc[0],
                        'study_id': df['study_id'].iloc[0],
                        'section': section,
                        'passage': report_section,
                        'relations': RELATIONS,
                        'triplet_list': section_triplet_list,
                        'same_triplet_list': section_same_triplet_list,
                        'data_from': args.mode,
                        'json_input': None
                    })
                    
    else:
        section_data = []
        for col in args.report_col_name:
            section_triplet_list = []
            section_same_triplet_list = []
            
            # Create a new dataframe with only the gt_report column
            report_section, _ = prepend_idx(df, args, col)
            # Iterate through each row in the dataframe
            section_data.append({
                'subject_id': int(df['subject_id'].iloc[0]) if 'subject_id' in df.columns else None,
                'study_id': int(df['study_id'].iloc[0]) if 'study_id' in df.columns else None,
                'section': col,
                'passage': report_section,
                'relations': RELATIONS,
                'triplet_list': None,
                'same_triplet_list': None,
                'data_from': args.mode
            })
            
            
    return section_data


def create_sent_level_data(df, args):
    """Create section level data entry"""

    vocab = pd.read_csv(args.vocab_path)
    
    if 'section' in df.columns:
        if args.mode in ['maira', 'maira_cascade', 'rexerr']:
            section_list = ['find', 'impr']
        else:
            section_list = ['hist', 'find', 'impr']
    
        section_data = []
        for section in section_list:
            section_triplet_list = []
            section_same_triplet_list = []

            df_section = df[df['section'] == section]
            
            # 빈 섹션 건너뛰기
            if df_section.empty:
                continue

            gt_sro_review = defaultdict(list)
            gt_sro = defaultdict(list)
            gt_so = defaultdict(list)
            gt_s = defaultdict(list)
            gt_ent_rcg = defaultdict(list)
            gt_sro_rmd = defaultdict(list)
            gt_so_rmd = defaultdict(list)
            gt_s_rmd = defaultdict(list)
            gt_ent_rcg_rmd = defaultdict(list)
            
            if not df_section.empty:
                _, report_section_list = prepend_idx(df_section, args)

                if len(df_section) == 0:
                    continue
                
                df_section_idx_removed = df_section.copy()
                
                sent_idx_triplets = defaultdict(list)
                sent_idx_same_triplets = defaultdict(list)
                for row_idx, row in df_section_idx_removed.iterrows():
                    triplet_list, same_triplet_list, sent_idx = extract_triplets(row)
                    sent_idx_triplets[sent_idx].extend(triplet_list)
                    sent_idx_same_triplets[sent_idx].extend(same_triplet_list)
                    
                for sent_idx in sorted(sent_idx_triplets.keys()):
                    section_triplet_list.append(sent_idx_triplets[sent_idx])
                    section_same_triplet_list.append(sent_idx_same_triplets[sent_idx])

                    for triplet in triplet_list:
                        gt_sro[sent_idx].append((triplet[0], triplet[1], triplet[2]))
                        
                        ent_cat_list = vocab[vocab['target_term'] == triplet[0]]['category'].tolist()
                        gt_ent_rcg[sent_idx].append((triplet[0], ent_cat_list))
                        gt_so[sent_idx].append((triplet[0]))
                        if triplet[1] not in ['cat', 'dx_status', 'dx_certainty', 'associate', 'evidence']:
                            gt_so[sent_idx].append((triplet[2]))
                            gt_ent_rcg[sent_idx].append((triplet[2], [triplet[1]]))

                        gt_s[sent_idx].append((triplet[0]))
                        
                        
                # Remove duplicates from gt_s_rmd
                for sent_idx, entities in gt_s.items():
                    gt_s_rmd[sent_idx] = list(dict.fromkeys(entities))
                    
                for sent_idx, entities in gt_so.items():
                    gt_so_rmd[sent_idx] = list(dict.fromkeys(entities))

                for sent_idx, entities in gt_ent_rcg.items():
                    # Convert list of tuples to dictionary to remove duplicates
                    # This won't work with dict.fromkeys since tuples with same entity name but different categories will be treated as duplicates
                    unique_entities = {}
                    for entity, categories in entities:
                        if entity in unique_entities:
                            # If entity already exists, merge the categories
                            unique_entities[entity] = list(set(unique_entities[entity] + categories))
                        else:
                            unique_entities[entity] = categories
                    
                    # Convert back to list of tuples
                    gt_ent_rcg_rmd[sent_idx] = [(entity, categories) for entity, categories in unique_entities.items()]
                
                # Remove duplicates from gt_sro
                for sent_idx, triplets in gt_sro.items():
                    # Use a set to track unique triplets
                    unique_triplets = []
                    seen = set()
                    for triplet in triplets:
                        # Create a hashable representation of the triplet
                        triplet_key = (triplet[0], triplet[1], triplet[2])
                        if triplet_key not in seen:
                            seen.add(triplet_key)
                            unique_triplets.append(triplet)
                    gt_sro_rmd[sent_idx] = unique_triplets

                if args.candidate_type == 'gt_sro_review':
                    final_candidates = gt_sro_rmd
                elif args.candidate_type == 'gt_sro':
                    final_candidates = gt_sro_rmd
                elif args.candidate_type == 'gt_so':
                    final_candidates = gt_so_rmd
                elif args.candidate_type == 'gt_s':
                    final_candidates = gt_s_rmd
                elif args.candidate_type == 'gt_ent_rcg':
                    final_candidates = gt_ent_rcg_rmd
                else:
                    final_candidates = gt_so_rmd

                if args.mode not in ['maira', 'maira_cascade', 'rexerr']:
                    if len(section_triplet_list) != df_section['sent_idx'].max() :
                        print(f"len(section_triplet_list) != df_section['sent_idx'].max() : {len(section_triplet_list)} != {df_section['sent_idx'].max()}")
                    for sent_idx in range(1, df_section['sent_idx'].max() + 1):
                        if args.candidate_type == 'no_candidates':
                            section_data.append({
                            'subject_id': df['subject_id'].iloc[0],
                            'study_id': df['study_id'].iloc[0],
                            'section': section,
                            'passage': report_section_list[sent_idx-1],
                            'relations': RELATIONS,
                            'triplet_list': section_triplet_list[sent_idx-1],
                            'same_triplet_list': section_same_triplet_list[sent_idx-1],
                            'data_from': args.mode,
                            'json_input': {
                                'sent_idx': sent_idx,
                                'sentence': df_section[df_section['sent_idx'] == sent_idx]['sent'].values[0],
                            } # GT candidate를 사용하거나 candidate를 아예 사용하지 않을 경우 쓰는 json input
                            })
                        else:
                            section_data.append({
                            'subject_id': df['subject_id'].iloc[0],
                            'study_id': df['study_id'].iloc[0],
                            'section': section,
                            'passage': report_section_list[sent_idx-1],
                            'relations': RELATIONS,
                            'triplet_list': section_triplet_list[sent_idx-1],
                            'same_triplet_list': section_same_triplet_list[sent_idx-1],
                            'data_from': args.mode,
                            'json_input': {
                                'sent_idx': sent_idx,
                                'sentence': df_section[df_section['sent_idx'] == sent_idx]['sent'].values[0],
                                'candidates': final_candidates[sent_idx]
                            } # GT candidate를 사용하거나 candidate를 아예 사용하지 않을 경우 쓰는 json input
                            })

                else:
                    for sent_idx in range(1, df_section['sent_idx'].max() + 1):
                        section_data.append({
                            'subject_id': df['subject_id'].iloc[0],
                            'study_id': df['study_id'].iloc[0],
                            'section': section,
                            'passage': report_section_list[sent_idx-1],
                            'relations': RELATIONS,
                            'triplet_list': section_triplet_list[sent_idx-1],
                            'same_triplet_list': section_same_triplet_list[sent_idx-1],
                            'data_from': args.mode,
                            'json_input': None
                        })
                    
    else:
        section_data = []
        for col in args.report_col_name:
            section_triplet_list = []
            section_same_triplet_list = []
            
            # Create a new dataframe with only the gt_report column
            _, report_section_list = prepend_idx(df, args, col)
            
            # Iterate through each row in the dataframe
            for sent_itr in report_section_list:
                section_data.append({
                    'subject_id': int(df['subject_id'].iloc[0]) if 'subject_id' in df.columns else None,
                    'study_id': int(df['study_id'].iloc[0]) if 'study_id' in df.columns else None,
                    'section': col,
                    'passage': sent_itr,
                    'relations': RELATIONS,
                    'triplet_list': None,
                    'same_triplet_list': None,
                    'data_from': args.mode
                })
            
            
    return section_data


def process_study_data(args_tuple):
    """Process a single study ID to create study data"""
    id, test_df, args = args_tuple
    
    df = test_df[test_df['study_id'] == id]
    results = []
    
    if args.unit == 'report':
        results.append((str(0), create_report_level_data(df, args)))
    elif args.unit == 'section':
        section_data = create_section_level_data(df, args)
        for i, data in enumerate(section_data):
            results.append((str(i), data))
    elif args.unit == 'sent':
        sent_data = create_sent_level_data(df, args)
        for i, data in enumerate(sent_data):
            results.append((str(i), data))
    
    return results

def create_input(args):
    if args.mode in ['gold_eval', 'maira', 'maira_cascade', 'rexerr', 'medversa', 'rgrg', 'cvt2distilgpt2', 'lingshu', 'medgemma']:
        gold = pd.read_csv(args.gold_path)
        gold = gold.copy()
        gold.loc[:, 'cat'] = gold['cat'].replace(['lf', 'If'], 'pf')
        gold.loc[:, 'evidence'] = gold['evidence'].str.replace(r'idx(\d+)', r'obj_ent_idx\1', regex=True)
        gold.loc[:, 'associate'] = gold['associate'].str.replace(r'idx(\d+)', r'obj_ent_idx\1', regex=True)
        gold.loc[:, 'location'] = gold['location'].str.replace(r'loc:\s*', '', regex=True).str.replace(r'det:\s*', '', regex=True)
        
        test_subject = ['p10274145', 'p10523725', 'p10886362', 'p10959054', 'p12433421', 
            'p15321868', 'p15446959', 'p15881535', 'p17720924', 'p18079481']

        # Delete the original location column and rename location2 to location
        if 'dx_uncertainty' in gold.columns:
            gold = gold.rename(columns={'dx_uncertainty': 'dx_certainty'})

        if args.mode in ['maira_cascade', 'maira']:
            if args.holdout_devset:
                test_df = gold[(gold['subject_id'].isin(test_subject))&
                           (~gold['study_id'].isin(args.dev_list))]
                
                overlap_ids = set(args.dev_list).intersection(set(test_df['study_id'].unique().tolist()))
                print(f"overlap between dev and {args.mode}: {len(overlap_ids)}")
                print(f"Exclude overlapped study IDs: {overlap_ids}")
            
                args.dev_list = list(set(args.dev_list) - overlap_ids)
                devset = gold[gold['study_id'].isin(args.dev_list)]
            else:
                devset = gold
                test_df = gold[gold['subject_id'].isin(test_subject)]
                        
            if args.mode == 'maira_cascade':
                maira_df = pd.read_json(args.maira2_cascade_report_path, lines=True)
            
            elif args.mode == 'maira':
                maira_df = pd.read_json(args.maira2_report_path, lines=True)
            
            maira_df = maira_df.rename(columns={'report': f'{args.mode}_report'})
            
            test_df = test_df.merge(maira_df, on=['subject_id', 'sequence'], how='left')
            test_df = test_df[((test_df['section'] == 'impr')&
                        (test_df[f'{args.mode}_report'].str.len() > 2))|
                        ((test_df['section'] == 'find')&
                        (test_df[f'{args.mode}_report'].str.len() > 2))].drop_duplicates(subset=['study_id'])
            args.test_std_ids = test_df['study_id'].unique().tolist()
            print(f"Run Subject {len(test_df.subject_id.unique())}, Study {len(test_df.study_id.unique())}")

        elif args.mode in ['medversa', 'rgrg', 'cvt2distilgpt2', 'lingshu', 'medgemma']:
            if args.holdout_devset:
                test_df = gold[(gold['subject_id'].isin(test_subject))&
                           (~gold['study_id'].isin(args.dev_list))]
                
                overlap_ids = set(args.dev_list).intersection(set(test_df['study_id'].unique().tolist()))
                print(f"overlap between dev and {args.mode}: {len(overlap_ids)}")
                print(f"Exclude overlapped study IDs: {overlap_ids}")
            
                args.dev_list = list(set(args.dev_list) - overlap_ids)
                devset = gold[gold['study_id'].isin(args.dev_list)]
            else:
                devset = gold
                test_df = gold[gold['subject_id'].isin(test_subject)]

            if args.mode == 'medversa':
                df = pd.read_csv(args.medversa_report_path)
            
            elif args.mode == 'rgrg':
                df = pd.read_csv(args.rgrg_report_path)
            
            elif args.mode == 'cvt2distilgpt2':
                df = pd.read_csv(args.cvt2distilgpt2_report_path)
            
            elif args.mode == 'lingshu':
                df = pd.read_json(args.lingshu_path, lines=True)
                df['subject_id'] = df['image_path'].str.extract(r'/home/data_storage/mimic-cxr-jpg/2.0.0/files/p\d+/(p\d+)/s\d+/')[0]
                df['study_id'] = df['image_path'].str.extract(r'/home/data_storage/mimic-cxr-jpg/2.0.0/files/p\d+/p\d+/(s\d+)/')[0]

                df['report'] = df['raw_output'].apply(lambda x: x[0] if isinstance(x, list) and len(x) > 0 else '')
                df['report'] = df['report'].str.strip()
            
            elif args.mode == 'medgemma':
                df = pd.read_json(args.medgemma_path, lines=True)
                df['subject_id'] = df['image_path'].str.extract(r'/home/data_storage/mimic-cxr-jpg/2.0.0/files/p\d+/(p\d+)/s\d+/')[0]
                df['study_id'] = df['image_path'].str.extract(r'/home/data_storage/mimic-cxr-jpg/2.0.0/files/p\d+/p\d+/(s\d+)/')[0]
                df['report'] = df['raw_output'].fillna('')
                df['report'] = df['report'].str.strip()
                
            df = df.rename(columns={'report': f'{args.mode}_report'})
            
            if args.mode == 'medversa':#
                
                df_study_ids = df['study_id'].unique().tolist()
                
                for id in df_study_ids:
                    section_list = test_df[(test_df['study_id'] == id)&(test_df['section'] != 'hist')]['section'].unique().tolist()
                    
                    if 'impr' not in section_list:
                        new_row = test_df[(test_df['study_id'] == id)&(test_df['section'] == 'find')].copy()
                        new_row['section'] = 'impr'
                        test_df = pd.concat([test_df, new_row])
                    elif 'find' not in section_list:
                        new_row = test_df[(test_df['study_id'] == id)&(test_df['section'] == 'impr')].copy()
                        new_row['section'] = 'find'
                        test_df = pd.concat([test_df, new_row])
                    
                test_df = test_df.merge(df, on=['subject_id', 'study_id', 'section'], how='left')
                test_df = test_df[((test_df['section'] == 'impr')&
                            (test_df[f'{args.mode}_report'].str.len() > 2))|
                            ((test_df['section'] == 'find')&
                            (test_df[f'{args.mode}_report'].str.len() > 2))].drop_duplicates(subset=['study_id', 'section'])
            else:
                test_df = test_df.merge(df, on=['subject_id', 'study_id'], how='left')
                test_df = test_df[((test_df['section'] == 'impr')&
                            (test_df[f'{args.mode}_report'].str.len() > 2))|
                            ((test_df['section'] == 'find')&
                            (test_df[f'{args.mode}_report'].str.len() > 2))].drop_duplicates(subset=['study_id'])
            

            args.test_std_ids = test_df['study_id'].unique().tolist()
            print(f"Run Subject {len(test_df.subject_id.unique())}, Study {len(test_df.study_id.unique())}")

        elif args.mode == 'rexerr':
            rexerr_df = pd.read_csv(args.rexerr_report_path)
            rexerr_df['study_id'] = 's' + rexerr_df['study_id'].astype(str)
            rexerr_df = rexerr_df[['study_id', 'error_findings', 'error_impression']]
            
            if args.holdout_devset:
                test_df = gold[(gold['subject_id'].isin(test_subject))&
                        (~gold['study_id'].isin(args.dev_list))]

                overlap_ids = set(args.dev_list).intersection(set(test_df['study_id'].unique().tolist()))
                print(f"overlap between dev and {args.mode}: {len(overlap_ids)}")
                print(f"Exclude overlapped study IDs: {overlap_ids}")
            
                args.dev_list = list(set(args.dev_list) - overlap_ids)
                devset = gold[gold['study_id'].isin(args.dev_list)]
            else:
                devset = gold
                test_df = gold[gold['subject_id'].isin(test_subject)]
                
            test_df = test_df.merge(rexerr_df, on=['study_id'], how='left')
            args.test_std_ids = test_df['study_id'].unique().tolist()
            
            print(f"Run Subject {len(test_df.subject_id.unique())}, Study {len(test_df.study_id.unique())}")
            
        else:
            if args.except_80std:
                    # List of 80 study IDs to exclude
                    args.test_std_ids = ['s59166131', 's52241282', 's52316568', 's53941324', 's51185902',
                                    's54594848', 's51723789', 's53652133', 's57513198', 's57761141',
                                    's59071382', 's59191972', 's57959841', 's54843884', 's50126222',
                                    's51656138', 's58274962', 's55372843', 's52705433', 's59027235',
                                    's52246418', 's54669609', 's54164323', 's55947318', 's53100359',
                                    's52268728', 's50323020', 's51202805', 's51621137', 's57053848',
                                    's58521232', 's51233868', 's51280998', 's55849664', 's53164365',
                                    's54932317', 's58466818', 's51807934', 's52440373', 's55049183',
                                    's50519818', 's57540712', 's51153135', 's50664785', 's58414605',
                                    's50259315', 's53403421', 's51837713', 's51766355', 's53818162',
                                    's59816233', 's50857625', 's52523882', 's50894711', 's56400373',
                                    's54552753', 's58155125', 's52969022', 's56556003', 's54350641',
                                    's50697229', 's52573647', 's55134684', 's56241369', 's53967875',
                                    's53002522', 's54541565', 's53607277', 's55420918', 's56998267',
                                    's59522601', 's52321096', 's57397512', 's59568059', 's52198118',
                                    's52546911', 's54174765', 's57561035', 's53424979', 's57874436']
                    
                    # Get all study IDs that are not in dev_list and not in the 80 excluded studies
                    if args.holdout_devset:
                        test_df = gold[(~gold['study_id'].isin(args.dev_list)) & 
                                (~gold['study_id'].isin(args.test_std_ids))]

                        overlap_ids = set(args.dev_list).intersection(set(test_df['study_id'].unique().tolist()))
                        print(f"overlap between dev and {args.mode}: {len(overlap_ids)}")
                        print(f"Exclude overlapped study IDs: {overlap_ids}")
                    
                        args.dev_list = list(set(args.dev_list) - overlap_ids)
                        devset = gold[gold['study_id'].isin(args.dev_list)]
                    else:
                        devset = gold
                        test_df = gold[gold['subject_id'].isin(args.test_std_ids)]
                        
                    all_test_ids = test_df['study_id'].unique().tolist()
                                        
                    # If no batch_itr specified, use all test IDs
                    args.test_std_ids = all_test_ids
                    print(f"No batch_itr specified. Processing all {len(args.test_std_ids)} study IDs")
                    print(f"Total Samples: {len(args.test_std_ids)}")
                    
                    # Update test_df to only include the selected batch
                    test_df = gold[gold['study_id'].isin(args.test_std_ids)]
            
            else:
                if args.holdout_devset:
                    test_df = gold[(~gold['study_id'].isin(args.dev_list))]

                    overlap_ids = set(args.dev_list).intersection(set(test_df['study_id'].unique().tolist()))
                    print(f"overlap between dev and {args.mode}: {len(overlap_ids)}")
                    print(f"Exclude overlapped study IDs: {overlap_ids}")
                
                    args.dev_list = list(set(args.dev_list) - overlap_ids)
                    devset = gold[gold['study_id'].isin(args.dev_list)]
                else:
                    devset = gold
                    test_df = gold
                        
                args.test_std_ids = test_df['study_id'].unique().tolist()
                            
        print(f'\n {len(args.test_std_ids)} studies in test set')
        # Use multiprocessing to process study data in parallel
        print(f"Starting multiprocessing with {min(cpu_count(), 64)} workers")
        process_args = [(id, test_df, args) for id in args.test_std_ids]
        

        with Pool(processes=min(cpu_count(), 64)) as pool:
            results = list(tqdm(
                pool.imap(process_study_data, process_args),
                total=len(process_args),
                desc="Processing study data"
            ))
        
        # Flatten results and create study_data dictionary
        study_data = {}
        idx = 0
        for result_list in results:
            for offset, data in result_list:
                study_data[str(idx)] = data
                idx += 1

    elif args.mode == 'rexval' or args.mode == 'silver_eval':
        gold = pd.read_csv(args.gold_path)
        gold = gold.copy()
        gold.loc[:, 'cat'] = gold['cat'].replace(['lf', 'If'], 'pf')
        gold.loc[:, 'evidence'] = gold['evidence'].str.replace(r'idx(\d+)', r'obj_ent_idx\1', regex=True)
        gold.loc[:, 'associate'] = gold['associate'].str.replace(r'idx(\d+)', r'obj_ent_idx\1', regex=True)
        gold.loc[:, 'location'] = gold['location'].str.replace(r'loc:\s*', '', regex=True).str.replace(r'det:\s*', '', regex=True)
        # Delete the original location column and rename location2 to location
        if 'dx_uncertainty' in gold.columns:
            gold = gold.rename(columns={'dx_uncertainty': 'dx_certainty'})

        if args.mode == 'rexval':
            test_df = pd.read_csv(args.rexval_report_path)

        elif args.mode == 'silver_eval':
            test_df = pd.read_csv(args.silver_eval_report_path)
            # test_df = test_df[test_df['split'] == 'test'].sample(n=2, random_state=42)
        
        all_test_ids = test_df['study_id'].unique().tolist()
                
        if args.holdout_devset:
            overlap_ids = set(args.dev_list).intersection(set(test_df['study_id'].unique().tolist()))
            print(f"overlap between dev and {args.mode}: {len(overlap_ids)}")
            print(f"Exclude overlapped study IDs: {overlap_ids}")
            args.dev_list = list(set(args.dev_list) - overlap_ids)
            devset = gold[gold['study_id'].isin(args.dev_list)]
            test_df = gold[(gold['subject_id'].isin(test_subject))&
                        (~gold['study_id'].isin(args.dev_list))]
        else:
            devset = gold
            
        args.test_std_ids = test_df['study_id'].unique().tolist()
        
        print(f'\n {len(args.test_std_ids)} studies in test set')
    
        print(f"Test std ids: {len(args.test_std_ids)}")
        
        input_file_path = os.path.join(
            args.output_dir,
            f'{args.mode}_{args.candidate_type}_{args.unit}_input.json'
        )
        
        if not os.path.exists(input_file_path):
            print(f"input_file does not exist: {input_file_path}")

            # Use multiprocessing to process study data in parallel
            print(f"Starting multiprocessing with {min(cpu_count(), 64)} workers")
            process_args = [(id, test_df, args) for id in args.test_std_ids]
            
            with Pool(processes=min(cpu_count(), 64)) as pool:
                results = list(tqdm(
                    pool.imap(process_study_data, process_args),
                    total=len(process_args),
                    desc="Processing study data"
                ))
            
            # Flatten results and create study_data dictionary
            study_data = {}
            idx = 0
            for result_list in results:
                for offset, data in result_list:
                    study_data[str(idx)] = data
                    idx += 1            
        else:
            print(f"input_file exists: {input_file_path}")


    if args.mode == 'silver_eval' and os.path.exists(input_file_path):
        study_data = json.load(open(input_file_path, 'r'))
        print(f"input_file loaded: {input_file_path}")
    
    else:
        # Save sentence-level data  
        output_file = os.path.join(
            args.output_dir,
            f'{args.mode}_{args.candidate_type}_{args.unit}_input.json'
        )

        # 혹시 같은 이름의 디렉토리가 있으면 제거
        if os.path.isdir(output_file):
            import shutil
            shutil.rmtree(output_file)

        os.makedirs(args.output_dir, exist_ok=True)

        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(study_data, f, ensure_ascii=False, indent=4)

        print(f'Saved input data to {output_file}')
    return study_data, devset


def generate_task(custom_id, conversation, args, system_message=None):

    if args.deployment_name.isin['gpt-4o-mini-batchapi', 'gpt-4o-batch']:
        task = {
            "custom_id": f"{custom_id}",
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {
                    # This is what you would have in your Chat Completions API call
                    "model": args.deployment_name,
                    "temperature": 0.1,
                    "messages": conversation
            }
        }
    elif args.deployment_name == 'o3-mini-batch':
        task = {
            "custom_id": f"{custom_id}",
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {
                # This is what you would have in your Chat Completions API call
                "model": args.deployment_name,
                "messages": conversation
            }
        }
    else:
        task = {
            "custom_id": f"{custom_id}",
            "system_message": system_message,
            "messages": conversation
        }

    return task

class RelationEnum(str, Enum):
    Cat = "Cat"
    Dx_Status = "Dx_Status"
    Dx_Certainty = "Dx_Certainty"
    Location = "Location"
    Associate = "Associate"
    Evidence = "Evidence"
    Morphology = "Morphology"
    Distribution = "Distribution"
    Measurement = "Measurement"
    Severity = "Severity"
    Comparison = "Comparison"
    Onset = "Onset"
    NoChange = "No Change"
    Improved = "Improved"
    Worsened = "Worsened"
    Placement = "Placement"
    PastHx = "Past Hx"
    OtherSource = "Other Source"
    AssessmentLimitations = "Assessment Limitations"

class EntityRelation(BaseModel):
    """Represents a single relation of an entity with its value"""
    relation: RelationEnum = Field(..., description="The relation name among the predefined types")
    value: str = Field(..., description="The value corresponding to the relation")
    obj_ent_idx: Optional[int] = Field(
        None,
        description="For Associate/Evidence relations, the ent_idx of the object entity"
    )

    @root_validator(skip_on_failure=True)
    def require_obj_ent_idx_for_certain_relations(cls, values):
        rel = values.get('relation')
        idx = values.get('obj_ent_idx')
        if rel in {RelationEnum.Associate, RelationEnum.Evidence} and idx is None:
            raise ValueError("'obj_ent_idx' must be provided for Associate and Evidence relations")
        return values

class Entity(BaseModel):
    """Represents a single entity with all its extracted relations"""
    name: str = Field(..., description="The name of the entity, exactly as provided")
    sent_idx: int = Field(..., description="Index of the sentence from which this entity was extracted")
    ent_idx: int = Field(..., description="Unique identifier for this entity within the report section")
    relations: List[EntityRelation] = Field(..., description="List of all relations for this entity")

    @root_validator(skip_on_failure=True)
    def validate_relations(cls, values):
        relations = values.get('relations', [])
        types = [rel.relation for rel in relations]
        # Must include exactly one Cat and one Status
        if types.count(RelationEnum.Cat) != 1 or types.count(RelationEnum.Dx_Status) != 1 or types.count(RelationEnum.Dx_Certainty) != 1:
            raise ValueError("Each entity must include exactly one 'Cat' relation and exactly one 'Status' relation and exactly one 'Dx_Certainty' relation")
        return values

class StructuredOutput(BaseModel):
    """Schema for structured extraction matching the provided output format"""
    entities: List[Entity] = Field(..., description="All extracted entities with their relations")

    @root_validator(skip_on_failure=True)
    def validate_entities(cls, values):
        entities = values.get('entities', [])
        if not entities:
            raise ValueError("Output must include at least one entity")
        # Ensure unique ent_idx and consistent
        idxs = [e.ent_idx for e in entities]
        if len(idxs) != len(set(idxs)):
            raise ValueError("Each entity must have a unique 'ent_idx'")
        return values

def convert_to_sr_structure(structured_output):
    """
    Convert a StructuredOutput object to a dictionary of triplets grouped by relation type.
    Format matches the output from the JSON parsing branch in parse_eval_format.
    
    Args:
        structured_output (StructuredOutput, dict, or str): Parsed output matching the Entity schema,
                                                           already formatted dictionary, or raw text containing JSON

    Returns:
        dict: Dictionary with relation types as keys and lists of triplets as values
    """
    import json
    import re
    
    # Initialize defaultdict to store triplets grouped by relation
    pred_triplet = defaultdict(list)
    
    # If structured_output is a string, extract JSON from it
    if isinstance(structured_output, str):
        try:
            # Try to extract JSON from text using various patterns
            json_text = None
            
            # Pattern 1: Look for ```json ... ``` blocks
            json_patterns = [
                r'```json\s*\n(.*?)\n```',
                r'```json\s*(.*?)```',
                r'```\s*json\s*\n(.*?)\n```',
                r'```\s*json\s*(.*?)```'
            ]
            
            for pattern in json_patterns:
                json_match = re.search(pattern, structured_output, re.DOTALL | re.IGNORECASE)
                if json_match:
                    json_text = json_match.group(1).strip()
                    break
            
            # Pattern 2: Look for { ... } blocks (standalone JSON)
            if not json_text:
                # Find the first complete JSON object
                brace_count = 0
                start_idx = structured_output.find('{')
                if start_idx != -1:
                    json_start = start_idx
                    for i, char in enumerate(structured_output[start_idx:], start_idx):
                        if char == '{':
                            brace_count += 1
                        elif char == '}':
                            brace_count -= 1
                            if brace_count == 0:
                                json_text = structured_output[json_start:i+1]
                                break
            
            # Pattern 3: Try to extract JSON after common keywords
            if not json_text:
                keywords = ['json:', 'output:', 'result:', 'answer:']
                for keyword in keywords:
                    idx = structured_output.lower().find(keyword)
                    if idx != -1:
                        remaining_text = structured_output[idx + len(keyword):].strip()
                        brace_idx = remaining_text.find('{')
                        if brace_idx != -1:
                            brace_count = 0
                            json_start = brace_idx
                            for i, char in enumerate(remaining_text[brace_idx:], brace_idx):
                                if char == '{':
                                    brace_count += 1
                                elif char == '}':
                                    brace_count -= 1
                                    if brace_count == 0:
                                        json_text = remaining_text[json_start:i+1]
                                        break
                            if json_text:
                                break
            
            # Parse the extracted JSON
            if json_text:
                try:
                    structured_output = json.loads(json_text)
                    print(f"✅ Successfully extracted and parsed JSON from text")
                except json.JSONDecodeError:
                    # Try to clean up common JSON formatting issues
                    json_text_cleaned = json_text.replace('\n', ' ').replace('\t', ' ')
                    # Remove extra commas before closing brackets
                    json_text_cleaned = re.sub(r',\s*}', '}', json_text_cleaned)
                    json_text_cleaned = re.sub(r',\s*]', ']', json_text_cleaned)
                    structured_output = json.loads(json_text_cleaned)
                    print(f"✅ Successfully parsed JSON after cleanup")
            else:
                print(f"❌ No JSON found in text. First 300 chars: {structured_output[:300]}...")
                return pred_triplet
                
        except json.JSONDecodeError as e:
            print(f"❌ Failed to parse extracted JSON: {e}")
            print(f"Extracted JSON text (first 500 chars): {json_text[:500] if json_text else 'None'}")
            return pred_triplet
        except Exception as e:
            print(f"❌ Error extracting JSON from text: {e}")
            print(f"Input text (first 300 chars): {structured_output[:300]}...")
            return pred_triplet
    
    # Check if structured_output is already a dictionary with relation keys
    if isinstance(structured_output, dict) and any(key in RELATIONS for key in structured_output.keys()):
        # Already in the correct format, return as is
        return structured_output
    
    # If it's a dictionary but with 'entities' key (parsed JSON)
    if isinstance(structured_output, dict) and 'entities' in structured_output:
        entities = structured_output['entities']
    # If it's a StructuredOutput object
    elif hasattr(structured_output, 'entities'):
        entities = structured_output.entities
    else:
        print("structured_output", structured_output)
        raise ValueError("Invalid structured_output format: must be a StructuredOutput object or a dictionary with 'entities' key")
    
    # Process entities
    for entity in entities:
        # Handle both object and dict formats
        if isinstance(entity, dict):
            entity_name = entity['name'].lower().strip()
            sent_idx = entity.get('sent_idx')
            ent_idx = entity.get('ent_idx')
            relations = entity.get('relations', [])
        else:
            entity_name = entity.name.lower().strip()
            sent_idx = entity.sent_idx
            ent_idx = entity.ent_idx
            relations = entity.relations
        
        # Process relations
        for relation in relations:
            if isinstance(relation, dict):
                relation_name = relation['relation'].lower()
                relation_value = relation['value'].lower().strip()
                obj_ent_idx = relation.get('obj_ent_idx')
            else:
                relation_name = relation.relation.value.lower() if hasattr(relation.relation, 'value') else str(relation.relation).lower()
                relation_value = relation.value.lower().strip()
                obj_ent_idx = relation.obj_ent_idx
            
            # Add triplet to the appropriate relation category
            pred_triplet[relation_name].append((
                entity_name, 
                relation_name, 
                relation_value,
                sent_idx,
                ent_idx,
                obj_ent_idx
            ))
    
    return pred_triplet

def convert_to_assistant_string(structured_output):
    """
    Convert a StructuredOutput object to a dictionary of triplets grouped by relation type.
    Format matches the output from the JSON parsing branch in parse_eval_format.
    
    Args:
        structured_output (StructuredOutput or dict): Parsed output matching the Entity schema
                                                     or already formatted dictionary

    Returns:
        dict: Dictionary with relation types as keys and lists of triplets as values
    """
    # Initialize defaultdict to store triplets grouped by relation
    extracted_entities = []
    
    
    # Check if structured_output is already a dictionary with relation keys
    if isinstance(structured_output, dict) and any(key in RELATIONS for key in structured_output.keys()):
        # Already in the correct format, return as is
        return structured_output
    
    # If it's a dictionary but with 'entities' key (parsed JSON)
    if isinstance(structured_output, dict) and 'entities' in structured_output:
        entities = structured_output['entities']
    # If it's a StructuredOutput object
    elif hasattr(structured_output, 'entities'):
        entities = structured_output.entities
    else:
        print("structured_output", structured_output)
        raise ValueError("Invalid structured_output format: must be a StructuredOutput object or a dictionary with 'entities' key")
    
    # Process entities
    for entity in entities:
        
        extracted_relations = []
        # Handle both object and dict formats
        if isinstance(entity, dict):
            entity_name = entity['name'].lower().strip()
            sent_idx = entity.get('sent_idx')
            ent_idx = entity.get('ent_idx')
            relations = entity.get('relations', [])
        else:
            entity_name = entity.name.lower().strip()
            sent_idx = entity.sent_idx
            ent_idx = entity.ent_idx
            relations = entity.relations
        
        for relation in relations:
            if isinstance(relation, dict):
                relation_name = relation['relation']
                relation_value = relation['value'].strip()
                obj_ent_idx = relation.get('obj_ent_idx')
            else:
                relation_name = relation.relation.value if hasattr(relation.relation, 'value') else str(relation.relation)
                relation_value = relation.value.strip()
                obj_ent_idx = relation.obj_ent_idx
            
            if relation_name in ['Associate', 'Evidence']:
                extracted_relations.append({
                    "relation": relation_name,
                    "value": relation_value,
                    "obj_ent_idx": obj_ent_idx
                })
            else:
                extracted_relations.append({
                    "relation": relation_name,
                    "value": relation_value,
                })
            
        extracted_entities.append({
            "name": entity_name, 
            "sent_idx": sent_idx,
            "ent_idx": ent_idx,
            "relations": extracted_relations
        })
    
    assistant_string = "OUTPUT: " + json.dumps({"entities": extracted_entities}, indent=2)
    return assistant_string

def evaluate_funct(gold_file_path, exp_file_path, args):
    exp_file = json.load(open(exp_file_path))
    gpt_data = {}
    for ann in exp_file['annotations']:
        try:
            result = parse_eval_format(ann['custom_id'], ann['model_output'])
            if result is not None:
                custom_id, pred_triplet = result
                gpt_data[custom_id] = pred_triplet
            else:
                print(f"Warning: Failed to parse output for custom_id: {ann['custom_id']}")
                gpt_data[ann['custom_id']] = {}
        except Exception as e:
            print(f"Error processing custom_id {ann['custom_id']}: {str(e)}")
            print(f"Problematic model_output: {ann['model_output'][:100]}...")
            gpt_data[ann['custom_id']] = {}

    save_path = f'./singleSR/eval/{args.mode}/{args.n_retrieval}_{args.candidate_type}_{args.deployment_name}/{args.output_format}/{args.unit}/{args.candidate_usage}'
    
    if not os.path.exists(save_path):
        os.makedirs(f'./{save_path}', exist_ok=True)

    triplet_path = f'{save_path}/pred_triplet_path.json'

    with open(triplet_path, 'w', encoding="utf-8") as f:
        json.dump(gpt_data, f, ensure_ascii=False, indent=4)

def run_sr_eval(data_path, triplet_path, save_path):
    SR_EVAL(data_path=data_path, gpt_pred_path=triplet_path, save_path=save_path)
    
def run_sro_eval(data_path, triplet_path, save_path, jaccard):
    SRO_EVAL(data_path=data_path, gpt_pred_path=triplet_path, save_path=save_path, jaccard=jaccard)
    
def run_gen_report_sr_eval(data_path, triplet_path, save_path):
    Gen_report_SR_EVAL(data_path=data_path, gpt_pred_path=triplet_path, save_path=save_path)
    
def run_gen_report_sro_eval(data_path, triplet_path, save_path, jaccard):
    Gen_report_SRO_EVAL(data_path=data_path, gpt_pred_path=triplet_path, save_path=save_path, jaccard=jaccard)

def evaluate_funct(gold_file_path, exp_file_path, args):
    exp_file = json.load(open(exp_file_path))
    
    gpt_data = {}
    for ann in exp_file['annotations']:
        
        try:
            result = parse_eval_format(ann['custom_id'], ann['model_output'])
            if result is not None:
                custom_id, pred_triplet = result
                gpt_data[custom_id] = pred_triplet
            else:
                print(f"Warning: Failed to parse output for custom_id: {ann['custom_id']}")
                gpt_data[ann['custom_id']] = {}
        
        except Exception as e:
            print(f"Error processing custom_id {ann['custom_id']}: {str(e)}")
            print(f"Problematic model_output: {ann['model_output'][:100]}...")
            gpt_data[ann['custom_id']] = {}


    if not args.dynamic_retrieval:
        if args.multi:
            save_path = f'./singleSR/eval/{args.mode}/M{args.n_retrieval}_{args.candidate_type}_{args.deployment_name}/{args.output_format}/{args.unit}/{args.candidate_usage}'
        else:
            save_path = f'./singleSR/eval/{args.mode}/{args.n_retrieval}_{args.candidate_type}_{args.deployment_name}/{args.output_format}/{args.unit}/{args.candidate_usage}'
    else:
        save_path = f'./singleSR/eval/{args.mode}/dynamic_{args.candidate_type}_{args.deployment_name}/{args.output_format}/{args.unit}/{args.candidate_usage}'
    
    if not os.path.exists(save_path):
        os.makedirs(f'./{save_path}', exist_ok=True)

    triplet_path = f'{save_path}/pred_triplet_path.json'

    with open(triplet_path, 'w', encoding="utf-8") as f:
        json.dump(gpt_data, f, ensure_ascii=False, indent=4)
    # 병렬 처리 실행
    if args.mode in ['test_80studies', 'test_2studies', 'gold_eval', 'maira', 'maira_cascade', 'rexerr', 'medversa', 'rgrg', 'cvt2distilgpt2', 'lingshu', 'medgemma']:
        # Process 객체 생성
        p1 = Process(target=run_sr_eval, args=(gold_file_path, triplet_path, save_path))
        p2 = Process(target=run_sro_eval, args=(gold_file_path, triplet_path, save_path, args.jaccard))
        
        # 프로세스 시작
        p1.start()
        p2.start()
        
        # 프로세스가 완료될 때까지 대기
        p1.join()
        p2.join()
        
        # 평가 결과 계산 (순차적으로 실행)
        cal_result_SR_EVAL(save_path)
        cal_result_SRO_EVAL(save_path)
        
    elif args.mode in ['rexval']:
        # Process 객체 생성
        p1 = Process(target=run_gen_report_sr_eval, args=(gold_file_path, triplet_path, save_path))
        p2 = Process(target=run_gen_report_sro_eval, args=(gold_file_path, triplet_path, save_path, args.jaccard))
        
        # 프로세스 시작
        p1.start()
        p2.start()
        
        # 프로세스가 완료될 때까지 대기
        p1.join()
        p2.join()
        
        # 평가 결과 계산 (순차적으로 실행)
        cal_result_gen_report_SR_EVAL(save_path)
        cal_result_gen_report_SRO_EVAL(save_path)
    
    return save_path          


def create_relation_dataframe(data_path=None, gpt_pred_path=None, save_path=None, relation_to_evaluate=None):
    with open(data_path, 'r') as file:
        data = json.load(file)
    
    eval_results = []
    with open(f"{save_path}/subject_predict.json", 'r') as file:
        for line in file:
            if line.strip():
                eval_results.append(json.loads(line))
    
    entity_rows = []
    
    for result in eval_results:
        custom_id = result['custom_id']
        report_type = result['report_type'] if 'report_type' in result else None
        sentences = result['sentences']
        study_id = data[custom_id]['study_id']
        subject_id = data[custom_id]['subject_id'] if 'subject_id' in data[custom_id] else None
                
        all_entities = {}
        
        for relation in RELATIONS:
            triplets = result['pred_triplet'].get(relation, [])
            
            right_entities = set(result['right_entities'].get(relation, []))
            wrong_entities = set(result['wrong_entities'].get(relation, []))
            
            for triplet in triplets:
                entity = triplet[0]
                rel = triplet[1]
                value = triplet[2]
                if len(triplet) > 3:
                    sent_idx = triplet[3]
                    ent_idx = triplet[4]
                else:
                    sent_idx = None
                    ent_idx = None
                
                # 엔티티가 처음 등장하면 기본 정보 초기화
                if entity not in all_entities:
                    # 기본 정보 설정
                    entity_data = {
                        'custom_id': custom_id,
                        'report_type': report_type,
                        'data_from': result['data_from'],
                        'study_id': study_id,
                        'subject_id': subject_id,
                        'sentences': sentences,
                        'sent_idx': sent_idx,
                        'ent_idx': ent_idx,
                        'entity': entity,
                        'is_correct': entity in right_entities
                    }
                    
                    # 모든 RELATIONS에 대해 기본값 None으로 초기화
                    for r in RELATIONS:
                        entity_data[r] = None
                    
                    all_entities[entity] = entity_data
                
                # relation을 컬럼으로 추가
                all_entities[entity][rel] = value
        
        # 수집된 엔티티 정보를 rows에 추가
        for entity, entity_data in all_entities.items():
            entity_rows.append(entity_data)
    
    df = pd.DataFrame(entity_rows)
    
    # 데이터프레임 저장
    df.to_csv(f"{save_path}/pred_SR_df.csv", index=False)
    return df

def create_relation_dataframe2(data_path=None, gpt_pred_path=None, save_path=None, relation_to_evaluate=None):
    with open(data_path, 'r') as file:
        data = json.load(file)
        
    eval_results = []
    with open(f"{save_path}/triplets_predict.json", 'r') as file: ####
        for line in file:
            if line.strip():
                eval_results.append(json.loads(line))
    
    entity_rows = []
    
    for result in eval_results:
        custom_id = result['custom_id']
        report_type = result['report_type'] if 'report_type' in result else None
        sentences = result['sentence']
        study_id = data[custom_id]['study_id']
        subject_id = data[custom_id]['subject_id'] if 'subject_id' in data[custom_id] else None
                
        all_entities = {}
        
        right_triplets = result['right_triplet_list']
        wrong_triplets = result['wrong_triplet_list']

        all_triplets = defaultdict(list)

        for triplet in right_triplets + wrong_triplets:
            all_triplets[triplet[1]].append(triplet)

        all_right_entities = defaultdict(list)
        all_wrong_entities = defaultdict(list)

        for triplet in right_triplets:
            all_right_entities[triplet[1]].append(triplet[0])
        for triplet in wrong_triplets:
            all_wrong_entities[triplet[1]].append(triplet[0])
        
        for relation in RELATIONS:
            
            right_entities = set(all_right_entities.get(relation, []))
            wrong_entities = set(all_wrong_entities.get(relation, []))
            
            for triplet in all_triplets[relation]:
                entity = triplet[0]
                rel = triplet[1]
                value = triplet[2]
                if len(triplet) > 3:
                    sent_idx = triplet[3]
                    ent_idx = triplet[4]
                    obj_ent_idx = triplet[5]
                else:
                    sent_idx = None
                    ent_idx = None
                    obj_ent_idx = None
                # 엔티티가 처음 등장하면 기본 정보 초기화
                if entity not in all_entities:
                    # 기본 정보 설정
                    entity_data = {
                        f'({sent_idx}, {ent_idx})': {
                            'custom_id': custom_id,
                            'report_type': report_type,
                            'data_from': result['data_from'],
                            'study_id': study_id,
                            'subject_id': subject_id,
                            'sentences': sentences,
                            'sent_idx': sent_idx,
                            'ent_idx': ent_idx,
                            'entity': entity,
                            'is_correct': entity in right_entities
                        }
                    }
                    
                    # 모든 RELATIONS에 대해 기본값 None으로 초기화
                    for r in RELATIONS:
                        entity_data[f'({sent_idx}, {ent_idx})'][r] = None
                    
                    all_entities[entity] = entity_data

                elif f'({sent_idx}, {ent_idx})' not in all_entities[entity]:
                    # 기본 정보 설정
                    entity_data = {
                        f'({sent_idx}, {ent_idx})': {
                            'custom_id': custom_id,
                            'report_type': report_type,
                            'data_from': result['data_from'],
                            'study_id': study_id,
                            'subject_id': subject_id,
                            'sentences': sentences,
                            'sent_idx': sent_idx,
                            'ent_idx': ent_idx,
                            'entity': entity,
                            'is_correct': entity in right_entities
                        }
                    }
                    
                    # 모든 RELATIONS에 대해 기본값 None으로 초기화
                    for r in RELATIONS:
                        entity_data[f'({sent_idx}, {ent_idx})'][r] = None
                    
                    # 기존 엔티티 데이터에 새로운 위치 정보 추가
                    all_entities[entity].update(entity_data)
                
                # relation을 컬럼으로 추가
                # If the relation value is None, set it to the new value
                # If it already has a value, append the new value with a comma
                current_value = all_entities[entity][f'({sent_idx}, {ent_idx})'][rel]
                if current_value is None:
                    if rel in ['associate', 'evidence']:
                        all_entities[entity][f'({sent_idx}, {ent_idx})'][rel] = f"{value}, idx{obj_ent_idx}"
                    else:
                        all_entities[entity][f'({sent_idx}, {ent_idx})'][rel] = value
                else:
                    if rel in ['associate', 'evidence']:
                        all_entities[entity][f'({sent_idx}, {ent_idx})'][rel] = f"{current_value}, {value}, idx{obj_ent_idx}"
                    else:
                        all_entities[entity][f'({sent_idx}, {ent_idx})'][rel] = f"{current_value}, {value}"
        
        # 수집된 엔티티 정보를 rows에 추가
        for entity, entity_datas in all_entities.items():
            for entity_data in entity_datas.values():
                entity_rows.append(entity_data)
    
    df = pd.DataFrame(entity_rows)
    
    # 데이터프레임 저장
    df.to_csv(f"{save_path}/pred_SR_df.csv", index=False)
    
    return df


def create_relation_dataframe3(data_path=None, gpt_pred_path=None, save_path=None, relation_to_evaluate=None):
    with open(data_path, 'r') as file:
        data = json.load(file)

    # gpt_pred_path에서 실험 JSON 읽기 (annotations > model_output 사용)
    with open(gpt_pred_path, 'r') as file:
        pred_data = json.load(file)

    annotations = pred_data.get('annotations', [])

    # 최종 DF의 행들을 모을 버퍼
    entity_rows = []

    for ann in annotations:
        custom_id = ann.get('custom_id')
        if custom_id is None or custom_id not in data:
            continue

        # gold 메타데이터
        study_id = data[custom_id].get('study_id')
        subject_id = data[custom_id].get('subject_id')
        report_type = data[custom_id].get('section')
        sentences = data[custom_id].get('passage')

        model_output = ann.get('model_output', {})

        # (entity, sent_idx, ent_idx) 단위로 집계
        rows_by_key = {}

        # 모든 relation 채널을 순회
        for rel_key, triples in model_output.items():
            rel_norm = rel_key.lower()
            if rel_norm not in RELATIONS:
                continue
            for t in triples:
                if not isinstance(t, list) or len(t) < 3:
                    continue
                entity = t[0]
                value = t[2]
                sent_idx = t[3] if len(t) > 3 else None
                ent_idx = t[4] if len(t) > 4 else None
                obj_ent_idx = t[5] if len(t) > 5 else None

                key = (entity, sent_idx, ent_idx)
                if key not in rows_by_key:
                    row = {
                        'custom_id': custom_id,
                        'report_type': report_type,
                        'data_from': 'silver_eval',
                        'study_id': study_id,
                        'subject_id': subject_id,
                        'sentences': sentences,
                        'sent_idx': sent_idx,
                        'ent_idx': ent_idx,
                        'entity': entity,
                    }
                    for r in RELATIONS:
                        row[r] = None
                    rows_by_key[key] = row

                # 값 설정(여러 값이 있을 수 있어 누적)
                current_val = rows_by_key[key][rel_norm]
                if rel_norm in ['associate', 'evidence'] and obj_ent_idx is not None:
                    val_to_set = f"{value}, idx{obj_ent_idx}"
                else:
                    val_to_set = value

                if current_val is None:
                    rows_by_key[key][rel_norm] = val_to_set
                else:
                    rows_by_key[key][rel_norm] = f"{current_val}, {val_to_set}"

        # 수집된 엔티티 인스턴스 추가
        entity_rows.extend(rows_by_key.values())

    if not os.path.exists(save_path):
        os.makedirs(save_path)
        
    df = pd.DataFrame(entity_rows)
    df.to_csv(f"{save_path}/pred_SR_df.csv", index=False)
    return df


def visualize_table(data_path, output_path=None, metrics=None, highlight_best=True, 
                   figsize=(14, 0.5), style='whitegrid'):
    # Load data
    df = pd.read_csv(data_path)
    
    if df.empty:
        print("No data to visualize")
        return
    
    # Set default metrics if not provided
    if metrics is None:
        metrics = ['SR P', 'SR R', 'SR F1', 'SRO P', 'SRO R', 'SRO F1']
    
    # Function to shorten model names
    def shorten_model_name(name):
        parts = name.split('-')
        return '-'.join(parts[:min(3, len(parts))])
    
    # Shorten model names
    df['Short Model'] = df['Model'].apply(shorten_model_name)
    
    # Select only the columns we want to display
    display_cols = ['Short Model', 'Cand. Rate'] + metrics
    df_display = df[['Short Model', 'Cand. Rate'] + metrics].copy()
    
    # Round numeric values for display
    for col in df_display.columns:
        if col not in ['Short Model', 'Cand. Rate'] and df_display[col].dtype in [np.float64, np.int64]:
            df_display[col] = df_display[col].round(1)
    
    # Sort by the last metric (usually the most important one) in descending order
    df_display = df_display.sort_values(metrics[-1], ascending=False)
    
    # Set the style - use a cleaner style for NeurIPS
    plt.style.use('seaborn-v0_8-whitegrid')
    
    # Calculate figure size based on number of rows
    num_rows = len(df_display)
    fig_height = max(figsize[1] * num_rows, 4)  # Ensure minimum height
    fig = plt.figure(figsize=(figsize[0], fig_height))
    
    # Create axis without frame
    ax = plt.subplot(111, frame_on=False)
    
    # Hide axes
    ax.xaxis.set_visible(False) 
    ax.yaxis.set_visible(False)
    
    # Format the data for display - add % to numeric values
    formatted_data = df_display.copy()
    for col in metrics:
        formatted_data[col] = formatted_data[col].apply(lambda x: f"{x:.1f}")
    
    # Rename the column header for better display
    formatted_data = formatted_data.rename(columns={'Short Model': 'Model'})
    display_cols[0] = 'Model'  # Update display_cols to match
    
    # Create the table with NeurIPS-style formatting
    table = ax.table(
        cellText=formatted_data.values,
        colLabels=formatted_data.columns,
        loc='center',
        cellLoc='center',
        colColours=['#f0f0f0'] * len(formatted_data.columns)
    )
    
    # Set table properties for NeurIPS style
    table.auto_set_font_size(False)
    table.set_fontsize(11)  # Slightly smaller font for academic style
    table.scale(1.1, 1.6)   # Adjust cell size for NeurIPS style
    
    # Add grid lines using the older API
    # Set edges for all cells to create grid effect
    for key, cell in table.get_celld().items():
        cell.set_linewidth(0.8)
        cell.set_edgecolor('black')
    
    # Highlight the best value in each metric column if requested
    if highlight_best:
        # Create a custom colormap for highlighting (subtle blue gradient)
        cmap = LinearSegmentedColormap.from_list('neurips_highlight', ['#ffffff', '#e1effe'])
        
        # Find and highlight best values
        for col in metrics:
            col_idx = display_cols.index(col)
            best_idx = df_display[col].idxmax()
            row_idx = df_display.index.get_indexer([best_idx])[0]
            
            # Get the cell and set its color (adjust for header row)
            cell = table[(row_idx + 1, col_idx)]  # +1 for header
            cell.set_facecolor(cmap(0.8))  # Use a subtle highlight
            
            # Make the best value bold
            cell.get_text().set_fontweight('bold')
    
    # Add a title in NeurIPS style (more understated)
    plt.title('Performance Comparison', fontsize=16, fontweight='bold', pad=15)
    
    # Adjust layout
    plt.tight_layout()
    
    # Save or show the plot with higher DPI for publication quality
    if output_path:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Table visualization saved to {output_path}")
    else:
        plt.show()
    
    plt.close()

def calculate_overall_metrics(results):
    """
    Calculate overall metrics (F1, Precision, Recall, TP, FP, FN) from results.
    
    Args:
        results (dict): Results dictionary from SR_result.json or SRO_result.json
        
    Returns:
        dict: Dictionary containing overall metrics
    """
    total_tp = 0
    total_fp = 0
    total_fn = 0
    
    # Skip the 'all' key if it exists
    for relation, metrics in results.items():
        if relation == 'all':
            continue
        
        total_tp += metrics.get('tp', 0)
        total_fp += metrics.get('fp', 0)
        
        # Handle different formats between SR and SRO results
        if 'miss' in metrics:  # SRO format
            total_fn += metrics.get('miss', 0)
        elif 'all' in metrics:  # SR format
            total_fn += metrics.get('all', 0) - metrics.get('tp', 0)
    
    # Calculate precision, recall, and F1
    precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
    recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    return {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'tp': total_tp,
        'fp': total_fp,
        'fn': total_fn
    }

def generate_table_from_rexval(results, output_path=None):
    """
    Generate a comparison table from rexval mode results.
    
    Args:
        results (dict): Dictionary containing evaluation results by model
        output_path (str, optional): Path to save the output table. If None, returns the DataFrame
        
    Returns:
        pd.DataFrame or None: Comparison table if output_path is None, otherwise None
    """
    # Initialize results storage
    table_data = []
    
    print(f"Results keys: {list(results.keys())}")
    if 'triplets' in results:
        print(f"Triplets models: {list(results['triplets'].keys())}")
    if 'subj' in results:
        print(f"Subject models: {list(results['subj'].keys())}")
    

    # Process triplets results (SRO)
    if 'triplets' in results:
        for model_name, model_results in results['triplets'].items():
            # Calculate overall metrics
            total_tp = sum(model_results[r].get('tp', 0) for r in model_results if r in RELATIONS)
            total_fp = sum(model_results[r].get('fp', 0) for r in model_results if r in RELATIONS)
            total_miss = sum(model_results[r].get('miss', 0) for r in model_results if r in RELATIONS)
            
            # Calculate precision, recall, F1
            precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
            recall = total_tp / (total_tp + total_miss) if (total_tp + total_miss) > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
            
            # Add to table data
            table_data.append({
                'Model': model_name,
                'SRO P': round(precision * 100, 1),
                'SRO R': round(recall * 100, 1),
                'SRO F1': round(f1 * 100, 1)
            })
    print(f"Generated {len(table_data)} table entries")

    # Process subject results (SR)
    if 'subj' in results:
        for model_name, model_results in results['subj'].items():
            # Find existing entry or create new one
            entry = next((item for item in table_data if item['Model'] == model_name), None)
            if entry is None:
                entry = {'Model': model_name}
                table_data.append(entry)
            
            # Calculate overall metrics
            total_tp = sum(model_results[r].get('tp', 0) for r in model_results if r in RELATIONS)
            total_fp = sum(model_results[r].get('fp', 0) for r in model_results if r in RELATIONS)
            total_all = sum(model_results[r].get('all', 0) for r in model_results if r in RELATIONS)
            total_miss = total_all - total_tp
            
            # Calculate precision, recall, F1
            precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
            recall = total_tp / total_all if total_all > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
            
            # Add to entry
            entry['SR P'] = round(precision * 100, 1)
            entry['SR R'] = round(recall * 100, 1)
            entry['SR F1'] = round(f1 * 100, 1)
    
    # Convert to DataFrame
    df = pd.DataFrame(table_data)
    
    # Sort by SRO F1 score (descending)
    if not df.empty and 'SRO F1' in df.columns:
        df = df.sort_values('SRO F1', ascending=False)
    
    if df.empty:
        print("Warning: No data was generated for the table")
    else:
        print(f"DataFrame columns: {df.columns.tolist()}")
        print(f"DataFrame shape: {df.shape}")
    
    # Save to file if output_path is provided
    if output_path:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        df.to_csv(output_path, index=False, float_format='%.1f')
        print(f"Results saved to {output_path}")
        return None
    
    return df

def generate_table(base_path='./singleSR/eval/test_2studies', args=None, output_path=None, mode=None):
    """
    Generate a comparison table of evaluation results across different models and candidate rates.
    
    Args:
        base_path (str): Base directory containing evaluation results
        output_path (str, optional): Path to save the output table. If None, returns the DataFrame
        mode (str, optional): Evaluation mode ('rexval' or None)
        
    Returns:
        pd.DataFrame or None: Comparison table if output_path is None, otherwise None
    """
    # Handle rexval mode
    if mode == 'rexval':
        results = {}
        
        # Debug: Print the base path
        print(f"Looking for rexval results in: {base_path}")
        
        # For rexval mode, we need to find all model directories
        model_dirs = [d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]
        print(f"Found model directories: {model_dirs}")
        
        # Initialize results for all models
        results = {'triplets': {}, 'subj': {}}
        
        # Process each model directory
        for model in model_dirs:
            model_path = os.path.join(base_path, model, 'SROSRO', 'section', '1')
            
            # Check if the model has results
            if not os.path.exists(model_path):
                print(f"No results found for model: {model}")
                continue
                
            # Load SRO results
            sro_path = f'{model_path}/SRO_result_by_model.json'
            if os.path.exists(sro_path):
                print(f"Loading SRO results from: {sro_path}")
                with open(sro_path, 'r') as f:
                    model_sro_results = json.load(f)
                    # Add to the overall results
                    for metric_model, metric_results in model_sro_results.items():
                        results['triplets'][f"{model}_{metric_model}"] = metric_results
            
            # Load SR results
            sr_path = f'{model_path}/SR_result_by_model.json'
            if os.path.exists(sr_path):
                print(f"Loading SR results from: {sr_path}")
                with open(sr_path, 'r') as f:
                    model_sr_results = json.load(f)
                    # Add to the overall results
                    for metric_model, metric_results in model_sr_results.items():
                        results['subj'][f"{model}_{metric_model}"] = metric_results
        
        # Debug: Print the loaded results
        print(f"Loaded triplets models: {list(results['triplets'].keys())}")
        print(f"Loaded subject models: {list(results['subj'].keys())}")
        
        return generate_table_from_rexval(results, output_path)
    
    # Standard mode (original implementation)
    # Find all model directories
    model_dirs = [d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]
    # Initialize results storage
    results = []
    
    # Process each model
    for model in model_dirs:
        # Find all candidate rates for this model
        cand_rates = []
        for rate_dir in glob.glob(f"{base_path}/{model}/SROSRO/{args.unit}/*"):
            if os.path.isdir(rate_dir):
                cand_rate = os.path.basename(rate_dir)
                if cand_rate.replace('.', '', 1).isdigit():  # Check if it's a number
                    cand_rates.append(cand_rate)
        
        # Process each candidate rate
        for cand_rate in cand_rates:
            sr_path = f"{base_path}/{model}/SROSRO/{args.unit}/{cand_rate}/SR_result.json"
            sro_path = f"{base_path}/{model}/SROSRO/{args.unit}/{cand_rate}/SRO_result.json"
            
            # Skip if either file doesn't exist
            if not (os.path.exists(sr_path) and os.path.exists(sro_path)):
                continue
                
            # Load results
            try:
                with open(sr_path, 'r') as f:
                    sr_results = json.load(f)
                with open(sro_path, 'r') as f:
                    sro_results = json.load(f)
            except json.JSONDecodeError:
                print(f"Error loading results for {model} with candidate rate {cand_rate}")
                continue
            
            # Calculate overall metrics for SR (subject recognition)
            sr_metrics = calculate_overall_metrics(sr_results)
            
            # Calculate overall metrics for SRO (subject-relation-object)
            sro_metrics = calculate_overall_metrics(sro_results)
            
            # Format metrics for publication (rounded to 3 decimal places)
            result_entry = {
                'Model': model,
                'Cand. Rate': float(cand_rate),
                'SR P': round(sr_metrics['precision'] * 100, 1),
                'SR R': round(sr_metrics['recall'] * 100, 1),
                'SR F1': round(sr_metrics['f1'] * 100, 1),
                'SRO P': round(sro_metrics['precision'] * 100, 1),
                'SRO R': round(sro_metrics['recall'] * 100, 1),
                'SRO F1': round(sro_metrics['f1'] * 100, 1),
            }
            
            results.append(result_entry)
    
    # Convert to DataFrame
    df = pd.DataFrame(results)
    
    # Sort by model and candidate rate
    if not df.empty:
        df = df.sort_values(['Model', 'Cand. Rate'])
    
    # Save to file if output_path is provided
    if output_path:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        df.to_csv(output_path, index=False, float_format='%.1f')
        print(f"Results saved to {output_path}")
        return None
    
    return df

def concat_batch_results(base_path, model=''):
    """
    여러 배치 폴더에 있는 동일한 이름의 JSON과 CSV 파일들을 합치는 함수
    
    Args:
        base_path (str): 기본 경로
        model (str): 모델 이름
    
    Returns:
        dict: 합쳐진 결과 파일들의 경로를 담은 딕셔너리
    """
    
    # 모델 경로 설정
    model_path = os.path.join(base_path, model)
    
    if not os.path.exists(model_path):
        print(f"경로를 찾을 수 없습니다: {model_path}")
        return {}
    
    # 배치 폴더 찾기 (0_547, 1_547 등)
    batch_folders = []
    for item in os.listdir(model_path):
        item_path = os.path.join(model_path, item)
        if os.path.isdir(item_path) and '_' in item:
            batch_folders.append(item_path)
    
    if not batch_folders:
        print(f"배치 폴더를 찾을 수 없습니다: {model_path}")
        return {}
    
    print(f"발견된 배치 폴더: {len(batch_folders)}")
    print(f"배치 폴더 목록: {batch_folders}")
    
    # 파일 유형별로 데이터를 저장할 딕셔너리
    json_data = defaultdict(list)       # 일반 JSON 파일
    csv_data = defaultdict(list)        # CSV 파일
    jsonl_data = defaultdict(list)      # 줄 단위 JSON 파일
    array_json_data = defaultdict(list) # 배열 형태의 JSON 파일
    
    # 각 배치 폴더에서 JSON 및 CSV 파일 찾기
    for batch_folder in batch_folders:
        print(f"처리 중: {batch_folder}")
        
        # JSON 파일 처리
        for json_file in glob.glob(os.path.join(batch_folder, "*.json")):
            file_name = os.path.basename(json_file)
            
            # 특정 파일은 JSONL 형식으로 처리
            if file_name in ['triplets_predict.json', 'subject_predict.json']:
                try:
                    with open(json_file, 'r', encoding='utf-8') as f:
                        for line in f:
                            line = line.strip()
                            if line:  # 빈 줄 무시
                                try:
                                    data = json.loads(line)
                                    jsonl_data[file_name].append(data)
                                except json.JSONDecodeError as e:
                                    print(f"JSONL 라인 파싱 오류: {json_file}, 라인: {line[:50]}..., 오류: {e}")
                except Exception as e:
                    print(f"JSONL 파일 처리 중 오류 발생: {json_file}, 오류: {e}")
            
            # 배열 형태의 JSON 파일 처리
            elif file_name in ['all_subject.json', 'all.json']:
                try:
                    with open(json_file, 'r', encoding='utf-8') as f:
                        content = f.read().strip()
                        if content and content[0] == '[' and content[-1] == ']':
                            try:
                                data = json.loads(content)
                                if isinstance(data, list):
                                    # 배열의 각 항목을 리스트에 추가
                                    array_json_data[file_name].extend(data)
                                else:
                                    print(f"배열 형태가 아닌 JSON 파일: {json_file}")
                            except json.JSONDecodeError as e:
                                print(f"배열 JSON 파일 파싱 오류: {json_file}, 오류: {e}")
                except Exception as e:
                    print(f"배열 JSON 파일 처리 중 오류 발생: {json_file}, 오류: {e}")
            
            # 일반 JSON 파일 처리
            elif file_name in ['SR_result.json', 'SRO_result.json', 'SR_result_no_sent_idx.json', 'SRO_result_no_sent_idx.json', 'pred_triplet_path.json']:
                try:
                    with open(json_file, 'r', encoding='utf-8') as f:
                        try:
                            data = json.load(f)
                            json_data[file_name].append(data)
                        except json.JSONDecodeError:
                            print(f"JSON 파일 형식 오류: {json_file}, 빈 파일이거나 잘못된 형식입니다.")
                except Exception as e:
                    print(f"JSON 파일 처리 중 오류 발생: {json_file}, 오류: {e}")
            
            # 기타 JSON 파일은 건너뛰기
            else:
                print(f"지원하지 않는 JSON 파일 형식: {file_name}, 건너뜁니다.")
        
        # CSV 파일 처리
        for csv_file in glob.glob(os.path.join(batch_folder, "*.csv")):
            file_name = os.path.basename(csv_file)
                
            try:
                # 빈 파일 확인
                if os.path.getsize(csv_file) == 0:
                    print(f"빈 CSV 파일: {csv_file}, 건너뜁니다.")
                    continue
                
                df = pd.read_csv(csv_file)
                if not df.empty:
                    csv_data[file_name].append(df)
                else:
                    print(f"빈 DataFrame: {csv_file}, 건너뜁니다.")
            except Exception as e:
                print(f"CSV 파일 처리 중 오류 발생: {csv_file}, 오류: {e}")
    
    # 결과 파일 경로를 저장할 딕셔너리
    result_files = {}
    
    # JSON 파일 합치기 및 저장
    for file_name, data_list in json_data.items():
        if not data_list:
            print(f"합칠 데이터가 없음: {file_name}, 건너뜁니다.")
            continue
            
        output_file = os.path.join(model_path, file_name)
        
        if file_name in ['SR_result.json', 'SRO_result.json', 'SR_result_no_sent_idx.json', 'SRO_result_no_sent_idx.json']:
            # 특수 처리: 결과 파일은 딕셔너리 형태로 합치기
            combined_data = {}
            for data in data_list:
                for key, value in data.items():
                    if key not in combined_data:
                        combined_data[key] = value
                    else:
                        # 기존 키가 있는 경우, 값을 합치기
                        if isinstance(value, dict) and isinstance(combined_data[key], dict):
                            for subkey, subvalue in value.items():
                                if subkey in combined_data[key]:
                                    # 숫자 값은 더하기
                                    if isinstance(subvalue, (int, float)) and isinstance(combined_data[key][subkey], (int, float)):
                                        combined_data[key][subkey] += subvalue
                                else:
                                    combined_data[key][subkey] = subvalue
            
            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump(combined_data, f, ensure_ascii=False, indent=2)
        else:
            # 일반 JSON 파일은 리스트로 합치기
            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump(data_list, f, ensure_ascii=False, indent=2)
        result_files[file_name] = output_file
        print(f"저장됨: {output_file}")
    
    # 배열 JSON 파일 합치기 및 저장
    for file_name, data_list in array_json_data.items():
        if not data_list:
            print(f"합칠 데이터가 없음: {file_name}, 건너뜁니다.")
            continue
            
        output_file = os.path.join(model_path, file_name)
        
        # 배열 형태의 JSON 파일은 하나의 배열로 합치기
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(data_list, f, ensure_ascii=False, indent=2)
        result_files[file_name] = output_file
        print(f"저장됨: {output_file}")
    
    # JSONL 파일 합치기 및 저장
    for file_name, data_list in jsonl_data.items():
        if not data_list:
            print(f"합칠 데이터가 없음: {file_name}, 건너뜁니다.")
            continue
            
        output_file = os.path.join(model_path, file_name)
        
        # 줄 단위 JSON 파일은 각 객체를 새 줄에 추가
        with open(output_file, 'w', encoding='utf-8') as f:
            for data in data_list:
                f.write(json.dumps(data, ensure_ascii=False) + '\n')
        result_files[file_name] = output_file
        print(f"저장됨: {output_file}")
    
    # CSV 파일 합치기 및 저장
    for file_name, df_list in csv_data.items():
        if not df_list:
            print(f"합칠 데이터가 없음: {file_name}, 건너뜁니다.")
            continue
            
        output_file = os.path.join(model_path, file_name)
        
        try:
            combined_df = pd.concat(df_list, ignore_index=True)
            if not combined_df.empty:
                combined_df.to_csv(output_file, index=False)
                result_files[file_name] = output_file
                print(f"저장됨: {output_file}")
            else:
                print(f"합쳐진 DataFrame이 비어 있음: {file_name}, 저장하지 않습니다.")
        except Exception as e:
            print(f"CSV 파일 합치기 중 오류 발생: {file_name}, 오류: {e}")



def visualize_rexval_metrics(data_path, output_path=None, metrics=['SR F1', 'SRO F1'], 
                            figsize=(16, 12), style='whitegrid'):
    """
    Visualize metrics for different models in rexval mode using line plots.
    
    Args:
        data_path (str): Path to the CSV file with comparison data
        output_path (str, optional): Path to save the visualization
        metrics (list): Metrics to visualize (e.g., ['SR F1', 'SRO F1'])
        figsize (tuple): Figure size (width, height)
        style (str): Seaborn style for the plot
    """
    # Load data
    df = pd.read_csv(data_path)
    
    if df.empty:
        print("No data to visualize")
        return
    
    # Filter out gt_report
    # df = df[~df['Model'].str.contains('gt_report')]
    
    if df.empty:
        print("No data left after filtering out gt_report")
        return
    
    # Extract model name from the full model name (e.g., 'gpt-4.1_radgraph' -> 'radgraph')
    df['Short Model'] = df['Model'].apply(lambda x: x.split('_')[-1] if '_' in x else x)
    
    # Set the style
    sns.set_style(style)
    
    # Create a figure with higher DPI for publication quality
    fig = plt.figure(figsize=figsize, dpi=100)
    ax = fig.add_subplot(111)
    
    # Define markers and line styles for the metrics
    markers = ['o', 's', '^', 'D', 'v']  # Different markers for each metric
    line_styles = ['-', '--', '-.', ':', '-']  # Different line styles for each metric
    colors = plt.cm.tab10(np.linspace(0, 1, len(metrics)))
    
    # Create x positions for the models
    models = df['Short Model'].unique()
    x = np.arange(len(models))
    
    # Store text objects for later adjustment
    texts = []
    
    # Create legend handles
    legend_handles = []
    
    # Plot each metric
    for j, metric in enumerate(metrics):
        if metric not in df.columns:
            print(f"Warning: Metric '{metric}' not found in the data.")
            continue
            
        # Get values for this metric across all models
        y_values = []
        for model in models:
            model_data = df[df['Short Model'] == model]
            if not model_data.empty:
                model_value = model_data[metric].values[0]
                y_values.append(model_value)
            else:
                y_values.append(0)  # Default value if model data is missing
        
        # Plot the line with explicit color
        line, = ax.plot(x, y_values, 
                      marker=markers[j % len(markers)], 
                      linestyle=line_styles[j % len(line_styles)],
                      linewidth=3, 
                      markersize=16, 
                      alpha=0.8,
                      color=colors[j],
                      label=metric)
        
        # Add to legend handles
        legend_handles.append(line)
        
        # Add labels at each point
        for i, (xi, yi) in enumerate(zip(x, y_values)):
            t = ax.text(xi, yi, f'{yi:.1f}', 
                       ha='center', 
                       va='bottom', 
                       fontsize=24, 
                       fontweight='bold',
                       color=colors[j])
            texts.append(t)
    
    # Use adjust_text to prevent overlapping
    try:
        from adjustText import adjust_text
        adjust_text(texts, 
                    arrowprops=dict(arrowstyle='-', color='gray', lw=0.8),
                    expand_points=(1.7, 1.7),
                    force_points=(0.6, 0.6))
    except ImportError:
        print("Warning: adjustText package not found. Text labels may overlap.")
    
    # Enhance the plot for publication quality
    ax.set_title('Model Performance Comparison', fontsize=18, fontweight='bold', pad=20)
    ax.set_xlabel('Model', fontsize=16, fontweight='bold', labelpad=15)
    ax.set_ylabel('Score (%)', fontsize=16, fontweight='bold', labelpad=15)
    
    # Set x-axis ticks to model names
    ax.set_xticks(x)
    ax.set_xticklabels(models, fontsize=14, fontweight='bold', rotation=45, ha='right')
    
    # Format y-axis as percentage with larger ticks
    ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.1f}%'))
    ax.tick_params(axis='y', which='major', labelsize=14)
    
    # Set y-axis to start from 0
    max_value = 100
    if not df[metrics].empty:
        max_value = max(df[metrics].max().max() * 1.1, 100)
    ax.set_ylim(0, max_value)
    
    # Make tick marks thicker
    ax.tick_params(width=2, length=8)
    
    # Customize grid
    ax.grid(True, linestyle='--', alpha=0.7, color='gray', linewidth=1.5)
    
    # Add legend with larger font - only if we have handles
    if legend_handles:
        ax.legend(handles=legend_handles, fontsize=14, loc='upper right')
    
    # Adjust layout
    plt.tight_layout()
    
    # Save or show the plot
    if output_path:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Visualization saved to {output_path}")
    else:
        plt.show()
    
    plt.close()        
        
def visualize_metrics(data_path, output_path=None, metrics=['SR F1', 'SRO F1'], 
                            figsize=(16, 12), style='whitegrid', y_min=40, y_max=100):
    """
    Visualize metrics for different models using separate plots for each metric.
    
    Args:
        data_path (str): Path to the CSV file with comparison data
        output_path (str, optional): Path to save the visualization
        metrics (list): Metrics to visualize (e.g., ['SR F1', 'SRO F1'])
        figsize (tuple): Figure size (width, height)
        style (str): Seaborn style for the plot
        y_min (int): Minimum value for y-axis
        y_max (int): Maximum value for y-axis
    """
    # Load data
    df = pd.read_csv(data_path)
    
    if df.empty:
        print("No data to visualize")
        return
    
    # Set the style
    sns.set_style(style)
    
    # Create a separate plot for each metric
    for metric in metrics:
        if metric not in df.columns:
            print(f"Warning: Metric '{metric}' not found in the data.")
            continue
        
        # Create figure
        plt.figure(figsize=figsize, dpi=100)
        
        # Get unique model-candidate combinations
        model_candidates = df[['Model']].copy()
        model_candidates['Short Model'] = model_candidates['Model'].apply(lambda x: x.split('_')[-1])
        model_candidates['Shot Type'] = model_candidates['Model'].apply(lambda x: x.split('_')[0])
        model_candidates['Candidate'] = model_candidates['Model'].apply(lambda x: x.split('_')[1] if len(x.split('_')) > 1 else 'no')
        
        # Get unique combinations
        unique_combinations = model_candidates[['Short Model', 'Candidate']].drop_duplicates()
        
        # Plot each combination
        for idx, (model, candidate) in unique_combinations.iterrows():
            # Get data for this combination
            mask = (model_candidates['Short Model'] == model) & (model_candidates['Candidate'] == candidate)
            combination_data = df[mask].copy()
            
            if len(combination_data) > 0:
                # Sort by shot type for proper line connection
                combination_data['Shot Type'] = combination_data['Model'].apply(lambda x: x.split('_')[0])
                combination_data = combination_data.sort_values('Shot Type')
                
                # Plot line
                plt.plot(combination_data['Shot Type'], combination_data[metric], 
                        marker='o', linestyle='-', linewidth=2, markersize=8,
                        label=f"{model}-{candidate}")
                
                # Add value labels
                for _, row in combination_data.iterrows():
                    plt.text(row['Shot Type'], row[metric], f'{row[metric]:.1f}',
                            ha='center', va='bottom', fontsize=12)
        
        # Customize plot
        plt.title(f'{metric} Performance by Shot Type', fontsize=20, pad=20)
        plt.xlabel('Shot Type', fontsize=16, labelpad=10)
        plt.ylabel('Score (%)', fontsize=16, labelpad=10)
        
        # Format y-axis as percentage
        plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.1f}%'))
        
        # Set y-axis limits
        plt.ylim(y_min, y_max)
        
        # Customize grid
        plt.grid(True, linestyle='--', alpha=0.7)
        
        # Add legend
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=12)
        
        # Adjust layout
        plt.tight_layout()
        
        # Save or show the plot
        if output_path:
            metric_output_path = os.path.join(output_path, f'{metric.replace(" ", "_")}.png')
            os.makedirs(os.path.dirname(metric_output_path), exist_ok=True)
            plt.savefig(metric_output_path, dpi=300, bbox_inches='tight')
            print(f"{metric} visualization saved to {metric_output_path}")
        else:
            plt.show()
        
        plt.close()

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='Process SR evaluation')
    parser.add_argument('--exp_name', type=str, default='test', help='Experiment name')
    parser.add_argument('--deployment_name', type=str, default='gpt-4o-mini', help='Model deployment name')
    parser.add_argument('--mode', type=str, default='section', help='Mode (sent, section, etc.)')
    parser.add_argument('--n_retrieval', type=int, default=5, help='Number of retrievals')
    parser.add_argument('--entity_types', type=str, nargs='+', default=['COF', 'NCD', 'Patient Info.', 'PF', 'CF', 'OTH'], 
                        help='Entity types to include')
    parser.add_argument('--relation_types', type=str, nargs='+', default=['Location', 'Associate', 'Evidence'], 
                        help='Relation types to include')
    parser.add_argument('--attribute_types', type=str, nargs='+', 
                        default=['Morphology', 'Distribution', 'Measurement', 'Severity',
                                'Comparison', 'Onset', 'No Change', 'Improved', 'Worsened',
                                'Placement', 'Past Hx', 'Other Source', 'Assessment Limitations'], 
                        help='Attribute types to include')
    parser.add_argument('--output_format', type=str, default='SRORO', help='Output format')
    parser.add_argument('--diverse_retrieval', action='store_true', default=True, help='Use diverse retrieval')
    parser.add_argument('--query', type=str, 
                        default="postoperative changes of left upper lobectomy are again seen with resection cavity completely opacified, without visualized pneumothorax.",
                        help='Query text')
    parser.add_argument('--gold_path', type=str, default='../gold_v5_add_section_report.csv',
                        help='Path to gold dataset')
    parser.add_argument('--vocab_path', type=str, default='../gold_revise/0.original_vocab_v2.xlsx',
                        help='Path to vocab file')
    parser.add_argument('--fast_retrieval', action='store_true', default=True,
                        help='Use fast retrieval')
    parser.add_argument('--candidate_usage', type=float, default=1,
                        help='Candidate usage ratio')

    args = parser.parse_args()
    args.dev_list = ['s51749906', 's50853840', 's51264956', 's52145612', 's54224807', 's59798967', 's52901628', 's51907814', 's51357526', 's50035498', 's50336741', 's50546279', 's58786693', 's52707748', 's53423060', 's59608718', 's57678258', 's55874928', 's52939782', 's54523680', 's59956784', 's50770541', 's56010471', 's50476602', 's55853389', 's57849643', 's52998783', 's58068113', 's58566283', 's50022945', 's56193921', 's53086061', 's50273882', 's55615214', 's55683961', 's50710771', 's59081164', 's57917788', 's56321140', 's50620677', 's57674897', 's57988469', 's53138800', 's51765454', 's53759718', 's50100756', 's58352175', 's57525852', 's51727838', 's51144460', 's54151331', 's55755138', 's56130174', 's50968695', 's58961408', 's55883502', 's58093109', 's54093116', 's58517699', 's56581797', 's57778607', 's57523636', 's53462360', 's57529728', 's54589789', 's58286219', 's58478940', 's50281752', 's50399800', 's52381727', 's50971332', 's59716296', 's50243114', 's51285349', 's54704786', 's57554917', 's53051689', 's59221699', 's59044985', 's56470564', 's53957798', 's57746739', 's59523573', 's53663749', 's54224166', 's55553875', 's57399078', 's59842808', 's52177069', 's51464763', 's54232769', 's57782283', 's55599778', 's51678067', 's50924449', 's56196471', 's58349137', 's52428827', 's56081926', 's58137643', 's51044625', 's51231499', 's55604705', 's54432661', 's53572321', 's58898395', 's57120453', 's52072042', 's57420525', 's50255843', 's55233589', 's51271572', 's50989704', 's59121133', 's53641457', 's51791247', 's57779343', 's58393560', 's50178679', 's59986698', 's55698800', 's58085167', 's51683155', 's54381763', 's55594849', 's57664750', 's52195893', 's53897449', 's51742525', 's59775769', 's55414814', 's52775752', 's51288835', 's52189004', 's58369249', 's51900597', 's54335521', 's54917064', 's54133231', 's51199892', 's59249979', 's58352022', 's52835225', 's53177649', 's54066754', 's54899257', 's53418217', 's59357257', 's50421764', 's51223853', 's56018459', 's57801123', 's59221051', 's50829485', 's55502536', 's57983519', 's53570653', 's51325572', 's58760787', 's59715122', 's58641137', 's54097861', 's58656783', 's51423353', 's54962274', 's58072789', 's53460154', 's50301215', 's52555178', 's57211901', 's56034024', 's54849848', 's55957472', 's54833205', 's54100996', 's53854854', 's59697640', 's57232140', 's52008677', 's55124994', 's56779415', 's55939586', 's56843282', 's56508966', 's59197220', 's52682048', 's54259878', 's51140249', 's58304701', 's58907220', 's53631792', 's55751115', 's53225676', 's53619001', 's56790426', 's58778783', 's56771404', 's59060938', 's53130454', 's51363438', 's56440919', 's59037095', 's55562335', 's50701107', 's54590636', 's54772630', 's59138498', 's58060878', 's55803590', 's55714183', 's57952807', 's58625748', 's57192814', 's51466579', 's58979101', 's50247294', 's54729238', 's55617591', 's57410883', 's58836797', 's56168637', 's58195876', 's50844004', 's50555779', 's51765753', 's52616494', 's54692227', 's50714348', 's54058678', 's51719198', 's50014127', 's59203230', 's50822353', 's55728799', 's50547182', 's50256977', 's53924742', 's56679657', 's51526655', 's55902256', 's57827533', 's57395479', 's52939447', 's54949810', 's52995335', 's50184397', 's53482917', 's57161577', 's50308220', 's51972257', 's58666319', 's54060800', 's51099690', 's59900684', 's52064406', 's55463368', 's51405069', 's57629666', 's52350132', 's58087032', 's51233560', 's54766893', 's59756815', 's57975962', 's56237499', 's52697084', 's56042355', 's57251948', 's50323961', 's55775814', 's59044011', 's59607772', 's52600197', 's53982700', 's55499601', 's55212349', 's57356552', 's53414987', 's55082399', 's59284918', 's59480672', 's57233393', 's58232231', 's58001075', 's56427859', 's53583954', 's58056585', 's51465438', 's57632806', 's50431066', 's58327706', 's58897728', 's56093476', 's58215117', 's50270173', 's50830952', 's54331436', 's54073075', 's51189125', 's54712047']

 
    # Load and preprocess dev set
    devset = pd.read_csv(args.gold_path)
    devset = devset[devset['study_id'].isin(args.dev_list)]
    # Fix AttributeError by applying str.upper() to each element instead of the Series
    devset['cat'] = devset['cat'].replace(['lf', 'If'], 'pf').str.upper()
    devset['status'] = devset['status'].str.upper()
    devset['location'] = devset['location'].str.replace(r'(loc|det):\s*', '', regex=True)
    devset['evidence'] = devset['evidence'].str.replace(r'idx(\d+)', r'obj_ent_idx\1', regex=True)
    devset['associate'] = devset['associate'].str.replace(r'idx(\d+)', r'obj_ent_idx\1', regex=True)

    print(f'{devset["study_id"].nunique()} studies in dev-set')

    # Retrieve few-shot examples
    query_words, user_history, assistant_history = retreive_query_related_fewshot(
        devset=devset,
        query=args.query,
        args=args
    )
    
    # Print results
    for user, assistant in zip(user_history, assistant_history):
        print(user)
        print(assistant)
        print()