"""
Sequential Metric Evaluation System

This module implements a two-level evaluation method for structured medical reports:
1. Subject-level F1 score - Finds the best matching entities between predicted and ground truth
2. Row-level F1 score - Finds the best matching rows and calculates F1 score

The system supports both single reports (rexval mode) and sequential reports.
"""
import pandas as pd
import numpy as np
import os
import yaml
import argparse
from scipy.optimize import linear_sum_assignment
from difflib import SequenceMatcher
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from tqdm import tqdm
import pandas as pd
import math
import textwrap
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.gridspec as gridspec
from matplotlib.ticker import MaxNLocator
import random
import numpy as np

# Set seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

class SemanticMatcher:
    """
    Handles sequence matching operations for entity attributes and embeddings.
    """
    
    def __init__(self, ):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def mean_pooling(self, model_output, attention_mask):
        """Perform mean pooling on model output using attention mask."""
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def get_embeddings(self, model, tokenizer, texts, model_name=None):
        """Get embeddings for a list of texts using the specified model."""

        model = model.to(self.device)
        
        encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
        encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
        
        with torch.no_grad():
            model_output = model(**encoded_input)
        
        embeddings = self.mean_pooling(model_output, encoded_input['attention_mask']) if model_name == 'FremyCompany/BioLORD-2023' \
                    else model_output.last_hidden_state[:, 0, :]
        
        return F.normalize(embeddings, p=2, dim=1)

    def matrix_process_sequence_pair(self, gt_text, target_text,
                        model1, tokenizer1, model_name=None):
        """Process and compare two sequences of text using embeddings."""
        gt_flat = [item for item in gt_text]
        target_flat = [item for item in target_text]
        
        gt_embeddings = self.get_embeddings(model1, tokenizer1, gt_flat, model_name)
        target_embeddings = self.get_embeddings(model1, tokenizer1, target_flat, model_name)
        
        similarities = torch.mm(gt_embeddings, target_embeddings.t())
        return similarities
       
class DataConverter:
    """Handles conversion of dataframes to structured dictionaries."""
    
    def __init__(self, STRUCTURE_WEIGHTS, mode='rexval'):
        self.rel_weights = STRUCTURE_WEIGHTS[1]
        self.temporal_weights = STRUCTURE_WEIGHTS[0]
        self.mode = mode
    
    def convert_to_structured_dict(self, df_data):
        """Convert dataframe to structured dictionary format by study_id"""
        structured_data = {}
        if 'sequential' in self.mode:
            for subject_id in df_data['subject_id'].unique():
                subject_data = df_data[df_data['subject_id'] == subject_id].sort_values(by='sequence')
                subject_dict = {}
                for index, row in subject_data.iterrows():

                    # Skip if entity is NaN
                    entity_name = 'entity' if 'entity' in row else 'ent'
                    if pd.isna(row[entity_name]):
                        continue
                        
                    entity = row[entity_name]
                    study_id = row['study_id']
                    if study_id not in subject_dict:
                        subject_dict[study_id] = {}
                    if  entity not in subject_dict[study_id]:
                        subject_dict[study_id][entity] = []
                    
                    entry = {}
                    for key in self.rel_weights.keys():
                        if key in row and pd.notna(row[key]):
                            entry[key] = row[key]
                            
                    for key in self.temporal_weights.keys():
                        if key in row and pd.notna(row[key]):
                            entry[key] = row[key]
                    
                    # Add episode info
                    if 'episode' not in row:
                        entry['episode'] = 1  # Default episode value
                        
                    if 'entity_group' in row:
                        entry['entity_group'] = row['entity_group']
                    
                    if 'sequence' not in row:
                        entry['sequence'] = row['study_id']
                        
                    # Only append if entry contains meaningful data
                    subject_dict[study_id][entity].append(entry)

                # Only add study_dict if it contains data
                if subject_dict:
                    if type(subject_id) == str:
                        structured_data[str(subject_id)] = subject_dict
                    else:
                        structured_data[int(subject_id)] = subject_dict

            print(f"{len(structured_data.keys())} Patient IDs: {list(structured_data.keys())} with {len(list(structured_data.values()))} groups")
        else:
            # Group by study_id
            for study_id in df_data['study_id'].unique():
                study_data = df_data[df_data['study_id'] == study_id]
                study_dict = {}
                
                for index, row in study_data.iterrows():
                    # Skip if entity is NaN
                    entity_name = 'entity' if 'entity' in row else 'ent'
                    if pd.isna(row[entity_name]):
                        continue
                        
                    entity = row[entity_name]
                    if entity not in study_dict:
                        study_dict[entity] = []
                    
                    entry = {}
                    for key in self.rel_weights.keys():
                        if key in row and pd.notna(row[key]):
                            entry[key] = row[key]
                    
                    for key in self.temporal_weights.keys():
                        if key in row and pd.notna(row[key]):
                            entry[key] = row[key]

                    entry['entity_group'] = 'single-eval'
                    entry['episode'] = 1  # Default episode value                    
                    entry['sequence'] = row['study_id']
                        
                    # Only append if entry contains meaningful data
                    if len(entry) > 1:  # More than just episode
                        study_dict[entity].append(entry)
                
                # Only add study_dict if it contains data
                if study_dict:
                    if type(study_id) == str:
                        structured_data[str(study_id)] = study_dict
                    else:
                        structured_data[int(study_id)] = study_dict
            print(f"{len(structured_data.keys())} STUDY IDs: {list(structured_data.keys())} with {len(list(structured_data.values()))} groups")

                    
        return structured_data

class SubjectMatcher:
    """Handles matching between subjects."""
    
    def __init__(self, mode):
        self.mode = mode

    def aggregate_entity_phrases(self, entity_dict):
        aggregate_keys = ['location', 'morphology', 'distribution', 'measurement', 'severity', 
                          'onset', 'improved', 'worsened', 'no change', 'placement']
        aggregated_phrases = []
        aggregated_row_info = []
        aggregated_study_ids = []
        aggregated_group = []
        aggregated_episode = []
        aggregated_dx_status = []
        if self.mode.startswith('sequential'):
            for study_id, entries in entity_dict.items():           
                phrases = []
                row_info = []
                for entity, instances in entries.items():
                    # Handle case where instances is a list of dictionaries
                    for instance in instances:
                        values = [instance.get(key, '') for key in aggregate_keys 
                                if key in instance and instance[key]]
                        phrase = ' '.join([entity] + values)
                        phrases.append(phrase.strip())
                        row_info.append(instance)
                        aggregated_study_ids.append(study_id)
                        aggregated_group.append(instance.get('entity_group', 'unknown'))
                        aggregated_episode.append(instance.get('episode', 'unknown'))
                        aggregated_dx_status.append(instance.get('dx_status', 'unknown'))
                aggregated_phrases.extend(phrases)
                aggregated_row_info.extend(row_info)
        
        else:
            for entity, entries in entity_dict.items():
                phrases = []
                row_info = []
                for instance in entries:                    
                    values = [instance[key] for key in aggregate_keys if key in instance and instance[key]]
                    phrase = ' '.join([entity] + values)
                    phrases.append(phrase.strip())
                    row_info.append(instance)
                    aggregated_group.append(instance.get('entity_group', 'unknown'))
                    aggregated_episode.append(instance.get('episode', 'unknown'))    
                    aggregated_dx_status.append(instance.get('dx_status', 'unknown'))
                aggregated_phrases.extend(phrases)
                aggregated_row_info.extend(row_info)
        
        return aggregated_phrases, aggregated_row_info, aggregated_study_ids, aggregated_group, aggregated_episode, aggregated_dx_status

    def compute_semantic_row_similarity(self, gt_row, pred_row, weights, matcher, model, tokenizer):
        """
        Compute semantic similarity between two rows based on their attributes.
        
        Args:
            gt_row: Ground truth row dictionary
            pred_row: Prediction row dictionary
            weights: Dictionary of weights
            matcher: SemanticMatcher instance
            model: Language model for semantic matching
            tokenizer: Tokenizer for the language model
            
        Returns:
            float: Normalized similarity score between the rows
        """
        score = 0.0
        norm = 0.0
        
        # Define binary/categorical fields that should use exact matching
        binary_fields = ['episode', 'dx_status', 'dx_certainty', 'sequence']
        
        if not self.mode.startswith('sequential'):
            binary_fields += ['entity_group']
        
        # Collect text fields for batch processing
        text_fields = []
        text_weights = []
        gt_texts = []
        pred_texts = []
        
        for rel, weight in weights.items():            
            gt_val = gt_row.get(rel)
            pred_val = pred_row.get(rel)
                        
            # Skip if both values are None/empty
            if (gt_val is None or str(gt_val).strip() == '') and (pred_val is None or str(pred_val).strip() == ''):
                continue
            
            # If either value exists, count this relation in normalization
            norm += weight
            
            # Handle binary/categorical fields with exact matching
            if rel in binary_fields:
                if gt_val is not None and pred_val is not None:
                    # For dx_status: positive/negative
                    if rel == 'dx_status':
                        gt_norm = str(gt_val).lower().strip()
                        pred_norm = str(pred_val).lower().strip()
                        
                        # Check for exact match in normalized values
                        if gt_norm == pred_norm:
                            score += weight
                    
                    # For dx_certainty: definitive/tentative
                    elif rel == 'dx_certainty':
                        gt_norm = str(gt_val).lower().strip()
                        pred_norm = str(pred_val).lower().strip()
                        
                        # Check for exact match in normalized values
                        if gt_norm == pred_norm:
                            score += weight
                    
                    # For episode and sequence: numeric comparison
                    elif rel in ['episode', 'sequence']:
                        try:
                            gt_num = int(gt_val)
                            pred_num = int(pred_val)
                            if gt_num == pred_num:
                                score += weight
                        except (ValueError, TypeError):
                            # If conversion fails, try string comparison
                            if str(gt_val) == str(pred_val):
                                score += weight

                    if not self.mode.startswith('sequential'):
                        if rel == 'entity_group':
                            if gt_val is not None and pred_val is not None:
                                if gt_val == pred_val:
                                    score += weight
                continue  # Skip semantic comparison for binary fields
            
            # For text fields, collect for batch processing
            if gt_val is not None and pred_val is not None and str(gt_val).strip() != '' and str(pred_val).strip() != '':
                text_fields.append(rel)
                text_weights.append(weight)
                gt_texts.append(str(gt_val))
                pred_texts.append(str(pred_val))
        
        
        # Batch process text fields if any exist
        if text_fields:
            try:
                # Process all text fields in a single batch
                batch_sim_matrix = matcher.matrix_process_sequence_pair(
                    gt_texts, pred_texts, model, tokenizer)
                
                # Extract diagonal elements (matching pairs)
                for i in range(min(len(gt_texts), len(pred_texts))):
                    if i < batch_sim_matrix.shape[0] and i < batch_sim_matrix.shape[1]:
                        sim_value = batch_sim_matrix[i, i].item()
                        score += text_weights[i] * sim_value
            except Exception as e:
                print(f"Error in batch similarity computation: {e}")
                # Fall back to individual processing
                for i, (rel, weight) in enumerate(zip(text_fields, text_weights)):
                    if i < len(gt_texts) and i < len(pred_texts):
                        try:
                            # Process individual text pair
                            column_sim = matcher.matrix_process_sequence_pair(
                                [gt_texts[i]], [pred_texts[i]], model, tokenizer)
                            
                            if torch.is_tensor(column_sim) and column_sim.numel() > 0:
                                sim_value = column_sim[0, 0].item()
                                score += weight * sim_value
                        except Exception as e:
                            print(f"Error computing similarity for {rel}: {e}")
                            # If semantic comparison fails, fall back to exact matching
                            if gt_texts[i] == pred_texts[i]:
                                score += weight
        
        # Normalize the score
        final_score = score / norm if norm > 0 else 0.0
        return final_score
            
    def semantic_match(self, gt_data, pred_data, rel_weights=None, temporal_weights=None):
        gt_phrases, gt_row_info, gt_study_ids, gt_group, gt_episode, gt_dx_status = self.aggregate_entity_phrases(gt_data)
        pred_phrases, pred_row_info, pred_study_ids, pred_group, pred_episode, pred_dx_status = self.aggregate_entity_phrases(pred_data)

        if len(gt_study_ids) == 0 or len(pred_study_ids) == 0:
            gt_study_ids = [0]*len(gt_phrases)
            pred_study_ids = [0]*len(pred_phrases)

        # 2. similarity matrix
        matcher = SemanticMatcher()
        models = [
            'FremyCompany/BioLORD-2023',
            'ncbi/MedCPT-Query-Encoder',
            # 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext',
            # 'emilyalsentzer/Bio_ClinicalBERT',
            # 'medicalai/ClinicalBERT',
            # 'dmis-lab/biobert-base-cased-v1.2'
        ]
        print(f"Using {len(models)} models for semantic matching: {models}")

        combined_sim_matrix = None
        combined_row_sim_matrix = None

        for model_idx, model_name in enumerate(models):
            print(f"\nProcessing with model: {model_name}")
            
            # Load model and tokenizer
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            model = AutoModel.from_pretrained(model_name)
            
             # Calculate group similarity matrix
            if self.mode.startswith('sequential'):
                sim_matrix = matcher.matrix_process_sequence_pair(gt_group, pred_group, model, tokenizer, model_name)
            else:
                sim_matrix = matcher.matrix_process_sequence_pair(gt_phrases, pred_phrases, model, tokenizer, model_name)
                
            # Calculate row-level similarity scores
            row_similarity_scores = torch.zeros((len(gt_row_info), len(pred_row_info)))
            if torch.is_tensor(sim_matrix):
                row_similarity_scores = row_similarity_scores.to(sim_matrix.device)
            
            temporal_similarity_scores = torch.zeros((len(gt_row_info), len(pred_row_info)))
            if torch.is_tensor(sim_matrix):
                temporal_similarity_scores = temporal_similarity_scores.to(sim_matrix.device)

            for i, gt_row in enumerate(gt_row_info):
                for j, pred_row in enumerate(pred_row_info):
                    # Compute row similarity
                    similarity = self.compute_semantic_row_similarity(
                        gt_row, pred_row, rel_weights, matcher, model, tokenizer)
                    row_similarity_scores[i, j] = similarity

                    temporal_similarity = self.compute_semantic_row_similarity(
                        gt_row, pred_row, temporal_weights, matcher, model, tokenizer)
                    temporal_similarity_scores[i, j] = temporal_similarity

            # Initialize or add to combined matrices
            if combined_sim_matrix is None:
                # First model - initialize the combined matrices
                combined_sim_matrix = sim_matrix.clone()
                combined_temporal_sim_matrix = temporal_similarity_scores.clone()
                combined_row_sim_matrix = row_similarity_scores.clone()

            else:                
                # Add current model's matrices to combined matrices
                combined_sim_matrix += sim_matrix
                combined_temporal_sim_matrix += temporal_similarity_scores
                combined_row_sim_matrix += row_similarity_scores
        
        # Average the combined matrices
        num_models = len(models)
        combined_sim_matrix /= num_models
        combined_temporal_sim_matrix /= num_models
        combined_row_sim_matrix /= num_models

        print(f"Row similarity matrix shape: {combined_row_sim_matrix.shape}")
        # Calculate final similarity matrix
        final_sim_matrix = combined_sim_matrix * combined_row_sim_matrix * combined_temporal_sim_matrix
        
        if self.mode == 'sequential_mask_maira':
            mask_matrix = torch.zeros((len(gt_study_ids), len(pred_study_ids)), device=final_sim_matrix.device)
            for i in range(len(gt_study_ids)):
                for j in range(len(pred_study_ids)):
                    if gt_study_ids[i] == pred_study_ids[j]:
                        mask_matrix[i, j] = 1.0
                        
            final_sim_matrix = final_sim_matrix * mask_matrix
        print(f"Final similarity matrix shape: {final_sim_matrix.shape}")
        # 3. bipartite 매칭
        if torch.is_tensor(final_sim_matrix):
            try:
                cost_matrix = 1 - final_sim_matrix.detach().cpu().numpy()
            except (RuntimeError, ImportError):
                # Handle case where numpy conversion fails
                n, m = final_sim_matrix.shape
                cost_matrix = np.ones((n, m))
                for i in range(n):
                    for j in range(m):
                        cost_matrix[i, j] = 1 - final_sim_matrix[i, j].item()
        else:
            cost_matrix = 1 - final_sim_matrix
        row_ind, col_ind = linear_sum_assignment(cost_matrix)

        #########################################################
        unique_gt_ids = sorted(list(set(gt_study_ids)))
        unique_pred_ids = sorted(list(set(pred_study_ids)))
        gt_id_to_indices = {id: [] for id in unique_gt_ids}
        pred_id_to_indices = {id: [] for id in unique_pred_ids}

        for i, id in enumerate(gt_study_ids):
            gt_id_to_indices[id].append(i)
            
        for j, id in enumerate(pred_study_ids):
            pred_id_to_indices[id].append(j)

        gt_id_to_count = {id: len(indices) for id, indices in gt_id_to_indices.items()}
        pred_id_to_count = {id: len(indices) for id, indices in pred_id_to_indices.items()}

        os.makedirs('matrix_images', exist_ok=True)
        # linear_sum_assignment 결과 시각화
        plt.figure(figsize=(18, 9))
        
        # 그리드스펙 설정 (두 개의 메인 플롯 + 두 개의 컬러바)
        gs = gridspec.GridSpec(1, 4, width_ratios=[10, 1, 10, 1])
        
        # 왼쪽: final_sim_matrix
        ax1 = plt.subplot(gs[0])
        try:
            final_np = final_sim_matrix.cpu().numpy()
        except (RuntimeError, ImportError):
            # Fallback if numpy conversion fails
            n, m = final_sim_matrix.shape
            final_np = np.zeros((n, m))
            for i in range(n):
                for j in range(m):
                    final_np[i, j] = final_sim_matrix[i, j].item()
        im1 = ax1.imshow(final_np, cmap='viridis', aspect='auto')
        ax1.set_title('Final Similarity Matrix', fontsize=16, fontweight='bold')
        ax1.set_xlabel('Predicted Reports', fontsize=14)
        ax1.set_ylabel('Ground Truth Reports', fontsize=14)
        
        # 왼쪽 플롯에 경계선 그리기
        # GT 경계선 (y축)
        y_tick_positions = []
        y_tick_labels = []
        
        for id in unique_gt_ids:
            indices = gt_id_to_indices[id]
            if indices:
                min_idx = min(indices)
                max_idx = max(indices)
                count = gt_id_to_count[id]
                
                # 경계선 그리기
                if min_idx > 0:
                    ax1.axhline(y=min_idx-0.5, color='#FF5555', linestyle='-', linewidth=1.5)
                
                # 중앙 위치 계산 및 저장 (커스텀 눈금용)
                center_pos = (min_idx + max_idx) / 2
                y_tick_positions.append(center_pos)
                y_tick_labels.append(f'{id} ({count})')
        
        # Pred 경계선 (x축)
        x_tick_positions = []
        x_tick_labels = []
        
        for id in unique_pred_ids:
            indices = pred_id_to_indices[id]
            if indices:
                min_idx = min(indices)
                max_idx = max(indices)
                count = pred_id_to_count[id]
                
                # 경계선 그리기
                if min_idx > 0:
                    ax1.axvline(x=min_idx-0.5, color='#55AA55', linestyle='-', linewidth=1.5)
                
                # 중앙 위치 계산 및 저장 (커스텀 눈금용)
                center_pos = (min_idx + max_idx) / 2
                x_tick_positions.append(center_pos)
                x_tick_labels.append(f'{id} ({count})')
        
        # 왼쪽 플롯의 보조 축 생성 (ID 레이블용)
        ax1_top = ax1.twiny()
        ax1_right = ax1.twinx()
        
        # 보조 축 범위 설정
        ax1_top.set_xlim(ax1.get_xlim())
        ax1_right.set_ylim(ax1.get_ylim())
        
        # 커스텀 눈금 설정
        ax1_top.set_xticks(x_tick_positions)
        ax1_top.set_xticklabels(x_tick_labels, rotation=45, ha='left', fontsize=9)
        ax1_right.set_yticks(y_tick_positions)
        ax1_right.set_yticklabels(y_tick_labels, fontsize=9)
        
        # 왼쪽 플롯의 기본 눈금 제거
        ax1.set_xticks([])
        ax1.set_yticks([])
        
        # 왼쪽 컬러바
        cax1 = plt.subplot(gs[1])
        cbar1 = plt.colorbar(im1, cax=cax1)
        cbar1.set_label('Similarity Score', fontsize=12)
        
        # 오른쪽: 매칭 결과 시각화
        ax2 = plt.subplot(gs[2])
        
        # 매칭 결과 행렬 생성 (0으로 초기화)
        try:
            matching_matrix = np.zeros_like(final_np)
        except (RuntimeError, ImportError, NameError):
            # Fallback if numpy conversion fails
            n, m = final_sim_matrix.shape
            matching_matrix = np.zeros((n, m))
        
        # linear_sum_assignment 결과로 매칭된 셀에 1 할당
        for i, j in zip(row_ind, col_ind):
            if i < matching_matrix.shape[0] and j < matching_matrix.shape[1]:
                matching_matrix[i, j] = 1
        
        # 매칭 결과 시각화
        # 커스텀 컬러맵: 흰색 배경, 빨간색 매칭
        colors = [(1, 1, 1), (1, 0.7, 0.7), (1, 0.5, 0.5), (1, 0.3, 0.3), (1, 0, 0)]
        cmap_match = LinearSegmentedColormap.from_list('custom_red', colors, N=100)
        
        im2 = ax2.imshow(matching_matrix, cmap=cmap_match, aspect='auto')
        ax2.set_title('Optimal Matching Result', fontsize=16, fontweight='bold')
        ax2.set_xlabel('Predicted Reports', fontsize=14)
        ax2.set_ylabel('Ground Truth Reports', fontsize=14)
        
        # 오른쪽 플롯에 경계선 그리기 (왼쪽과 동일)
        # GT 경계선 (y축)
        for id in unique_gt_ids:
            indices = gt_id_to_indices[id]
            if indices:
                min_idx = min(indices)
                if min_idx > 0:
                    ax2.axhline(y=min_idx-0.5, color='#FF5555', linestyle='-', linewidth=1.5)
        
        # Pred 경계선 (x축)
        for id in unique_pred_ids:
            indices = pred_id_to_indices[id]
            if indices:
                min_idx = min(indices)
                if min_idx > 0:
                    ax2.axvline(x=min_idx-0.5, color='#55AA55', linestyle='-', linewidth=1.5)
        
        # 오른쪽 플롯의 보조 축 생성 (ID 레이블용)
        ax2_top = ax2.twiny()
        ax2_right = ax2.twinx()
        
        # 보조 축 범위 설정
        ax2_top.set_xlim(ax2.get_xlim())
        ax2_right.set_ylim(ax2.get_ylim())
        
        # 커스텀 눈금 설정
        ax2_top.set_xticks(x_tick_positions)
        ax2_top.set_xticklabels(x_tick_labels, rotation=45, ha='left', fontsize=9)
        ax2_right.set_yticks(y_tick_positions)
        ax2_right.set_yticklabels(y_tick_labels, fontsize=9)
        
        # 오른쪽 플롯의 기본 눈금 제거
        ax2.set_xticks([])
        ax2.set_yticks([])
        
        # 오른쪽 컬러바
        cax2 = plt.subplot(gs[3])
        cbar2 = plt.colorbar(im2, cax=cax2)
        cbar2.set_label('Match (1) / No Match (0)', fontsize=12)
        
        # 레이아웃 조정
        plt.tight_layout()
        plt.savefig('matrix_images/matching_result.png', dpi=300, bbox_inches='tight')
        plt.savefig('matrix_images/matching_result.pdf', format='pdf', bbox_inches='tight')
        plt.close()
        
        # 매칭 결과와 유사도 점수를 함께 보여주는 시각화
        plt.figure(figsize=(10, 8))
        
        # 유사도 점수가 있는 매칭 결과 행렬 생성
        try:
            weighted_matching = np.zeros_like(final_np)
        except (RuntimeError, ImportError, NameError):
            # Fallback if numpy conversion fails
            n, m = final_sim_matrix.shape
            weighted_matching = np.zeros((n, m))
        
        # linear_sum_assignment 결과로 매칭된 셀에 유사도 점수 할당
        for i, j in zip(row_ind, col_ind):
            if i < weighted_matching.shape[0] and j < weighted_matching.shape[1]:
                weighted_matching[i, j] = final_np[i, j]
        
        # 매칭 결과 시각화
        plt.imshow(weighted_matching, cmap='viridis', aspect='auto')
        plt.colorbar(label='Similarity Score of Matched Pairs')
        plt.title('Optimal Matching with Similarity Scores', fontsize=16, fontweight='bold')
        plt.xlabel('Predicted Reports', fontsize=14)
        plt.ylabel('Ground Truth Reports', fontsize=14)
        
        # 경계선 그리기
        for id in unique_gt_ids:
            indices = gt_id_to_indices[id]
            if indices:
                min_idx = min(indices)
                if min_idx > 0:
                    plt.axhline(y=min_idx-0.5, color='#FF5555', linestyle='-', linewidth=1.5)
        
        for id in unique_pred_ids:
            indices = pred_id_to_indices[id]
            if indices:
                min_idx = min(indices)
                if min_idx > 0:
                    plt.axvline(x=min_idx-0.5, color='#55AA55', linestyle='-', linewidth=1.5)
        
        # 매칭된 셀에 X 표시 추가
        for i, j in zip(row_ind, col_ind):
            if i < weighted_matching.shape[0] and j < weighted_matching.shape[1]:
                plt.plot(j, i, 'rx', markersize=5)
        
        plt.tight_layout()
        plt.savefig('matrix_images/weighted_matching.png', dpi=300, bbox_inches='tight')
        plt.savefig('matrix_images/weighted_matching.pdf', format='pdf', bbox_inches='tight')
        plt.close()
        
        #########################################################

        matched_pairs = []
        matched_scores = []
        matched_indices = set()
        
        print("\n\n Matching result:")
        # Store all matches and record similarity scores
        for i, j in zip(row_ind, col_ind):
            if i < len(gt_study_ids) and j < len(pred_study_ids):
                # Get the similarity score for this match
                sim_score = final_sim_matrix[i, j].item() if torch.is_tensor(final_sim_matrix[i, j]) else final_sim_matrix[i, j]
                com_sim_score = combined_sim_matrix[i, j].item() if torch.is_tensor(combined_sim_matrix[i, j]) else combined_sim_matrix[i, j]
                com_temp_sim_score = combined_temporal_sim_matrix[i, j].item() if torch.is_tensor(combined_temporal_sim_matrix[i, j]) else combined_temporal_sim_matrix[i, j]
                com_row_sim_score = combined_row_sim_matrix[i, j].item() if torch.is_tensor(combined_row_sim_matrix[i, j]) else combined_row_sim_matrix[i, j]
                # Store the match information and score
                matched_pairs.append((gt_study_ids[i], pred_study_ids[j], gt_episode[i], pred_episode[j], gt_phrases[i], pred_phrases[j], sim_score))
                matched_scores.append(sim_score)
                matched_indices.add((i, j))
                print(f"Match: GT={gt_phrases[i]}, Std: {gt_study_ids[i]}, E: {gt_episode[i]}, Group: {gt_group[i]}, DX_STATUS: {gt_dx_status[i]}, \n Pred={pred_phrases[j]} Std: {pred_study_ids[j]} E: {pred_episode[j]}, Group: {pred_group[j]}, DX_STATUS: {pred_dx_status[j]}, \n Total_Score={sim_score:.3f}, Semantic_score = {com_sim_score:.3f}, Temporal_score = {com_temp_sim_score:.3f}, Row_score = {com_row_sim_score:.3f}\n")

        # Identify unmatched items
        gt_matched_indices = {i for i, _ in matched_indices}
        pred_matched_indices = {j for _, j in matched_indices}
        
        gt_unmatched_indices = [i for i in range(len(gt_study_ids)) if i not in gt_matched_indices]
        pred_unmatched_indices = [j for j in range(len(pred_study_ids)) if j not in pred_matched_indices]
        
        # Collect info for unmatched GT items
        gt_unmatched = []
        gt_best_similarities = []
        for i in gt_unmatched_indices:
            if i < final_sim_matrix.shape[0] and final_sim_matrix.shape[1] > 0:
                # Find max similarity
                best_sim = torch.max(final_sim_matrix[i, :]).item() if torch.is_tensor(final_sim_matrix) else np.max(final_sim_matrix[i, :])
                gt_unmatched.append((gt_study_ids[i], gt_phrases[i]))
                gt_best_similarities.append(best_sim)
                print(f"Unmatched GT[{i}]: {gt_phrases[i]} (best sim: {best_sim:.3f})")
            else:
                gt_unmatched.append((gt_study_ids[i], gt_phrases[i]))
                gt_best_similarities.append(0.0)
        
        # Collect info for unmatched Pred items
        pred_unmatched = []
        pred_best_similarities = []
        for j in pred_unmatched_indices:
            if j < final_sim_matrix.shape[1] and final_sim_matrix.shape[0] > 0:
                # Find max similarity
                best_sim = torch.max(final_sim_matrix[:, j]).item() if torch.is_tensor(final_sim_matrix) else np.max(final_sim_matrix[:, j])
                pred_unmatched.append((pred_study_ids[j], pred_phrases[j]))
                pred_best_similarities.append(best_sim)
                print(f"Unmatched Pred[{j}]: {pred_phrases[j]} (best sim: {best_sim:.3f})")
            else:
                pred_unmatched.append((pred_study_ids[j], pred_phrases[j]))
                pred_best_similarities.append(0.0)

        # F1 with similarity-weighted matching
        # Each match is proportional to similarity to TP
        # Unmatched items are proportional to (1-similarity) to FP/FN
        
        # 1. TP from matched results
        TP = sum(matched_scores)
        
        # 2. FP/FN from matched results (1-sim)
        # Imperfect matching should be reflect to FN and FP.
        FP_from_matches = sum(1 - score for score in matched_scores)
        FN_from_matches = sum(1 - score for score in matched_scores)
        
        # 3. Real unmatcned FP/FN
        # Max sim of FP/FN
        FP_from_unmatched = len(pred_unmatched) - sum(pred_best_similarities)
        FN_from_unmatched = len(gt_unmatched) - sum(gt_best_similarities)
        
        if self.mode == 'sequential_mask_maira':
            # Create a mapping of study IDs to track which ones are present
            gt_study_id_set = set()
            pred_study_id_set = set()
            
            # Extract study IDs from matched pairs
            for gt_id, pred_id, *_ in matched_pairs:
                gt_study_id_set.add(gt_id)
                pred_study_id_set.add(pred_id)
                
            # Extract study IDs from unmatched items
            for gt_id, _ in gt_unmatched:
                gt_study_id_set.add(gt_id)
            
            for pred_id, _ in pred_unmatched:
                pred_study_id_set.add(pred_id)
            
            # Filter unmatched items based on study ID presence
            filtered_gt_unmatched = []
            filtered_gt_best_similarities = []
            
            for i, (gt_id, phrase) in enumerate(gt_unmatched):
                # Only count as FN if the study ID exists in predictions
                if gt_id in pred_study_id_set:
                    filtered_gt_unmatched.append((gt_id, phrase))
                    if i < len(gt_best_similarities):
                        filtered_gt_best_similarities.append(gt_best_similarities[i])
            
            filtered_pred_unmatched = []
            filtered_pred_best_similarities = []
            
            for i, (pred_id, phrase) in enumerate(pred_unmatched):
                # Only count as FP if the study ID exists in ground truth
                if pred_id in gt_study_id_set:
                    filtered_pred_unmatched.append((pred_id, phrase))
                    if i < len(pred_best_similarities):
                        filtered_pred_best_similarities.append(pred_best_similarities[i])
            
            # Recalculate FP/FN with filtered unmatched items
            FP_from_unmatched = len(filtered_pred_unmatched) - sum(filtered_pred_best_similarities)
            FN_from_unmatched = len(filtered_gt_unmatched) - sum(filtered_gt_best_similarities)
            
            print(f"\nFiltered unmatched items for sequential_mask_maira mode:")
            print(f"Original GT unmatched: {len(gt_unmatched)}, Filtered: {len(filtered_gt_unmatched)}")
            print(f"Original Pred unmatched: {len(pred_unmatched)}, Filtered: {len(filtered_pred_unmatched)}")
        
        FP = FP_from_matches + FP_from_unmatched
        FN = FN_from_matches + FN_from_unmatched
        
        precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
        recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0
        
        f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
        
        print(f"\nWeighted F1 Calculation:")
        print(f"TP: {TP:.3f}")
        print(f"FP: {FP:.3f} (from matches: {FP_from_matches:.3f}, from unmatched: {FP_from_unmatched:.3f})")
        print(f"FN: {FN:.3f} (from matches: {FN_from_matches:.3f}, from unmatched: {FN_from_unmatched:.3f})")
        print(f"Precision: {precision:.3f}")
        print(f"Recall: {recall:.3f}")
        print(f"F1 Score: {f1_score:.3f}")
        # input("ST!!")

        return matched_pairs, f1_score, precision, recall, gt_unmatched, pred_unmatched
        
class TrendAnalyzer:
    """Analyzes trends in sequential data."""
    
    def __init__(self, severity_scale=None):
        self.severity_scale = severity_scale or {}
    
    def extract_trend(self, sequence_rows, target_rel):
        """Extract trend values for a specific relation from sequence rows."""
        trend = []
        last_val = None
        for row in sequence_rows:
            if row.get('no changing', False) and last_val is not None:
                trend.append(last_val)
                continue
            val = row.get(target_rel)
            if target_rel == 'severity':
                score = self.severity_scale.get(val, val)
                if score:
                    trend.append(score)
                    last_val = score
            else:
                # Extensible for other relations
                trend.append(val)
                last_val = val
        return trend
    
    def detect_direction(self, trend_scores):
        """Detect the direction of a trend (increasing, decreasing, stable, or mixed)."""
        if not trend_scores or len(trend_scores) < 2:
            return 'stable'
            
        diffs = [b - a for a, b in zip(trend_scores, trend_scores[1:])]

        if all(d >= 0 for d in diffs) and any(d > 0 for d in diffs):
            return 'increasing'
        elif all(d <= 0 for d in diffs) and any(d < 0 for d in diffs):
            return 'decreasing'
        elif all(d == 0 for d in diffs):
            return 'stable'
        else:
            return 'mixed'
    
    def compare_trends_weighted(self, gt_rows, pred_rows, target_rel, rel_weights):
        """Compare trends between ground truth and predicted rows with weighting."""
        if len(gt_rows) < 2 or len(pred_rows) < 2:
            return 1.0  # Skip temporal evaluation for single-item sequences

        gt_trend = self.extract_trend(gt_rows, target_rel)
        pred_trend = self.extract_trend(pred_rows, target_rel)

        gt_dir = self.detect_direction(gt_trend)
        pred_dir = self.detect_direction(pred_trend)

        weight = rel_weights.get(target_rel, 0)

        if gt_dir == pred_dir:
            return 1.0
        elif 'stable' in [gt_dir, pred_dir]:
            return 1.0 - weight * 0.5   # Partial penalty
        else:
            return 1.0 - weight         # Critical error (opposite direction)

class MetricEvaluator:
    """Main class for evaluating metrics between ground truth and predictions."""
    
    def __init__(self, weights, mode='rexval', subject_matching_mode="semantic"):
        self.temporal_weights = weights[0]
        self.rel_weights = weights[1]
        self.mode = mode
        self.data_converter = DataConverter(weights, mode)
        self.subject_matcher = SubjectMatcher(mode)
        self.trend_analyzer = TrendAnalyzer()
        self.subject_matching_mode = subject_matching_mode
    
    def evaluate_subject(self, gt_data, pred_data, mode='semantic', 
                         temporal_eval=True):
        """Evaluate subjects between ground truth and predictions."""

        matched_pairs, f1_score, prec, recall, gt_unmatched, pred_unmatched = self.subject_matcher.semantic_match(
            gt_data, pred_data, self.rel_weights, self.temporal_weights)

        return matched_pairs, f1_score, prec, recall, gt_unmatched, pred_unmatched
    
    def evaluate_dataset(self, gt_data_dict, pred_data_dict):
        """Evaluate an entire dataset of subjects."""
        scores = {}
        prec_scores = {}
        recall_scores = {}
        for study_id in gt_data_dict:
            if study_id in pred_data_dict:
                matched_pairs, f1_score, prec, recall, gt_unmatched, pred_unmatched = self.evaluate_subject(gt_data_dict[study_id], pred_data_dict[study_id], mode=self.mode, temporal_eval=False)
                scores[study_id] = f1_score
                prec_scores[study_id] = prec
                recall_scores[study_id] = recall
            else:
                print(f"Study ID {study_id}: Missing in predictions")
        return scores, prec_scores, recall_scores

def main():
    """Main function to run the metric evaluation."""
    # Argument parser setup
    parser = argparse.ArgumentParser(description='Run metric evaluation.')   
    
    parser.add_argument('--mode', type=str, default='rexval', 
                        choices=['rexval', 'single_maira', 'single_medgemma', 'single_libra', 'single_medversa','single_lingshu',
                         'sequential_maira', 'sequential_mask_maira', 'sequential_medgemma', 'sequential_libra', 'sequential_lingshu'], 
                        help='Mode to run the metric evaluation.')
    
    parser.add_argument('--subject_matching_mode', type=str, default='semantic',
                        choices=['semantic', 'string'],
                        help="Subject matching mode")
        
    parser.add_argument('--gt_path', type=str, default='./dataset/gold_SR.csv',
                        help="Path to ground truth report CSV")
    
    parser.add_argument('--pred_path', type=str, default='./eval/pred_SR_df.csv',
                        help="Path to prediction report CSV")
    args = parser.parse_args()

    # Load metric weights
    with open('metric_weights.yaml', 'r') as file:
        metric_weights = yaml.safe_load(file)
    
    temporal_weight = sum(metric_weights['TEMPORAL_WEIGHTS'].values())
    temporal_weights = {k: v / temporal_weight for k, v in metric_weights['TEMPORAL_WEIGHTS'].items()}
    rel_weight = sum(metric_weights['REL_WEIGHTS'].values())
    rel_weights = {k: v / rel_weight for k, v in metric_weights['REL_WEIGHTS'].items()}
    
    STRUCTURE_WEIGHTS = [temporal_weights, rel_weights]
    # Initialize evaluator
    evaluator = MetricEvaluator(STRUCTURE_WEIGHTS, args.mode, args.subject_matching_mode)

    if args.mode == 'rexval':
        rexval_result = pd.read_csv('./benchmark/final_rexval.csv')

        # Filter data by report type
        gt_report_data = rexval_result[rexval_result['report_type'] == 'gt_report']
        radgraph_data = rexval_result[rexval_result['report_type'] == 'radgraph']
        bertscore_data = rexval_result[rexval_result['report_type'] == 'bertscore']
        s_emb_data = rexval_result[rexval_result['report_type'] == 's_emb']
        bleu_data = rexval_result[rexval_result['report_type'] == 'bleu']

        # Convert data for each metric method
        metric_data = {
            'gt': evaluator.data_converter.convert_to_structured_dict(gt_report_data),
            'radgraph': evaluator.data_converter.convert_to_structured_dict(radgraph_data),
            'bertscore': evaluator.data_converter.convert_to_structured_dict(bertscore_data),
            's_emb': evaluator.data_converter.convert_to_structured_dict(s_emb_data),
            'bleu': evaluator.data_converter.convert_to_structured_dict(bleu_data)
        }
        
        # Evaluate each metric against ground truth
        results = {}
        structure_scores = {}
        prec_scores = {}
        recall_scores = {}
        
        for metric_name, metric_pred_data in metric_data.items():
            if metric_name != 'gt':
                print(f"\n\nEvaluating {metric_name.upper()} metric:")
                structure_score, prec_scores, recall_scores = evaluator.evaluate_dataset(metric_data['gt'], metric_pred_data)
                structure_scores[metric_name] = structure_score
                prec_scores[metric_name] = prec_scores
                recall_scores[metric_name] = recall_scores
        # Output results
        print("\n\nFinal Results:")
        print("\nSemanctic subject matching & Structure Scores:")
        for metric_name, score_dict in structure_scores.items():
            # Calculate average score from the dictionary of scores
            avg_score = np.mean(list(score_dict.values())) if score_dict else 0.0
            print(f"{metric_name.upper()}: {avg_score:.3f}")

        print("\nPrecision Scores:")
        for metric_name in structure_scores.keys():
            # 모든 메트릭에 대해 처리
            score_dict = prec_scores.get(metric_name, {})
            if score_dict:
                # Extract only the numeric values at the first level
                numeric_values = []
                for key, value in score_dict.items():
                    if isinstance(value, (int, float)) and not isinstance(value, bool):
                        numeric_values.append(value)
                
                avg_prec = np.mean(numeric_values) if numeric_values else 0.0
                print(f"{metric_name.upper()}: {avg_prec:.3f}")

        print("\nRecall Scores:")
        for metric_name in structure_scores.keys():
            # 모든 메트릭에 대해 처리
            score_dict = recall_scores.get(metric_name, {})
            if score_dict:
                # Extract only the numeric values at the first level
                numeric_values = []
                for key, value in score_dict.items():
                    if isinstance(value, (int, float)) and not isinstance(value, bool):
                        numeric_values.append(value)
                
                avg_recall = np.mean(numeric_values) if numeric_values else 0.0
                print(f"{metric_name.upper()}: {avg_recall:.3f}")

    elif args.mode in ['sequential_maira', 'sequential_mask_maira', 'single_maira']:
        # TODO: Implement sequential mode evaluation
        print("Sequential mode not yet implemented")
        ## sequential reports must contain temporal_group (episode), entity_group (Subject) column.
        ### temporal_group, entity_group extractor merge.
        sequential_gt = pd.read_csv(args.gt_path) 
        sequential_gt = sequential_gt[sequential_gt['gt_temporal_group'].notna()]
        
        sequential_gt.rename(columns={'entity_group': 'gt_entity_group', 'temporal_group': 'gt_temporal_group'}, inplace=True)
        
        if args.mode == 'sequential_maira':
            maira_standard = pd.read_csv('./benchmark/maira.csv')
            maira_cascade = pd.read_csv('./benchmark/maira_cascade.csv')
        elif args.mode == 'sequential_mask_maira':
            maira_standard = pd.read_csv('./benchmark/maira.csv')
            maira_cascade = pd.read_csv('./benchmark/maira_cascade.csv')
        elif args.mode == 'single_maira':
            maira_standard = pd.read_csv('./benchmark/maira.csv')
            maira_cascade = pd.read_csv('./benchmark/maira_cascade.csv')
        
        maira_standard = maira_standard[(maira_standard['LLM_cluster'].notna()) & (maira_standard['temporal_group'].notna())].rename(columns={'LLM_cluster': 'entity_group'})
        maira_cascade = maira_cascade[(maira_cascade['LLM_cluster'].notna()) & (maira_cascade['temporal_group'].notna())].rename(columns={'LLM_cluster': 'entity_group'})


        interset_study = set(maira_cascade['study_id']) & set(maira_standard['study_id'])
        maira_cascade = maira_cascade[maira_cascade['study_id'].isin(interset_study)]
        maira_standard = maira_standard[maira_standard['study_id'].isin(interset_study)]
        sequential_gt = sequential_gt[sequential_gt['study_id'].isin(interset_study)]


        # Convert data for each metric method
        metric_data = {
            'maira_standard': evaluator.data_converter.convert_to_structured_dict(maira_standard),
            'maira_cascade': evaluator.data_converter.convert_to_structured_dict(maira_cascade),
            'gt': evaluator.data_converter.convert_to_structured_dict(sequential_gt)
        }
        
             
        # Evaluate each metric against ground truth
        results = {}
        structure_scores = {}
        prec_scores = {}
        recall_scores = {}
        
        for metric_name, metric_pred_data in metric_data.items():
            if metric_name != 'gt':
                print(f"\n\nEvaluating {metric_name.upper()} metric:")
                sem_struct_score, prec_scores, recall_scores = evaluator.evaluate_dataset(metric_data['gt'], metric_pred_data)
                
                structure_scores[metric_name] = sem_struct_score
                prec_scores[metric_name] = prec_scores
                recall_scores[metric_name] = recall_scores
                results[metric_name] = sem_struct_score
        
        print("\n\nFinal Results:")
        print("\nSemanctic subject matching & Structure Scores:")
        for metric_name, score_dict in structure_scores.items():
            # Calculate average score from the dictionary of scores
            avg_score = np.mean(list(score_dict.values())) if score_dict else 0.0
            print(f"{metric_name.upper()}: {avg_score:.3f}")

        print("\nPrecision Scores:")
        for metric_name in structure_scores.keys():
            # 모든 메트릭에 대해 처리
            score_dict = prec_scores.get(metric_name, {})
            if score_dict:
                # Extract only the numeric values at the first level
                numeric_values = []
                for key, value in score_dict.items():
                    if isinstance(value, (int, float)) and not isinstance(value, bool):
                        numeric_values.append(value)
                
                avg_prec = np.mean(numeric_values) if numeric_values else 0.0
                print(f"{metric_name.upper()}: {avg_prec:.3f}")

        print("\nRecall Scores:")
        for metric_name in structure_scores.keys():
            # 모든 메트릭에 대해 처리
            score_dict = recall_scores.get(metric_name, {})
            if score_dict:
                # Extract only the numeric values at the first level
                numeric_values = []
                for key, value in score_dict.items():
                    if isinstance(value, (int, float)) and not isinstance(value, bool):
                        numeric_values.append(value)
                
                avg_recall = np.mean(numeric_values) if numeric_values else 0.0
                print(f"{metric_name.upper()}: {avg_recall:.3f}")

    elif 'maira' not in args.mode:
        # raise NotImplementedError("Others mode not yet implemented. Please run rexval / sequential_maira / sequential_mask_maira / single_maira mode.")
        print("Sequential mode not yet implemented")
        ## sequential reports must contain temporal_group (episode), entity_group (Subject) column.
        ### temporal_group, entity_group extractor merge.        
        gt_df = pd.read_csv(args.gt_path) 
        gt_df = gt_df[gt_df['gt_temporal_group'].notna()]
        gt_df.rename(columns={'entity_group': 'gt_entity_group', 'temporal_group': 'gt_temporal_group'}, inplace=True)
        
        pred_report = pd.read_csv(args.pred_path)
        
        pred_report = pred_report[(pred_report['LLM_cluster'].notna()) & (pred_report['temporal_group'].notna())].rename(columns={'LLM_cluster': 'entity_group'})


        interset_study = set(pred_report['study_id']) & set(gt_df['study_id'])
        pred_report = pred_report[pred_report['study_id'].isin(interset_study)]
        gt_df = gt_df[gt_df['study_id'].isin(interset_study)]

        # Convert data for each metric method
        metric_data = {
            'pred_report': evaluator.data_converter.convert_to_structured_dict(pred_report),
            'gt': evaluator.data_converter.convert_to_structured_dict(gt_df)
        }
        
        # Evaluate each metric against ground truth
        results = {}
        structure_scores = {}
        prec_scores = {}
        recall_scores = {}
        
        for metric_name, metric_pred_data in metric_data.items():
            if metric_name != 'gt':
                print(f"\n\nEvaluating {metric_name.upper()} metric:")
                sem_struct_score, prec_scores, recall_scores = evaluator.evaluate_dataset(metric_data['gt'], metric_pred_data)
                
                structure_scores[metric_name] = sem_struct_score
                prec_scores[metric_name] = prec_scores
                recall_scores[metric_name] = recall_scores
                results[metric_name] = sem_struct_score
        
        print("\n\nFinal Results:")
        print("\nSemanctic subject matching & Structure Scores:")
        for metric_name, score_dict in structure_scores.items():
            # Calculate average score from the dictionary of scores
            avg_score = np.mean(list(score_dict.values())) if score_dict else 0.0
            print(f"{metric_name.upper()}: {avg_score:.3f}")

        print("\nPrecision Scores:")
        for metric_name in structure_scores.keys():
            # 모든 메트릭에 대해 처리
            score_dict = prec_scores.get(metric_name, {})
            if score_dict:
                # Extract only the numeric values at the first level
                numeric_values = []
                for key, value in score_dict.items():
                    if isinstance(value, (int, float)) and not isinstance(value, bool):
                        numeric_values.append(value)
                
                avg_prec = np.mean(numeric_values) if numeric_values else 0.0
                print(f"{metric_name.upper()}: {avg_prec:.3f}")

        print("\nRecall Scores:")
        for metric_name in structure_scores.keys():
            # 모든 메트릭에 대해 처리
            score_dict = recall_scores.get(metric_name, {})
            if score_dict:
                # Extract only the numeric values at the first level
                numeric_values = []
                for key, value in score_dict.items():
                    if isinstance(value, (int, float)) and not isinstance(value, bool):
                        numeric_values.append(value)
                
                avg_recall = np.mean(numeric_values) if numeric_values else 0.0
                print(f"{metric_name.upper()}: {avg_recall:.3f}")

if __name__ == "__main__":
    
    main()