import pandas as pd
from typing import List, Dict, Any
import instructor
from pydantic import BaseModel, Field, RootModel
import os
import json
import numpy as np
from seqeuntial_normalize_prompt import Sequential_findings_review_prompt, Result_Review_w_time_gap_Ex1, Result_Review_w_time_gap_Ex2, Result_Review_w_time_gap_Ex3, Result_Review_w_time_gap_Ex4, Result_Review_w_time_gap_Ex5
from collections import defaultdict
from fuzzywuzzy import fuzz
import matplotlib.pyplot as plt
import seaborn as sns
import torch

END_POINT = "/v1/chat/completions"


def create_batch_job(file_name, client, args):
    if args.LLM_name.startswith('claude'):
        from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
        from anthropic.types.messages.batch_create_params import Request            
        requests = []
        schema_format = [
            {
                "name": "sequential_matching",
                "description": "build the sequential matching object",
                "input_schema": RadiologyOutput.model_json_schema()
            }
        ]

        with open(file_name, 'r') as file:
            for line in file:
                task = json.loads(line)
                custom_id = task.get('custom_id', '')
                messages = task.get('messages', [])
                
                request = Request(
                    custom_id=custom_id,
                    params=MessageCreateParamsNonStreaming(
                        model=args.LLM_name,
                        max_tokens= 4096 if '3-haiku' in args.LLM_name else 8092,
                        system=[
                                {"type": "text", 
                                "text": Sequential_findings_review_prompt,
                                "cache_control": {"type": "ephemeral"}}
                            ],
                        messages = messages,
                        tools = schema_format,
                        tool_choice={"type": "tool", "name": "sequential_matching"}
                    )
                )
                requests.append(request)

        
        batch_job = client.beta.messages.batches.create(requests=requests)
    else:
        batch_file = client.files.create(
                        file=open(file_name, "rb"),
                        purpose="batch"
                        )       
        batch_job = client.batches.create(
                        input_file_id=batch_file.id,
                        endpoint=END_POINT,
                        completion_window="24h",
                        
                        )
    return batch_job

def _clean_malformed_json(json_str):
    """JSON 문자열에서 일반적인 오류를 수정합니다."""
    import re
    
    # 1. 마크다운 코드 블록 제거 (```json ... ``` 또는 ``` ... ```)
    if json_str.strip().startswith('```'):
        # 시작과 끝의 코드 블록 마커 제거
        lines = json_str.strip().split('\n')
        
        # 첫 번째 줄이 ```json 또는 ```인 경우 제거
        if lines[0].strip().startswith('```'):
            lines = lines[1:]
            
        # 마지막 줄이 ```인 경우 제거
        if lines and lines[-1].strip() == '```':
            lines = lines[:-1]
            
        json_str = '\n'.join(lines)
    
    # 2. 앞뒤 공백 제거
    json_str = json_str.strip()
    
    # 3. 중복된 rationale 필드 처리
    # "rationale": "...", "episodes": [...], "rationale": "..." 패턴을 찾아서 수정
    pattern = r'"rationale":\s*"([^"]*(?:\\.[^"]*)*)",\s*"episodes":\s*(\[[^\]]*\]),\s*"rationale":\s*"([^"]*(?:\\.[^"]*)*)"'
    
    def replace_duplicate_rationale(match):
        rationale1 = match.group(1)
        episodes = match.group(2) 
        rationale2 = match.group(3)
        
        # 두 번째 rationale을 사용 (보통 더 간결하고 정확함)
        return f'"rationale": "{rationale2}", "episodes": {episodes}'
    
    cleaned = re.sub(pattern, replace_duplicate_rationale, json_str, flags=re.DOTALL)
    
    # 4. rationale 내부의 잘못된 ],가 있는 경우 제거
    pattern2 = r'"rationale":\s*"([^"]*)\s*\],([^"]*)"'
    cleaned = re.sub(pattern2, r'"rationale": "\1\2"', cleaned)
    
    # 5. 잘린 JSON인 경우 감지 및 경고
    if not cleaned.strip().endswith('}') and not cleaned.strip().endswith(']'):
        print("Warning: JSON appears to be truncated")
        # 기본적인 닫기 시도 (간단한 경우만)
        if cleaned.count('{') > cleaned.count('}'):
            # 여는 중괄호가 더 많은 경우
            missing_braces = cleaned.count('{') - cleaned.count('}')
            cleaned += '}' * missing_braces
        if cleaned.count('[') > cleaned.count(']'):
            # 여는 대괄호가 더 많은 경우
            missing_brackets = cleaned.count('[') - cleaned.count(']')
            cleaned += ']' * missing_brackets
    
    return cleaned

def _attempt_json_recovery(json_str):
    """JSON 복구를 위한 추가적인 시도를 합니다."""
    import re
    import json
    
    # 1. 기본 클리닝 다시 시도
    cleaned = _clean_malformed_json(json_str)
    
    # 2. 불완전한 JSON 구조 감지 및 복구
    
    # 2-1. results 배열이 시작되었지만 완료되지 않은 경우
    if '"results"' in cleaned and cleaned.count('[') > cleaned.count(']'):
        # results 배열을 찾아서 그 뒤를 처리
        results_match = re.search(r'"results":\s*\[', cleaned)
        if results_match:
            # results 시작점 이후의 내용 확인
            start_pos = results_match.end()
            content_after_results = cleaned[start_pos:]
            
            # 완성된 그룹들만 추출 시도
            groups = []
            brace_count = 0
            current_group = ""
            in_string = False
            escape_next = False
            
            for char in content_after_results:
                if escape_next:
                    current_group += char
                    escape_next = False
                    continue
                    
                if char == '\\':
                    escape_next = True
                    current_group += char
                    continue
                    
                if char == '"' and not escape_next:
                    in_string = not in_string
                    
                current_group += char
                
                if not in_string:
                    if char == '{':
                        brace_count += 1
                    elif char == '}':
                        brace_count -= 1
                        
                        # 완성된 그룹 발견
                        if brace_count == 0 and current_group.strip():
                            # 앞의 쉼표 제거
                            group_content = current_group.strip()
                            if group_content.startswith(','):
                                group_content = group_content[1:].strip()
                            
                            try:
                                # 단일 그룹으로 파싱 시도
                                group_obj = json.loads(group_content)
                                groups.append(group_obj)
                                current_group = ""
                            except:
                                pass
                    elif char == ',' and brace_count == 0:
                        # 그룹 구분자
                        current_group = ""
            
            if groups:
                # 복구된 그룹들로 새로운 JSON 생성
                recovered = json.dumps({"results": groups})
                print(f"Recovered {len(groups)} complete groups from truncated JSON")
                return recovered
    
    # 3. 단순한 구조 복구 시도
    try:
        # 기본적인 JSON 닫기 시도
        if cleaned.strip():
            # JSON이 { 로 시작하지만 } 로 끝나지 않는 경우
            if cleaned.strip().startswith('{') and not cleaned.strip().endswith('}'):
                # 간단한 닫기 시도
                test_json = cleaned.strip() + '}'
                json.loads(test_json)  # 검증
                return test_json
            
            # 배열이 [ 로 시작하지만 ] 로 끝나지 않는 경우  
            if cleaned.strip().startswith('[') and not cleaned.strip().endswith(']'):
                test_json = cleaned.strip() + ']'
                json.loads(test_json)  # 검증
                return test_json
                
    except:
        pass
    
    # 4. 부분적인 JSON에서 results 추출 시도
    results_pattern = r'"results":\s*\[(.*?)(?:\]|}|$)'
    match = re.search(results_pattern, cleaned, re.DOTALL)
    if match:
        try:
            results_content = match.group(1).strip()
            if results_content:
                # 마지막 불완전한 객체 제거
                if results_content.endswith(','):
                    results_content = results_content[:-1]
                
                test_json = '{"results": [' + results_content + ']}'
                json.loads(test_json)  # 검증
                return test_json
        except:
            pass
    
    print("Could not recover JSON structure")
    return None

def _validate_content_structure(content_obj):
    """content_obj가 예상된 구조를 가지고 있는지 검증합니다."""
    if not isinstance(content_obj, dict):
        return False
        
    # 'results' 키가 있고 리스트인지 확인
    if 'results' not in content_obj:
        return False
        
    results = content_obj['results']
    if not isinstance(results, list):
        return False
        
    # 각 결과가 필요한 구조를 가지고 있는지 확인
    for result in results:
        if not isinstance(result, dict):
            continue
            
        required_fields = ['group_name', 'findings', 'episodes', 'rationale']
        if not all(field in result for field in required_fields):
            return False
            
        # findings와 episodes가 리스트인지 확인
        if not isinstance(result['findings'], list) or not isinstance(result['episodes'], list):
            return False
            
    return True

def _fix_content_structure(content_obj):
    """content_obj의 구조를 수정합니다."""
    if not isinstance(content_obj, dict):
        return content_obj
        
    # results 키가 없는 경우, 전체 객체를 results로 감싸기
    if 'results' not in content_obj:
        # 객체가 finding group들의 딕셔너리인 경우
        results = []
        for key, value in content_obj.items():
            if isinstance(value, dict):
                # group_name이 없으면 키를 group_name으로 사용
                if 'group_name' not in value:
                    value['group_name'] = key
                results.append(value)
            elif isinstance(value, list):
                # 리스트인 경우 각 항목을 개별 그룹으로 처리
                for item in value:
                    if isinstance(item, dict):
                        if 'group_name' not in item:
                            item['group_name'] = key
                        results.append(item)
        content_obj = {'results': results}
    
    # 각 결과의 구조 수정
    results = content_obj.get('results', [])
    fixed_results = []
    
    for result in results:
        if not isinstance(result, dict):
            continue
            
        fixed_result = {}
        
        # 필수 필드들 설정
        fixed_result['group_name'] = result.get('group_name', 'unnamed_group')
        fixed_result['findings'] = result.get('findings', [])
        fixed_result['episodes'] = result.get('episodes', [])
        fixed_result['rationale'] = result.get('rationale', '')
        
        # findings와 episodes가 리스트가 아닌 경우 수정
        if not isinstance(fixed_result['findings'], list):
            fixed_result['findings'] = []
            
        if not isinstance(fixed_result['episodes'], list):
            fixed_result['episodes'] = []
            
        fixed_results.append(fixed_result)
    
    content_obj['results'] = fixed_results
    return content_obj

def _fix_episodes_structure(episodes, group_name):
    """episodes 리스트의 구조를 검증하고 수정합니다."""
    if not episodes:
        return []
    
    fixed_episodes = []
    
    for i, episode in enumerate(episodes):
        try:
            if isinstance(episode, dict):
                # 필수 필드 확인
                episode_num = episode.get('episode', i + 1)
                days = episode.get('days', [])
                
                # days가 리스트가 아닌 경우 수정
                if not isinstance(days, list):
                    if isinstance(days, (int, float)):
                        days = [int(days)]
                    else:
                        print(f"Warning: Invalid days format in episode {episode_num} for group {group_name}")
                        days = []
                
                # episode 번호가 숫자가 아닌 경우 수정
                if not isinstance(episode_num, (int, float)):
                    episode_num = i + 1
                
                fixed_episodes.append({
                    'episode': int(episode_num),
                    'days': [int(day) for day in days if isinstance(day, (int, float))]
                })
            else:
                print(f"Warning: Episode is not a dict for group {group_name}: {episode}")
                # 기본 에피소드 생성
                fixed_episodes.append({
                    'episode': i + 1,
                    'days': []
                })
        except Exception as e:
            print(f"Error fixing episode structure for group {group_name}: {e}")
            # 기본 에피소드 생성
            fixed_episodes.append({
                'episode': i + 1,
                'days': []
            })
    
    return fixed_episodes

def process_radiology_output(response, input_data=None, is_missing_process=False):
    flattened_data = []
    
    stats = {
        'total_input_observations': 0,
        'matched_observations': 0,
        'unmatched_observations': 0
    }
    
    subject_id = input_data.get("subject_id", "") if input_data else ""
    
    content_obj = None
    try:
        if isinstance(response, str):
            try:
                # JSON 파싱 전에 잠재적인 오류 수정
                cleaned_response = _clean_malformed_json(response)
                content_obj = json.loads(cleaned_response)
            except json.JSONDecodeError as e:
                print(f"Error: Could not parse response as JSON after cleaning: {e}")
                print(f"Original response length: {len(response)}")
                print(f"Cleaned response length: {len(cleaned_response) if 'cleaned_response' in locals() else 'N/A'}")
                print(f"Response preview: {response[:200]}...")
                print(f"Cleaned preview: {cleaned_response[:200] if 'cleaned_response' in locals() else 'N/A'}...")
                
                # 추가적인 복구 시도
                try:
                    recovered_json = _attempt_json_recovery(response)
                    if recovered_json:
                        content_obj = json.loads(recovered_json)
                        print("Successfully recovered JSON after additional processing")
                    else:
                        return pd.DataFrame(), stats
                except Exception as recovery_error:
                    print(f"JSON recovery also failed: {recovery_error}")
                    return pd.DataFrame(), stats
        elif hasattr(response, 'root'):
            content_obj = response.root
        elif isinstance(response, dict):
            content_obj = response
        elif isinstance(response, RadiologyOutput):
            # Handle RadiologyOutput type
            content_obj = {"results": response.results}
        else:
            print(f"Error: Unsupported response type: {type(response)}")
            return pd.DataFrame(), stats
            
        if not content_obj:
            print("Error: No content object found in response")
            return pd.DataFrame(), stats
            
        # 추가 검증: content_obj가 예상된 구조를 가지고 있는지 확인
        if not _validate_content_structure(content_obj):
            print("Warning: Content object has unexpected structure, attempting to fix...")
            content_obj = _fix_content_structure(content_obj)
            
    except Exception as e:
        print(f"Error processing response: {e}")
        print(f"Response type: {type(response)}")
        return pd.DataFrame(), stats

    if not is_missing_process:
        all_input_observations = {}
        subject_unmatched = []
        
        if input_data:
            try:
                if isinstance(input_data["Input"], str):
                    input_json = json.loads(input_data["Input"])
                else:
                    input_json = input_data["Input"]
                    
                observations = input_json.get("observations", {})
                ent_idx_dict = input_data.get("Input_ent_idx", {})
                seq_dict = input_data.get("Input_seq", {})
                
                for idx_day_str, findings in observations.items():
                    idx_day_str = idx_day_str.split(", status")[0]
                    
                    for finding in findings:
                        
                        obs_key = f"{idx_day_str}|{finding}"
                        
                        ent_idx_val = ent_idx_dict.get(idx_day_str, -1)
                        seq_val = seq_dict.get(idx_day_str, -1)
                        
                        all_input_observations[obs_key] = {
                            "idx_day_str": idx_day_str,
                            "finding": finding,
                            "ent_idx": ent_idx_val,
                            "seq": seq_val,
                            "matched": False,
                            "batch_idx": input_data.get("batch_idx", 0),
                            "group_name": input_data.get("Group_name", "")
                        }
                        
                        stats['total_input_observations'] += 1
            except Exception as e:
                print(f"Error parsing input data: {e}")
        
        if 'results' in content_obj:
            findings_list = content_obj.get('results', [])
            
            for finding_group in findings_list:
                # Extract data from the finding group with improved error handling
                try:
                    if isinstance(finding_group, dict):
                        group_name = finding_group.get('group_name', 'unnamed_group')
                        findings = finding_group.get('findings', [])
                        episodes = finding_group.get('episodes', [])
                        rationale = finding_group.get('rationale', '')
                    else:  # Pydantic 모델
                        group_name = getattr(finding_group, 'group_name', 'unnamed_group')
                        findings = getattr(finding_group, 'findings', [])
                        episodes = getattr(finding_group, 'episodes', [])
                        rationale = getattr(finding_group, 'rationale', '')
                    
                    # 추가 검증 및 수정
                    if not isinstance(findings, list):
                        print(f"Warning: findings is not a list for group {group_name}, converting...")
                        findings = [] if findings is None else [findings]
                    
                    if not isinstance(episodes, list):
                        print(f"Warning: episodes is not a list for group {group_name}, converting...")
                        episodes = [] if episodes is None else [episodes]
                    
                    # episodes 구조 검증 및 수정
                    episodes = _fix_episodes_structure(episodes, group_name)
                    
                except Exception as e:
                    print(f"Error processing finding group: {e}")
                    print(f"Finding group data: {finding_group}")
                    continue
                
                for finding in findings:
                    temporal_group = 1
                    
                    # 안전하게 속성 접근 with improved validation
                    try:
                        if isinstance(finding, dict):
                            finding_IDX = finding.get('IDX', -1)
                            finding_day = finding.get('DAY', -1)
                            finding_text = finding.get('finding', '')
                        else:  # Pydantic 모델
                            finding_IDX = getattr(finding, 'IDX', -1)
                            finding_day = getattr(finding, 'DAY', -1)
                            finding_text = getattr(finding, 'finding', '')
                        
                        # 데이터 타입 검증 및 수정
                        finding_IDX = int(finding_IDX) if isinstance(finding_IDX, (int, float, str)) and str(finding_IDX).isdigit() else -1
                        finding_day = int(finding_day) if isinstance(finding_day, (int, float, str)) and str(finding_day).isdigit() else -1
                        finding_text = str(finding_text) if finding_text is not None else ''
                        
                    except Exception as e:
                        print(f"Error processing finding in group {group_name}: {e}")
                        print(f"Finding data: {finding}")
                        continue
                    
                    # 에피소드에서 일치하는 날짜 찾기
                    for ep in episodes:
                        if isinstance(ep, dict):
                            episode_days = ep.get('days', [])
                            if finding_day in episode_days:
                                temporal_group = ep.get('episode', -1)
                                break
                        else:  # Pydantic 모델
                            episode_days = getattr(ep, 'days', [])
                            if finding_day in episode_days:
                                temporal_group = getattr(ep, 'episode', -1)
                                break
                    
                    idx_day_str = f"IDX:{finding_IDX}, DAY: {finding_day}"
                    obs_key = f"{idx_day_str}|{finding_text}"
                    
                    if obs_key in all_input_observations:
                        all_input_observations[obs_key]["matched"] = True
                        stats['matched_observations'] += 1
                        
                        data_row = {
                            'subject_id': subject_id,
                            'batch_idx': all_input_observations[obs_key]["batch_idx"],
                            'LLM_cluster': group_name,
                            'episodes': str(episodes),
                            'rationale': rationale,
                            'ent_idx': all_input_observations[obs_key]["ent_idx"],
                            'IDX': finding_IDX,
                            'sequence': all_input_observations[obs_key]["seq"],
                            'temporal_group': temporal_group,
                            'DAY': finding_day,
                            'finding': finding_text,
                            'status': 'matched'
                        }
                        
                        if input_data and input_data.get("Group_name"):
                            data_row["1st_cluster"] = input_data["Group_name"]
                        
                        flattened_data.append(data_row)
        else:
            raise ValueError("No SCHEMA or PYDANTIC ERROR found in content_obj")
        
        for obs_key, obs_data in all_input_observations.items():
            if not obs_data["matched"]:
                stats['unmatched_observations'] += 1
                
                idx_day_parts = obs_data["idx_day_str"].split(", DAY: ")
                idx_day_val = idx_day_parts[1] if len(idx_day_parts) > 1 else ""
                
                IDX_parts = idx_day_parts[0].split(":")
                llm_IDX_val = IDX_parts[1] if len(IDX_parts) > 1 else -1
                
                unmatched_info = {
                    'subject_id': subject_id,
                    'batch_idx': obs_data["batch_idx"],
                    'idx_day_str': obs_data["idx_day_str"],
                    'DAY': idx_day_val,
                    'finding': obs_data["finding"],
                    'ent_idx': obs_data["ent_idx"],
                    'seq': obs_data["seq"],
                    'status': 'not_in_output'
                }
                
                subject_unmatched.append(unmatched_info)
                
                unmatched_row = {
                    'subject_id': subject_id,
                    'batch_idx': obs_data["batch_idx"],
                    'LLM_cluster': 'UNMATCHED',
                    'episodes': '[]',
                    'rationale': '',
                    'IDX': llm_IDX_val,
                    'ent_idx': obs_data["ent_idx"],
                    'sequence': obs_data["seq"],
                    'temporal_group': -1,
                    'DAY': idx_day_val,
                    'finding': obs_data["finding"],
                    'status': 'unmatched'
                }
                
                if input_data and input_data.get("Group_name"):
                    unmatched_row["1st_cluster"] = input_data["Group_name"]
                
                flattened_data.append(unmatched_row)
        
        if subject_unmatched:
            print(f"\n=== Unmatched observations for subject {subject_id} ===")
            for unmatched in subject_unmatched:
                print(f"  - Finding: {unmatched['finding']}, Day: {unmatched['DAY']}, Status: {unmatched['status']}")
    
    else:
        missing_tracking = {
            'total': 0,
            'matched': 0,
            'unmatched': 0
        }
        
        processed_findings = set()
        cluster_name = input_data.get("Group_name", "") if input_data else ""

        if 'results' in content_obj:
            findings_list = content_obj['results']
            
            for finding_group in findings_list:
                # 안전하게 속성 접근 with improved error handling
                try:
                    if isinstance(finding_group, dict):
                        group_name = finding_group.get('group_name', 'unnamed_group')
                        findings = finding_group.get('findings', [])
                        episodes = finding_group.get('episodes', [])
                        rationale = finding_group.get('rationale', '')
                    else:  # Pydantic 모델
                        group_name = getattr(finding_group, 'group_name', 'unnamed_group')
                        findings = getattr(finding_group, 'findings', [])
                        episodes = getattr(finding_group, 'episodes', [])
                        rationale = getattr(finding_group, 'rationale', '')
                    
                    # 추가 검증 및 수정 (missing process에서도 동일하게)
                    if not isinstance(findings, list):
                        print(f"Warning: findings is not a list for group {group_name}, converting...")
                        findings = [] if findings is None else [findings]
                    
                    if not isinstance(episodes, list):
                        print(f"Warning: episodes is not a list for group {group_name}, converting...")
                        episodes = [] if episodes is None else [episodes]
                    
                    # episodes 구조 검증 및 수정
                    episodes = _fix_episodes_structure(episodes, group_name)
                    
                except Exception as e:
                    print(f"Error processing finding group in missing process: {e}")
                    print(f"Finding group data: {finding_group}")
                    continue
                
                # findings 처리
                for finding in findings:
                    missing_tracking['total'] += 1
                    stats['total_input_observations'] += 1
                    
                    # 안전하게 속성 접근 with improved validation
                    try:
                        if isinstance(finding, dict):
                            finding_idx = finding.get('IDX', -1)
                            finding_day = finding.get('DAY', -1)
                            finding_text = finding.get('finding', '')
                        else:  # Pydantic 모델
                            finding_idx = getattr(finding, 'IDX', -1)
                            finding_day = getattr(finding, 'DAY', -1)
                            finding_text = getattr(finding, 'finding', '')
                        
                        # 데이터 타입 검증 및 수정
                        finding_idx = int(finding_idx) if isinstance(finding_idx, (int, float, str)) and str(finding_idx).isdigit() else -1
                        finding_day = int(finding_day) if isinstance(finding_day, (int, float, str)) and str(finding_day).isdigit() else -1
                        finding_text = str(finding_text) if finding_text is not None else ''
                        
                    except Exception as e:
                        print(f"Error processing finding in missing process for group {group_name}: {e}")
                        print(f"Finding data: {finding}")
                        continue
                    
                    # 에피소드 정보 찾기
                    temporal_group = 1
                    for ep in episodes:
                        if isinstance(ep, dict):
                            episode_days = ep.get('days', [])
                            if finding_day in episode_days:
                                temporal_group = ep.get('episode', -1)
                                break
                        else:  # Pydantic 모델
                            episode_days = getattr(ep, 'days', [])
                            if finding_day in episode_days:
                                temporal_group = getattr(ep, 'episode', -1)
                                break
                    
                    idx_day_str = f"IDX:{finding_idx}, DAY: {finding_day}"
                    finding_key = f"{idx_day_str}|{finding_text}"
                    
                    if finding_key in processed_findings:
                        continue
                    
                    processed_findings.add(finding_key)
                    missing_tracking['matched'] += 1
                    
                    data_row = {
                        'subject_id': subject_id,
                        'batch_idx': input_data.get('batch_idx', ''),
                        'LLM_cluster': group_name,
                        'episodes': str(episodes),
                        'rationale': rationale,
                        'ent_idx': -1,
                        'IDX': finding_idx,
                        'sequence': -1,
                        'temporal_group': temporal_group,
                        'DAY': finding_day,
                        'finding': finding_text,
                        'status': 'matched'
                    }
                    
                    data_row["1st_cluster"] = cluster_name
                    
                    flattened_data.append(data_row)
        else:            
            for group_name, group_data in content_obj.items():
                if isinstance(group_data, str):
                    print(f"Warning: Group data is a string: {group_data[:100]}...")
                    continue
                    
                if hasattr(group_data, 'timeframe'):
                    timeframe = group_data.timeframe
                    rationale = group_data.rationale
                    episodes = group_data.episodes
                    findings = group_data.findings
                else:
                    if isinstance(group_data, list):
                        print(f"Warning: group_data is a list with {len(group_data)} items, using first item")
                        if group_data and isinstance(group_data[0], dict):
                            group_data = group_data[0]
                        else:
                            print(f"Warning: Cannot process group_data list: {group_data[:2]}...")
                            continue

                    timeframe = group_data.get('timeframe', '')
                    rationale = group_data.get('rationale', '')
                    episodes = group_data.get('episodes', [])
                    findings = group_data.get('findings', [])
                
                for finding in findings:
                    missing_tracking['total'] += 1
                    stats['total_input_observations'] += 1
                    
                    try:
                        if hasattr(finding, 'DAY'):
                            finding_idx = finding.IDX
                            finding_day = finding.DAY
                            finding_text = finding.finding
                        else:
                            finding_idx = finding.get('IDX', -1)
                            finding_day = finding.get('DAY', -1)
                            finding_text = finding.get('finding', '')
                        
                        # 데이터 타입 검증 및 수정
                        finding_idx = int(finding_idx) if isinstance(finding_idx, (int, float, str)) and str(finding_idx).isdigit() else -1
                        finding_day = int(finding_day) if isinstance(finding_day, (int, float, str)) and str(finding_day).isdigit() else -1
                        finding_text = str(finding_text) if finding_text is not None else ''
                        
                    except Exception as e:
                        print(f"Error processing finding in alternative structure for group {group_name}: {e}")
                        print(f"Finding data: {finding}")
                        continue
                        
                    temporal_group = 1
                    for ep in episodes:
                        if hasattr(ep, 'days'):
                            episode_days = ep.days
                            if finding_day in episode_days:
                                temporal_group = ep.episode
                                break
                        else:
                            episode_days = ep.get('days', [])
                            if finding_day in episode_days:
                                temporal_group = ep.get('episode', -1)
                                break
                    
                    idx_day_str = f"IDX:{finding_idx}, DAY: {finding_day}"
                    finding_key = f"{idx_day_str}|{finding_text}"
                    
                    if finding_key in processed_findings:
                        continue
                    
                    processed_findings.add(finding_key)
                    missing_tracking['matched'] += 1
                    
                    data_row = {
                        'subject_id': subject_id,
                        'batch_idx': input_data.get('batch_idx', ''),
                        'LLM_cluster': group_name,
                        'episodes': str(episodes),
                        'rationale': rationale,
                        'ent_idx': -1,
                        'IDX': finding_idx,
                        'sequence': -1,
                        'temporal_group': temporal_group,
                        'DAY': finding_day,
                        'finding': finding_text,
                        'status': 'matched'
                    }
                    
                    data_row["1st_cluster"] = cluster_name
                    
                    flattened_data.append(data_row)
            
        # 미싱 통계 출력
        print(f"  Missing statistics:")
        print(f"  - Total observations: {missing_tracking['total']}")
        print(f"  - Matched: {missing_tracking['matched']}")
        print(f"  - Unmatched: {missing_tracking['total'] - missing_tracking['matched']} \n")
        
        # 중요: 미싱 처리 완료 후 전체 통계 계산
        stats['matched_observations'] = missing_tracking['matched']
        stats['unmatched_observations'] = missing_tracking['total'] - missing_tracking['matched']
    
    # 통계 출력
    print("=== Statistics ===")
    print(f"Total input observations: {stats['total_input_observations']}")
    print(f"Matched observations: {stats['matched_observations']}")
    print(f"Unmatched observations: {stats['unmatched_observations']}")
    
    # 데이터프레임 생성 및 정렬
    df = pd.DataFrame(flattened_data)
    if not df.empty:
        # 필요한 컬럼이 없는 경우 추가
        if '1st_cluster' not in df.columns:
            df['1st_cluster'] = input_data.get("Group_name", "") if input_data else ""
        
        sort_columns = ['subject_id', 'LLM_cluster', 'temporal_group', 'DAY', 'IDX']
        sort_columns = [col for col in sort_columns if col in df.columns]
        if sort_columns:
            df = df.sort_values(sort_columns)
    else:
        # 빈 데이터프레임인 경우 필요한 컬럼 추가
        df = pd.DataFrame(columns=[
            'subject_id', 'batch_idx', '1st_cluster', 'LLM_cluster',
            'episodes', 'rationale', 'ent_idx', 'IDX', 'sequence',
            'temporal_group', 'DAY', 'finding', 'status'
        ])
    
    return df, stats

def convert_sets_to_lists(obj):
    """Recursively convert sets to lists in nested structures"""
    if isinstance(obj, dict):
        return {key: convert_sets_to_lists(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_sets_to_lists(item) for item in obj]
    elif isinstance(obj, set):
        return list(obj)
    return obj

def read_batch_results_to_csv(args, batch_results_file, clustered_df, all_model_outputs, is_missing_process=False):
    print("\nTemporal mapping of elements within each cluster:")        
    line_count = 0
    success_count = 0
    output_df = pd.DataFrame()
    print("batch_results_file", batch_results_file)

    with open(batch_results_file, 'r') as file:
        for line in file:
            line_count += 1
            try:
                result = json.loads(line)
            except json.JSONDecodeError as e:
                print(f"Error parsing JSONL line {line_count}: {e}")
                print(f"Skipping problematic line: {line[:100]}...")
                continue
            
            custom_id = result.get('custom_id', '')
            parts = custom_id.split('_')
            if len(parts) >= 2:
                subject_id = parts[0]
                batch_idx = parts[1]
                print(f"\nsubject_id {subject_id}, batch_idx {batch_idx}")
            else:
                print(f"Warning: Unexpected custom_id format: {custom_id}")
                continue
            
            try:
                try:
                    batch_idx_key = int(batch_idx)
                except ValueError:
                    batch_idx_key = batch_idx
                
                if subject_id not in all_model_outputs:
                    print(f"Warning: Subject ID {subject_id} not found in model outputs")
                    continue
                    
                if batch_idx_key not in all_model_outputs[subject_id]:
                    if isinstance(batch_idx_key, int) and str(batch_idx_key) in all_model_outputs[subject_id]:
                        batch_idx_key = str(batch_idx_key)
                    else:
                        print(f"Warning: Batch index {batch_idx} not found for subject {subject_id}")
                        print(f"Available keys: {list(all_model_outputs[subject_id].keys())}")
                        continue
                
                input_data = {
                    "batch_idx": batch_idx,
                    "subject_id": subject_id,
                    "Input": all_model_outputs[subject_id][batch_idx_key]['Input'],
                    "Input_ent_idx": all_model_outputs[subject_id][batch_idx_key]['Input_ent_idx'],
                    "Input_seq": all_model_outputs[subject_id][batch_idx_key]['Input_seq'],
                    "Group_name": all_model_outputs[subject_id][batch_idx_key]['Group_name'],
                }                

                content = None
                if "content" in result:
                    content = result.get("content", "")
                else:

                    response = result.get('response', {})
                    body = response.get('body', {})
                    
                    if isinstance(body, str):

                        try:
                            body = json.loads(body)
                        except json.JSONDecodeError:
                            print(f"Warning: Could not parse body as JSON for custom_id {custom_id}")
                            continue
                    
                    choices = body.get('choices', [])
                    
                    if not choices or len(choices) == 0:
                        print(f"Warning: No choices found for custom_id {custom_id}")
                        continue
                    
                    message = choices[0].get('message', {})
                    content = message.get('content', '')
                
                if not content:
                    print(f"Warning: No content found for custom_id {custom_id}")
                    continue

                try:
                    llm_output_df, stats = process_radiology_output(content, input_data, is_missing_process)
                    output_df = pd.concat([output_df, llm_output_df])
                    success_count += 1
                except Exception as e:
                    print(f"Error processing content: {e}")
                    import traceback
                    print(f"Detailed error: {traceback.format_exc()}")
                    continue
                
            except Exception as e:
                print(f"Error extracting content: {e}")
                print(f"Result structure: {list(result.keys())}")
                continue
                
        if is_missing_process:
            output_path = f"{args.output_path}/missing_outputs.csv"
        else:
            output_path = f"{args.output_path}/output_df.csv"
        
        if not output_df.empty:
            try:
                output_df.to_csv(output_path, index=False)
                print(f"Successfully saved output to {output_path}")
            except Exception as e:
                print(f"Error saving output: {e}")
                
    print(f"\nProcessed {line_count} lines, successfully parsed {success_count}")
    
    print(f"\nCreated DataFrame with shape: {output_df.shape}")
    print(f"Columns: {output_df.columns.tolist()}")    
    return output_df

def creating_batch_file(clustered_df, args, missing_data=None, all_inputs=None, iteration=0):   
    print(f"\n Creating batch file for {args.LLM_name}, number of subject_id: {clustered_df.subject_id.nunique()}")
    if all_inputs is None:
        all_inputs = {}

    tasks = []
    if missing_data is None:
        for subject_id in clustered_df.subject_id.unique():
            print(f'Patient: "{subject_id}"')    
            all_inputs[subject_id] = {}  # Initialize dict for this subject

            print("number of groups: ", len(clustered_df[clustered_df['subject_id']==subject_id].cluster_name.unique()))
            
            for group_idx, group_name in enumerate(clustered_df[clustered_df['subject_id']==subject_id].cluster_name.unique()):
                print(f'  Group: "{group_name}"')
                
                # Get the subset of data for this subject and group
                cur_group_df = clustered_df[(clustered_df['subject_id']==subject_id)&
                            (clustered_df['cluster_name']==group_name)]
                
                if group_idx not in all_inputs[subject_id]:
                    all_inputs[subject_id][group_idx] = {}
                # Debug print
                print(f"Number of observations in group: {len(cur_group_df)}, std_len: {len(cur_group_df['study_id'].unique())}")
                        
                # Exclude groups with only one observation
                if len(cur_group_df["ELA_cur_ent"]) <= 1:
                    print(f"Skipping group {group_name} - only has {len(cur_group_df)} observation(s)")
                    continue
                
                # Create dictionary of observations
                dict_obs, ent_idx_info, seq_info = {}, {}, {}
                idx_counter = 0
                for idx, obs in enumerate(cur_group_df['ELA_cur_ent'].to_list()):
                    day = cur_group_df['day_from_first'].to_list()[idx]
                    status = cur_group_df['dx_status'].to_list()[idx]
                    day_num = int(day.split()[0])  # Extract the number from "X days"
                    dict_obs[f"IDX:{idx_counter}, DAY: {day_num}, status: {status}"] = [obs]  # Create single-item list for each observation
                    ent_idx_info[f"IDX:{idx_counter}, DAY: {day_num}"] = cur_group_df['ent_idx'].to_list()[idx]
                    seq_info[f"IDX:{idx_counter}, DAY: {day_num}"] = cur_group_df['sequence'].to_list()[idx]
                    idx_counter += 1

                input_json = {
                    "cluster_name": group_name,
                    "observations": dict_obs
                }
                Input = json.dumps(input_json, ensure_ascii=False)

                
                all_inputs[subject_id][group_idx]["Input"] = Input
                all_inputs[subject_id][group_idx]["Input_ent_idx"] = ent_idx_info
                all_inputs[subject_id][group_idx]["Input_seq"] = seq_info
                all_inputs[subject_id][group_idx]["Group_name"] = group_name
                
                conversation = generate_few_shot(Input, args)

                # Convert any sets to lists and ensure proper message format
                conversation = convert_sets_to_lists(conversation)
                
                # Ensure each message has the correct format
                formatted_conversation = []
                for msg in conversation:
                    if isinstance(msg, dict) and 'role' in msg and 'content' in msg:
                        # If content is a dict or list, convert it to a string
                        if isinstance(msg['content'], (dict, list)):
                            msg['content'] = str(msg['content'])
                        formatted_conversation.append(msg)
                    else:
                        print(f"Skipping invalid message format: {msg}")                                

                custom_id = f"{subject_id}_{group_idx}"
                schema_file_path = "./sequentialSR/gpt_json_schema.json"
                with open(schema_file_path, 'r') as f:
                    gpt_json_schema = json.load(f)
                if args.LLM_name.startswith('gpt'):
                    task = {
                        "custom_id": custom_id,
                        "method": "POST",
                        "url": "/v1/chat/completions",
                        "body": {
                            "model": args.LLM_name,
                            "temperature": 0.0,
                            "messages": formatted_conversation,
                            "response_format": { 
                                "type": "json_schema",
                                "json_schema": gpt_json_schema
                            }
                        }
                    }
                else:
                    task = {
                        "custom_id": custom_id,
                        "messages": formatted_conversation
                    }
                    
                # Verify task is properly formed
                try:
                    # Test JSON serialization
                    json.dumps(task)
                    tasks.append(task)
                    # print(f"Successfully added task for {subject_id}, group {group_name}")
                except TypeError as e:
                    print(f"Error creating task for {subject_id}, group {group_name}: {e}")
                    print(f"Problematic task structure: {task}")
        file_name = f"{args.batch_path}/llm_batch.jsonl"
        
    else:
        for subject_id, subject_clusters in missing_data.items():
            print(f"\n=== Processing missing observations for Subject ID: {subject_id} ===")
            print("missing group number of subject_id: ", len(subject_clusters))

            for group_idx, (cluster_name, cluster_content) in enumerate(subject_clusters.items()):
                print("Processing cluster:", cluster_name)

                conversation = generate_few_shot(cluster_content, args)
                
                # Convert any sets to lists and ensure proper message format
                conversation = convert_sets_to_lists(conversation)
                        
                # Ensure each message has the correct format
                formatted_conversation = []
                for msg in conversation:
                    if isinstance(msg, dict) and 'role' in msg and 'content' in msg:
                        # If content is a dict or list, convert it to a string
                        if isinstance(msg['content'], (dict, list)):
                            msg['content'] = str(msg['content'])
                        formatted_conversation.append(msg)
                    else:
                        print(f"Skipping invalid message format: {msg}")

                custom_id = f"{subject_id}_{group_idx}"
                if args.LLM_name.startswith('gpt'):
                    task = {
                        "custom_id": custom_id,
                        "method": "POST",
                        "url": "/v1/chat/completions",
                        "body": {
                            "model": args.LLM_name,
                            "temperature": 0.1,
                            "messages": formatted_conversation
                        }
                    }
                else:
                    task = {
                        "custom_id": custom_id,
                        "messages": formatted_conversation
                    }
                    
                # Verify task is properly formed
                try:
                    # Test JSON serialization
                    json.dumps(task)
                    tasks.append(task)
                except TypeError as e:
                    print(f"Error creating task for {subject_id}, group {cluster_name}: {e}")
                    print(f"Problematic task structure: {task}")
        
        file_name = f"{args.batch_path}/missing_batch{iteration}.jsonl"
    
    # Write tasks to file, ensuring all data is JSON serializable
    with open(file_name, 'w') as file:
        for task in tasks:
            try:
                json_str = json.dumps(task)
                file.write(json_str + '\n')
                # print(f"Successfully wrote task for {task['custom_id']}")
            except TypeError as e:
                print(f"Error serializing task: {e}")
                print(f"Problematic task: {task}")
                continue
    
    print("Batch file created!")
    print(f"\nSummary:")
    print(f"Total tasks created: {len(tasks)}")
    print(f"Batch file: {file_name} \n")
    return all_inputs

def prepare_missing_inputs(result_df):
    unmatched_dict = defaultdict(lambda: defaultdict(list))
    matched_dict = defaultdict(lambda: defaultdict(dict))
    
    cluster_sizes = result_df.groupby(['subject_id', 'cluster_name']).size()
    multi_obs_clusters = {(subject_id, cluster_name) 
                         for (subject_id, cluster_name), size in cluster_sizes.items() 
                         if size > 1}
    
    for _, row in result_df.iterrows():
        subject_id, cluster_name, status = row['subject_id'], row['cluster_name'], row['llm_processed']
        idx_value = row.get('IDX', -1)
        
        # Include both explicit 'unmatched' and NaN status for clusters with multiple observations
        if status == 'unmatched' or (pd.isna(status) and (subject_id, cluster_name) in multi_obs_clusters):
            unmatched_dict[subject_id][cluster_name].append({
                'DAY': row['DAY'],
                'finding': row['ELA_cur_ent'],
                'IDX': idx_value
            })

        else:
            llm_cluster = row['LLM_cluster']
            
            if llm_cluster not in matched_dict[subject_id][cluster_name]:
                try:
                    episodes_data = row['episodes']
                    processed_episodes = []
                    
                    if isinstance(episodes_data, list) and all(hasattr(ep, 'episode') for ep in episodes_data if hasattr(ep, '__dict__')):
                        for ep in episodes_data:
                            if hasattr(ep, 'episode') and hasattr(ep, 'days'):
                                processed_episodes.append({
                                    'episode': ep.episode,
                                    'days': ep.days
                                })

                    elif isinstance(episodes_data, str) and 'Episode(' in episodes_data:
                        import re
                        episode_matches = re.finditer(r'Episode\(episode=(\d+), days=\[([\d, ]+)\]\)', episodes_data)
                        for match in episode_matches:
                            episode_num = int(match.group(1))
                            days = [int(d.strip()) for d in match.group(2).split(',') if d.strip()]
                            processed_episodes.append({
                                'episode': episode_num,
                                'days': days
                            })

                    elif isinstance(episodes_data, list) and all(isinstance(ep, dict) for ep in episodes_data if ep):
                        processed_episodes = episodes_data

                    elif isinstance(episodes_data, str):
                        try:
                            parsed_data = json.loads(episodes_data)
                            if isinstance(parsed_data, list):
                                processed_episodes = parsed_data
                        except json.JSONDecodeError:
                            import ast
                            try:
                                parsed_data = ast.literal_eval(episodes_data)
                                if isinstance(parsed_data, list):
                                    processed_episodes = parsed_data
                            except (SyntaxError, ValueError):
                                processed_episodes = []
                                print(f"Warning: Could not parse episodes: {episodes_data}")
                    else:
                        processed_episodes = episodes_data
                except Exception as e: 
                    processed_episodes = []
                    print(f"Error parsing episodes: {e}")
      
                matched_dict[subject_id][cluster_name][llm_cluster] = {
                    'findings': [],
                    'episodes': processed_episodes,
                    'rationale': row['rationale']
                }
            
            matched_dict[subject_id][cluster_name][llm_cluster]['findings'].append({
                'IDX': idx_value,
                'DAY': row['DAY'],
                'finding': row['ELA_cur_ent']
            })
    
    final_inputs = {}
    
    for subject_id, clusters in unmatched_dict.items():
        subject_inputs = {}
        
        for cluster_name, unmatched_items in clusters.items():
            if not unmatched_items:
                continue
            
            unprocessed_observations = {}
            for item in unmatched_items:
                idx_day_str = f"IDX:{item['IDX']}, DAY: {item['DAY']}"
                if idx_day_str not in unprocessed_observations:
                    unprocessed_observations[idx_day_str] = []
                unprocessed_observations[idx_day_str].append(item["finding"])
            
            subject_inputs[cluster_name] = {
                "existing_results": matched_dict[subject_id].get(cluster_name, {}),
                "unprocessed_observations": unprocessed_observations,
                "cluster_name": cluster_name
            }
            
        if subject_inputs:
            final_inputs[subject_id] = subject_inputs
            
    return final_inputs

def create_missing_input(subject_id, cluster_name, result_data):
    if subject_id not in result_data or cluster_name not in result_data[subject_id]:
        return None
    
    cluster_data = result_data[subject_id][cluster_name]
    existing_results = json.dumps(cluster_data["existing_results"], indent=2, cls=NumpyEncoder)
    unprocessed_groups = json.dumps(cluster_data["unprocessed_observations"], indent=2, cls=NumpyEncoder)
    
    missing_prompt = f'''The following observations were previously missed. Your task is to:
    1. Review each missing observation
    2. If there are existing results, assign each observation to the appropriate existing group and episode when applicable
    3. If existing results are empty or not suitable for some observations, create new groups as needed
    4. Return the COMPLETE JSON for ALL groups (both existing and newly created)

    Missing_observations: {unprocessed_groups}
    Existing_results: {existing_results}
    
    Note: 
    - If existing_results is empty, you should create completely new groups for the missing observations
    - Your output should follow the original format directly, with Group Names as top-level keys
    - DO NOT wrap your output in "Missing_observations" or "Existing_results" structure
    - Return a direct JSON object with all groups in the same format as the examples you saw earlier
    '''
    return missing_prompt
 
def post_process(llm_output, clustered_df, output_path, is_missing_process=False, iteration=0):
    
    llm_output = llm_output.rename(columns={
        'finding': 'ELA_cur_ent',  # Changed to match clustered_df's column name
        'status': 'llm_processed'
    })

    # Guard: if llm_output is empty or missing required columns, skip updates
    if llm_output is None or llm_output.empty:
        print("Warning: llm_output is empty; skipping post_process updates.")
        return clustered_df

    # Ensure expected columns exist before building match keys
    if 'subject_id' not in llm_output.columns or 'ELA_cur_ent' not in llm_output.columns:
        print(f"Warning: llm_output missing required columns. Available columns: {list(llm_output.columns)}")
        return clustered_df

    if not is_missing_process:
        # Add new columns to clustered_df with default values only if they do not exist
        default_columns = {
            'LLM_cluster': None,
            'episodes': None,
            'rationale': None,
            'llm_processed': None,
            'temporal_group': None,
            'DAY': None,
            'IDX': None
        }

        for col, default_val in default_columns.items():
            if col not in clustered_df.columns:
                clustered_df[col] = default_val

        # Create a unique identifier for matching
        clustered_df['match_key'] = clustered_df['subject_id'] + '_' + clustered_df['sequence'].fillna(0).astype(int).astype(str) + '_' + clustered_df['ent_idx'].fillna(0).astype(int).astype(str) + '_' + clustered_df['ELA_cur_ent']
        llm_output['match_key'] = llm_output['subject_id'] + '_' + llm_output['sequence'].fillna(0).astype(int).astype(str) + '_' + llm_output['ent_idx'].fillna(0).astype(int).astype(str) + '_' + llm_output['ELA_cur_ent']
        
    else:
        # Create a unique identifier for matching
        clustered_df['match_key'] = clustered_df['subject_id'] + '_' + clustered_df['IDX'].fillna(0).astype(float).astype(int).astype(str) + '_' + clustered_df['ELA_cur_ent']
        if 'IDX' not in llm_output.columns:
            print(f"Warning: llm_output missing 'IDX' column in missing process. Available columns: {list(llm_output.columns)}")
            return clustered_df
        llm_output['match_key'] = llm_output['subject_id'] + '_' + llm_output['IDX'].fillna(0).astype(float).astype(int).astype(str) + '_' + llm_output['ELA_cur_ent']
    
    # Find matching records
    matching_keys = set(clustered_df['match_key']) & set(llm_output['match_key'])

    # Update matching records
    match_count = 0
    for key in matching_keys:
        try:
            llm_data = llm_output[llm_output['match_key'] == key].iloc[0]
            clustered_df.loc[clustered_df['match_key'] == key, 'llm_processed'] = llm_data['llm_processed']
            clustered_df.loc[clustered_df['match_key'] == key, 'LLM_cluster'] = llm_data['LLM_cluster']
            clustered_df.loc[clustered_df['match_key'] == key, 'episodes'] = llm_data['episodes']
            clustered_df.loc[clustered_df['match_key'] == key, 'rationale'] = llm_data['rationale']
            
            temporal_group_val = llm_data['temporal_group']
            clustered_df.loc[clustered_df['match_key'] == key, 'temporal_group'] = temporal_group_val
            clustered_df.loc[clustered_df['match_key'] == key, 'DAY'] = llm_data['DAY']
            clustered_df.loc[clustered_df['match_key'] == key, 'IDX'] = llm_data['IDX']
            
            match_count += 1
        except Exception as e:
            print(f"Error updating match for key {key}: {e}")
    
    print(f"Updated {match_count} records out of {len(matching_keys)} matching keys")
    
    # Remove temporary match_key column
    clustered_df = clustered_df.drop('match_key', axis=1)
    
    return clustered_df


def initialize_llm_client(llm_name):
    if llm_name.startswith('gpt'):
        from openai import OpenAI
        return instructor.from_openai(OpenAI(api_key=os.getenv('API_KEY')), mode=instructor.Mode.JSON), None
            
    elif llm_name.startswith('deepseek') or llm_name.startswith('llama4') or llm_name.startswith('qwen3'):
        from openai import OpenAI
        return instructor.from_openai(OpenAI(api_key=os.getenv('API_KEY'),base_url = "https://api.fireworks.ai/inference/v1"), mode=instructor.Mode.JSON), None
    
    elif llm_name.startswith('baichuan') or llm_name.startswith('vllm_medgemma') or llm_name.startswith('vllm_gpt-oss-120b') or llm_name.startswith('vllm_gpt-oss-20b'):
        from openai import OpenAI
        return instructor.from_openai(OpenAI(api_key=os.getenv('API_KEY'), base_url=f"http://localhost:{os.getenv('PORT')}/v1"), mode=instructor.Mode.JSON), None
    
    elif llm_name.startswith('MedGemma'):
        import torch
        from transformers import AutoTokenizer, AutoModelForCausalLM
        model_cache_dir = "../medgemma"
        model_id = "google/medgemma-27b-text-it"
        access_token = os.getenv('API_KEY')
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            cache_dir=model_cache_dir,
            token=access_token
        )
        tokenizer = AutoTokenizer.from_pretrained(
            model_id,
            cache_dir=model_cache_dir,
            token=access_token
        )
        return model, tokenizer


def estimate_llm_cost(args, prompt_tokens, completion_tokens, total_cost):
    if args.LLM_name == 'gpt-4o-mini':
        input_cost = 0.00015 #0.15dollar 1M token
        output_cost = 0.0006 #0.6dollar 1M token
    elif args.LLM_name == 'gpt-4-5':
        input_cost = 0.075 #75dollar 1M token
        output_cost = 0.15 #150dollar 1M token
    elif args.LLM_name == 'gpt-4.1':
        input_cost = 0.002 #2dollar 1M token
        output_cost = 0.008 #8dollar 1M token        
    elif args.LLM_name == 'gpt-4o-batch':
        input_cost = 0.00125 #1.25dollar 1M token
        output_cost = 0.005 #5dollar 1M token
    elif args.LLM_name == 'gpt-4o':
        input_cost = 0.0025 #2.5dollar 1M token
        output_cost = 0.01 #10dollar 1M token
    elif args.LLM_name == 'o3-mini':
        input_cost = 0.0011 #1.1dollar 1M token
        output_cost = 0.0044 #4.4dollar 1M token
    elif args.LLM_name == 'o3-mini-batch':
        input_cost = 0.00055 #0.55dollar 1M token
        output_cost = 0.0022 #2.2dollar 1M token
    elif args.LLM_name == 'o1':
        input_cost = 0.015 #15dollar 1M token
        output_cost = 0.060 #60dollar 1M token
    elif args.LLM_name == 'claude-3-7-sonnet':
        input_cost = 0.003 #3dollar 1M token
        output_cost = 0.015 #15dollar 1M token
    elif args.LLM_name == 'claude-3-opus':
        input_cost = 0.015 #15dollar 1M token
        output_cost = 0.075 #75dollar 1M token
    elif args.LLM_name == 'claude-3-7-sonnet-batch':
        input_cost = 0.0015 #1.5dollar 1M token
        output_cost = 0.0075 #7.5dollar 1M token
    elif args.LLM_name == 'claude-3-opus-batch':
        input_cost = 0.0075 #7.5dollar 1M token
        output_cost = 0.0375 #37.5dollar 1M token
    elif args.LLM_name == 'claude-3-haiku-20240307':
        input_cost = 0.00025 #0.25dollar 1M token
        output_cost = 0.00125 #1.25dollar 1M token
    
    total_cost["prompt_tokens"].append(prompt_tokens)
    total_cost["completion_tokens"].append(completion_tokens)
    total_cost["gpt-45-cost"].append((0.075*prompt_tokens/1000 + 0.15*completion_tokens/1000))
    total_cost["gpt-4o-batch-cost"].append((0.00125*prompt_tokens/1000 + 0.005*completion_tokens/1000))
    total_cost["gpt-4o-cost"].append((0.0025*prompt_tokens/1000 + 0.01*completion_tokens/1000))
    total_cost["o1-cost"].append((0.015*prompt_tokens/1000 + 0.06*completion_tokens/1000))
    total_cost["claude-3-7-sonnet-cost"].append((0.003*prompt_tokens/1000 + 0.015*completion_tokens/1000))
    total_cost["claude-3-opus-cost"].append((0.015*prompt_tokens/1000 + 0.075*completion_tokens/1000))
    total_cost["claude-3-7-sonnet-batch-cost"].append((0.0015*prompt_tokens/1000 + 0.0075*completion_tokens/1000))
    total_cost["claude-3-opus-batch-cost"].append((0.0075*prompt_tokens/1000 + 0.0375*completion_tokens/1000))
    return (input_cost*prompt_tokens/1000 + output_cost*completion_tokens/1000), total_cost

def generate_few_shot(Input, args):
    if args.LLM_name.startswith('claude'):
        return [
            {"role": "user", 
             "content": [
                 {
                     "type": "text",
                     "text": Result_Review_w_time_gap_Ex1['Input'],
                    #  "cache_control": {"type": "ephemeral"}
                 }
             ]},
            
            {"role": "assistant", 
             "content": json.dumps(Result_Review_w_time_gap_Ex1['Output'])},
            
            {"role": "user", 
             "content": [
                 {
                     "type": "text",
                     "text": Result_Review_w_time_gap_Ex2['Input'],
                    #  "cache_control": {"type": "ephemeral"}
                 }
             ]},
            
            {"role": "assistant", 
             "content": json.dumps(Result_Review_w_time_gap_Ex2['Output'])},
            
            {"role": "user", 
             "content": [
                 {
                     "type": "text",
                     "text": Result_Review_w_time_gap_Ex3['Input'],
                    #  "cache_control": {"type": "ephemeral"}
                 }
             ]},
            
            {"role": "assistant", 
             "content": json.dumps(Result_Review_w_time_gap_Ex3['Output'])},
            
            {"role": "user", 
             "content": [
                 {
                     "type": "text",
                     "text": Result_Review_w_time_gap_Ex4['Input'],
                    #  "cache_control": {"type": "ephemeral"}
                 }
             ]},
            
            {"role": "assistant", 
             "content": json.dumps(Result_Review_w_time_gap_Ex4['Output'])},
            
            {"role": "user", 
             "content": [
                 {
                     "type": "text",
                     "text": Result_Review_w_time_gap_Ex5['Input'],
                    #  "cache_control": {"type": "ephemeral"}
                 }
             ]},
            
            {"role": "assistant", 
             "content": json.dumps(Result_Review_w_time_gap_Ex5['Output'])},
                                    
            {"role": "user", 
             "content": [
                 {
                     "type": "text",
                     "text": Input
                 }
             ]}
        ]
    else:
        if args.few_shot:
            return [
                {"role": "system", "content": Sequential_findings_review_prompt},
                    
                {"role": "user", "content": Result_Review_w_time_gap_Ex1['Input']},
                {"role": "assistant", "content": json.dumps(Result_Review_w_time_gap_Ex1['Output'])},
                
                {"role": "user", "content": Result_Review_w_time_gap_Ex2['Input']},
                {"role": "assistant", "content": json.dumps(Result_Review_w_time_gap_Ex2['Output'])},
                
                {"role": "user", "content": Result_Review_w_time_gap_Ex3['Input']},
                {"role": "assistant", "content": json.dumps(Result_Review_w_time_gap_Ex3['Output'])},
                
                {"role": "user", "content": Result_Review_w_time_gap_Ex4['Input']},
                {"role": "assistant", "content": json.dumps(Result_Review_w_time_gap_Ex4['Output'])},
                            
                {"role": "user", "content": Result_Review_w_time_gap_Ex5['Input']},
                {"role": "assistant", "content": json.dumps(Result_Review_w_time_gap_Ex5['Output'])},
                            
                {"role": "user", "content": Input}
                ]
        else:
            return [
                {"role": "system", "content": Sequential_findings_review_prompt},
                {"role": "user", "content": Input}
            ]

def combined_precision_recall(true_clusters, pred_clusters, fuzzy_weight=0.6, jaccard_weight=0.4, threshold=0.6):
    tp_combined = 0
    fp = 0
    fn = 0
    
    # 각 예측 클러스터에 대해
    for pred_cluster in pred_clusters:
        pred_tokens = set(pred_cluster.lower().split())
        best_match = None
        best_score = 0
        
        # 가장 유사한 실제 클러스터 찾기
        for true_cluster in true_clusters:
            true_tokens = set(true_cluster.lower().split())
            
            # Fuzzy 유사도 계산
            fuzzy_similarity = fuzz.token_sort_ratio(pred_cluster, true_cluster) / 100.0
            
            # Jaccard 유사도 계산
            if not pred_tokens or not true_tokens:
                jaccard_similarity = 0
            else:
                intersection = len(pred_tokens.intersection(true_tokens))
                union = len(pred_tokens.union(true_tokens))
                jaccard_similarity = intersection / union if union > 0 else 0
            
            # 가중 평균으로 통합 점수 계산
            combined_similarity = fuzzy_weight * fuzzy_similarity + jaccard_weight * jaccard_similarity
            
            if combined_similarity > best_score:
                best_score = combined_similarity
                best_match = true_cluster
        
        # 임계값 이상이면 TP로 간주, 아니면 FP로 간주
        if best_score >= threshold:
            tp_combined += best_score  # 유사도를 가중치로 사용
        else:
            fp += 1
    
    # 매칭되지 않은 실제 클러스터는 FN으로 간주
    matched_true_clusters = set()
    for pred_cluster in pred_clusters:
        best_match = None
        best_score = 0
        
        for true_cluster in true_clusters:
            # Fuzzy 및 Jaccard 유사도 계산 (위와 동일)
            fuzzy_similarity = fuzz.token_sort_ratio(pred_cluster, true_cluster) / 100.0
            
            pred_tokens = set(pred_cluster.lower().split())
            true_tokens = set(true_cluster.lower().split())
            
            if not pred_tokens or not true_tokens:
                jaccard_similarity = 0
            else:
                intersection = len(pred_tokens.intersection(true_tokens))
                union = len(pred_tokens.union(true_tokens))
                jaccard_similarity = intersection / union if union > 0 else 0
            
            combined_similarity = fuzzy_weight * fuzzy_similarity + jaccard_weight * jaccard_similarity
            
            if combined_similarity > best_score:
                best_score = combined_similarity
                best_match = true_cluster
        
        if best_score >= threshold:
            matched_true_clusters.add(best_match)
    
    fn = len(true_clusters) - len(matched_true_clusters)
    
    # Precision과 Recall 계산
    precision = tp_combined / (tp_combined + fp) if (tp_combined + fp) > 0 else 0
    recall = tp_combined / (tp_combined + fn) if (tp_combined + fn) > 0 else 0
    
    return precision, recall

def text_f1_score(pred_name, true_name):
    # 토큰화 (간단한 공백 기반 토큰화)
    precision, recall = combined_precision_recall(true_name, pred_name)
        
    if precision + recall == 0:
        return 0.0
    
    f1 = 2 * precision * recall / (precision + recall)
    return f1

def group_wise_accuracy(true_groups, pred_groups):
    """각 그룹별 정확도를 개별적으로 계산 (음수 값 처리)"""
    unique_true_groups = np.unique(true_groups)
    group_accuracies = {}
    
    for group in unique_true_groups:
        # 해당 그룹에 속하는 항목 인덱스
        group_indices = (true_groups == group)
        
        # 해당 그룹 항목들의 예측 정확도
        if np.sum(group_indices) > 0:
            # 해당 그룹에서 가장 많이 예측된 클래스
            pred_for_group = pred_groups[group_indices]
            
            # 음수 값을 포함할 수 있으므로 Counter 사용
            from collections import Counter
            pred_counts = Counter(pred_for_group)
            most_common_pred = pred_counts.most_common(1)[0][0]
            
            # 해당 그룹 내에서 올바르게 클러스터링된 비율
            correct_clustering = np.sum(pred_for_group == most_common_pred)
            accuracy = correct_clustering / np.sum(group_indices)
            
            group_accuracies[float(group) if isinstance(group, (int, float, np.number)) else group] = {
                'accuracy': float(accuracy),
                'count': int(np.sum(group_indices)),
                'most_common_pred': float(most_common_pred) if isinstance(most_common_pred, (int, float, np.number)) else most_common_pred
            }
    
    return group_accuracies

def calculate_purity_fscore(df):
    """
    클러스터링 결과의 Purity와 F-score 계산
    
    Parameters:
    -----------
    df : pandas.DataFrame
        클러스터링 결과가 포함된 데이터프레임
        
    Returns:
    --------
    dict
        Purity와 F-score 값을 포함하는 딕셔너리
    """
    import numpy as np
    from sklearn.metrics import f1_score
    
    # Purity 계산
    def calculate_purity(y_true, y_pred):
        """
        클러스터 순도(Purity) 계산
        """
        # 각 클러스터에서 가장 많은 클래스의 항목 수 합계
        contingency_matrix = pd.crosstab(y_pred, y_true)
        return np.sum(np.amax(contingency_matrix, axis=1)) / np.sum(contingency_matrix.values)
    
    # entity_group Purity
    entity_purity = calculate_purity(
        df['gt_entity_group'].fillna('nan').astype(str),
        df['LLM_cluster'].fillna('nan').astype(str)
    )
    
    # temporal_group Purity
    temporal_purity = calculate_purity(
        df['gt_temporal_group'].fillna('nan').astype(str),
        df['temporal_group'].fillna('nan').astype(str)
    )
    
    # F-score 계산을 위한 준비
    # 각 GT 클래스 쌍에 대해 같은 클러스터에 있는지 여부 (1: 같음, 0: 다름)
    def create_pair_matrix(labels):
        n = len(labels)
        pairs = np.zeros((n, n), dtype=int)
        for i in range(n):
            for j in range(i+1, n):
                if labels[i] == labels[j]:
                    pairs[i, j] = pairs[j, i] = 1
        return pairs
    
    # entity_group F-score
    gt_entity_pairs = create_pair_matrix(df['gt_entity_group'].fillna('nan').astype(str).values)
    pred_entity_pairs = create_pair_matrix(df['LLM_cluster'].fillna('nan').astype(str).values)
    
    # 상삼각행렬만 사용 (중복 제거)
    mask = np.triu_indices(len(gt_entity_pairs), k=1)
    gt_entity_pairs_flat = gt_entity_pairs[mask]
    pred_entity_pairs_flat = pred_entity_pairs[mask]
    
    entity_f1 = f1_score(gt_entity_pairs_flat, pred_entity_pairs_flat)
    entity_precision = np.sum(gt_entity_pairs_flat & pred_entity_pairs_flat) / np.sum(pred_entity_pairs_flat)
    entity_recall = np.sum(gt_entity_pairs_flat & pred_entity_pairs_flat) / np.sum(gt_entity_pairs_flat)
    
    # temporal_group F-score
    gt_temporal_pairs = create_pair_matrix(df['gt_temporal_group'].fillna('nan').astype(str).values)
    pred_temporal_pairs = create_pair_matrix(df['temporal_group'].fillna('nan').astype(str).values)
    
    # 상삼각행렬만 사용 (중복 제거)
    gt_temporal_pairs_flat = gt_temporal_pairs[mask]
    pred_temporal_pairs_flat = pred_temporal_pairs[mask]
    
    temporal_f1 = f1_score(gt_temporal_pairs_flat, pred_temporal_pairs_flat)
    temporal_precision = np.sum(gt_temporal_pairs_flat & pred_temporal_pairs_flat) / np.sum(pred_temporal_pairs_flat)
    temporal_recall = np.sum(gt_temporal_pairs_flat & pred_temporal_pairs_flat) / np.sum(gt_temporal_pairs_flat)
    
    return {
        'entity_purity': entity_purity,
        'entity_f1': entity_f1,
        'entity_precision': entity_precision,
        'entity_recall': entity_recall,
        'temporal_purity': temporal_purity,
        'temporal_f1': temporal_f1,
        'temporal_precision': temporal_precision,
        'temporal_recall': temporal_recall
    }

def calculate_subject_purity_fscore(df):
    """
    주제별 Purity와 F-score 계산
    
    Parameters:
    -----------
    df : pandas.DataFrame
        클러스터링 결과가 포함된 데이터프레임
        
    Returns:
    --------
    dict
        주제별 Purity와 F-score 값을 포함하는 딕셔너리
    """
    subject_metrics = {}
    
    for subject_id in df['subject_id'].unique():
        subject_df = df[df['subject_id'] == subject_id]
        
        # 항목이 충분히 있는 경우에만 계산
        if len(subject_df) >= 2:
            metrics = calculate_purity_fscore(subject_df)
            subject_metrics[subject_id] = metrics
    
    # 평균 계산
    avg_metrics = {
        'entity_purity_mean': np.mean([m['entity_purity'] for m in subject_metrics.values()]),
        'entity_f1_mean': np.mean([m['entity_f1'] for m in subject_metrics.values()]),
        'entity_precision_mean': np.mean([m['entity_precision'] for m in subject_metrics.values()]),
        'entity_recall_mean': np.mean([m['entity_recall'] for m in subject_metrics.values()]),
        'temporal_purity_mean': np.mean([m['temporal_purity'] for m in subject_metrics.values()]),
        'temporal_f1_mean': np.mean([m['temporal_f1'] for m in subject_metrics.values()]),
        'temporal_precision_mean': np.mean([m['temporal_precision'] for m in subject_metrics.values()]),
        'temporal_recall_mean': np.mean([m['temporal_recall'] for m in subject_metrics.values()])
    }
    
    return {
        'subject_purity_fscore': subject_metrics,
        **avg_metrics
    }
        
def analyze_low_performance_subjects(df, threshold=0.6):
    """
    성능이 낮은 주제 분석
    """
    from sklearn.metrics import adjusted_rand_score
    
    low_performance = []
    
    for subject_id in df['subject_id'].unique():
        subject_df = df[df['subject_id'] == subject_id]
        
        # ARI 계산
        ari = adjusted_rand_score(subject_df['gt_entity_group'].fillna('nan').astype(str), 
                                subject_df['LLM_cluster'].fillna('nan').astype(str))
        
        if ari < threshold:
            low_performance.append((subject_id, ari))
    
    # 성능이 낮은 주제 상세 분석
    for subject_id, ari in low_performance:
        print(f"\n분석 대상 주제: {subject_id}, ARI: {ari:.4f}")
        
        subject_df = df[df['subject_id'] == subject_id]
        
        # GT 그룹과 예측 그룹 간의 매핑 테이블
        mapping_table = pd.crosstab(subject_df['gt_entity_group'], subject_df['LLM_cluster'])
        print("\nGT 그룹과 예측 그룹 간의 매핑 테이블:")
        print(mapping_table)
    
    return low_performance

def calculate_text_f1(df):
    """
    동일 텍스트가 동일 그룹에 속하는지 평가하는 F1 점수 계산
    """
    # concatenated 컬럼이 없으면 ELA_cur_ent를 사용하거나 기본값 반환
    if 'concatenated' not in df.columns:
        print("Warning: 'concatenated' column not found. Using 'ELA_cur_ent' as fallback.")
        if 'ELA_cur_ent' in df.columns:
            df = df.copy()
            df['concatenated'] = df['ELA_cur_ent']
        else:
            print("Warning: Neither 'concatenated' nor 'ELA_cur_ent' columns found. Returning 0 for text F1.")
            return 0

    # 각 고유 텍스트에 대해 일관성 계산
    text_consistency = {}
    for text in df['concatenated'].unique():
        text_items = df[df['concatenated'] == text]
        
        # GT 그룹 일관성
        gt_groups = text_items['gt_entity_group'].value_counts()
        gt_consistency = gt_groups.max() / len(text_items) if len(gt_groups) > 0 else 0
        
        # 예측 그룹 일관성
        pred_groups = text_items['LLM_cluster'].value_counts()
        pred_consistency = pred_groups.max() / len(text_items) if len(pred_groups) > 0 else 0
        
        text_consistency[text] = {
            'count': len(text_items),
            'gt_consistency': gt_consistency,
            'pred_consistency': pred_consistency
        }
    
    # 가중 평균 계산 (항목 수로 가중치 부여)
    total_items = sum(info['count'] for info in text_consistency.values())
    weighted_gt_consistency = sum(info['count'] * info['gt_consistency'] for info in text_consistency.values()) / total_items
    weighted_pred_consistency = sum(info['count'] * info['pred_consistency'] for info in text_consistency.values()) / total_items
    
    # F1 점수 계산
    if weighted_gt_consistency + weighted_pred_consistency == 0:
        return 0
    
    f1 = 2 * (weighted_gt_consistency * weighted_pred_consistency) / (weighted_gt_consistency + weighted_pred_consistency)
    return f1

def calculate_temporal_group_accuracies(df):
    """
    각 시간적 그룹별 정확도 계산
    """
    accuracies = {}
    
    for group in df['gt_temporal_group'].unique():
        group_items = df[df['gt_temporal_group'] == group]
        pred_groups = group_items['temporal_group'].value_counts()
        most_common_pred = pred_groups.idxmax() if len(pred_groups) > 0 else None
        accuracy = pred_groups.max() / len(group_items) if len(pred_groups) > 0 else 0
        
        accuracies[group] = {
            'accuracy': accuracy,
            'count': len(group_items),
            'most_common_pred': most_common_pred
        }
    
    return accuracies

def calculate_subject_metrics(df):
    """
    각 주제(subject)별 메트릭 계산
    """
    from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
    from sklearn.metrics import homogeneity_score, completeness_score, v_measure_score
    
    subject_metrics = {}
    
    for subject_id in df['subject_id'].unique():
        subject_df = df[df['subject_id'] == subject_id]
        
        # 항목이 충분히 있는 경우에만 계산
        if len(subject_df) >= 2:
            # temporal_group 메트릭
            temporal_ari = adjusted_rand_score(subject_df['gt_temporal_group'].fillna('nan').astype(str), 
                                              subject_df['temporal_group'].fillna('nan').astype(str))
            temporal_nmi = normalized_mutual_info_score(subject_df['gt_temporal_group'].fillna('nan').astype(str), 
                                                       subject_df['temporal_group'].fillna('nan').astype(str))
            temporal_homogeneity = homogeneity_score(subject_df['gt_temporal_group'].fillna('nan').astype(str), 
                                                    subject_df['temporal_group'].fillna('nan').astype(str))
            temporal_completeness = completeness_score(subject_df['gt_temporal_group'].fillna('nan').astype(str), 
                                                      subject_df['temporal_group'].fillna('nan').astype(str))
            temporal_v_measure = v_measure_score(subject_df['gt_temporal_group'].fillna('nan').astype(str), 
                                                subject_df['temporal_group'].fillna('nan').astype(str))
            
            # entity_group 메트릭
            entity_ari = adjusted_rand_score(subject_df['gt_entity_group'].fillna('nan').astype(str), 
                                            subject_df['LLM_cluster'].fillna('nan').astype(str))
            entity_nmi = normalized_mutual_info_score(subject_df['gt_entity_group'].fillna('nan').astype(str), 
                                                     subject_df['LLM_cluster'].fillna('nan').astype(str))
            entity_homogeneity = homogeneity_score(subject_df['gt_entity_group'].fillna('nan').astype(str), 
                                                  subject_df['LLM_cluster'].fillna('nan').astype(str))
            entity_completeness = completeness_score(subject_df['gt_entity_group'].fillna('nan').astype(str), 
                                                    subject_df['LLM_cluster'].fillna('nan').astype(str))
            entity_v_measure = v_measure_score(subject_df['gt_entity_group'].fillna('nan').astype(str), 
                                              subject_df['LLM_cluster'].fillna('nan').astype(str))
            
            # 텍스트 F1 점수
            text_f1 = calculate_text_f1(subject_df)
            
            subject_metrics[subject_id] = {
                'temporal_ARI': temporal_ari,
                'temporal_NMI': temporal_nmi,
                'temporal_homogeneity': temporal_homogeneity,
                'temporal_completeness': temporal_completeness,
                'temporal_v_measure': temporal_v_measure,
                'entity_ARI': entity_ari,
                'entity_NMI': entity_nmi,
                'entity_homogeneity': entity_homogeneity,
                'entity_completeness': entity_completeness,
                'entity_v_measure': entity_v_measure,
                'text_f1': text_f1,
                'count': len(subject_df)
            }
    
    # 평균 메트릭 계산
    metrics_sum = {
        'temporal_ARI_sum': 0,
        'temporal_NMI_sum': 0,
        'temporal_homogeneity_sum': 0,
        'temporal_completeness_sum': 0,
        'temporal_v_measure_sum': 0,
        'entity_ARI_sum': 0,
        'entity_NMI_sum': 0,
        'entity_homogeneity_sum': 0,
        'entity_completeness_sum': 0,
        'entity_v_measure_sum': 0,
        'text_f1_sum': 0,
        'total_count': 0
    }
    
    for metrics in subject_metrics.values():
        metrics_sum['temporal_ARI_sum'] += metrics['temporal_ARI']
        metrics_sum['temporal_NMI_sum'] += metrics['temporal_NMI']
        metrics_sum['temporal_homogeneity_sum'] += metrics['temporal_homogeneity']
        metrics_sum['temporal_completeness_sum'] += metrics['temporal_completeness']
        metrics_sum['temporal_v_measure_sum'] += metrics['temporal_v_measure']
        metrics_sum['entity_ARI_sum'] += metrics['entity_ARI']
        metrics_sum['entity_NMI_sum'] += metrics['entity_NMI']
        metrics_sum['entity_homogeneity_sum'] += metrics['entity_homogeneity']
        metrics_sum['entity_completeness_sum'] += metrics['entity_completeness']
        metrics_sum['entity_v_measure_sum'] += metrics['entity_v_measure']
        metrics_sum['text_f1_sum'] += metrics['text_f1']
        metrics_sum['total_count'] += 1
    
    # 평균 계산
    count = metrics_sum['total_count']
    if count > 0:
        return {
            'temporal_ARI_mean': metrics_sum['temporal_ARI_sum'] / count,
            'temporal_NMI_mean': metrics_sum['temporal_NMI_sum'] / count,
            'temporal_homogeneity_mean': metrics_sum['temporal_homogeneity_sum'] / count,
            'temporal_completeness_mean': metrics_sum['temporal_completeness_sum'] / count,
            'temporal_v_measure_mean': metrics_sum['temporal_v_measure_sum'] / count,
            'entity_ARI_mean': metrics_sum['entity_ARI_sum'] / count,
            'entity_NMI_mean': metrics_sum['entity_NMI_sum'] / count,
            'entity_homogeneity_mean': metrics_sum['entity_homogeneity_sum'] / count,
            'entity_completeness_mean': metrics_sum['entity_completeness_sum'] / count,
            'entity_v_measure_mean': metrics_sum['entity_v_measure_sum'] / count,
            'text_f1_mean': metrics_sum['text_f1_sum'] / count,
            'subject_metrics': subject_metrics
        }
    else:
        return {
            'temporal_ARI_mean': 0,
            'temporal_NMI_mean': 0,
            'temporal_homogeneity_mean': 0,
            'temporal_completeness_mean': 0,
            'temporal_v_measure_mean': 0,
            'entity_ARI_mean': 0,
            'entity_NMI_mean': 0,
            'entity_homogeneity_mean': 0,
            'entity_completeness_mean': 0,
            'entity_v_measure_mean': 0,
            'text_f1_mean': 0,
            'subject_metrics': {}
        }

def eval_func(args, clustered_df):
    from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
    from sklearn.metrics import homogeneity_score, completeness_score, v_measure_score
    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    import numpy as np
    import os
    from tabulate import tabulate
    
    # Create results directory if it doesn't exist
    os.makedirs(f'{args.output_path}/eval', exist_ok=True)
    
    # Check if required columns exist before filtering
    required_columns = ['gt_temporal_group', 'gt_entity_group', 'temporal_group', 'LLM_cluster']
    missing_columns = [col for col in required_columns if col not in clustered_df.columns]
    
    if missing_columns:
        print(f"Warning: Missing required columns for evaluation: {missing_columns}")
        print("Skipping evaluation function.")
        return
    
    clustered_df = clustered_df[(clustered_df.gt_temporal_group.notna())&
                                (clustered_df.gt_entity_group.notna())&
                                (clustered_df.temporal_group.notna())&
                                (clustered_df.LLM_cluster.notna())]
    
    if len(clustered_df) == 0:
        raise ValueError("Warning: No valid data remaining after filtering. Skipping evaluation.")
    
    analysis_results = analyze_zero_ari_subjects(clustered_df)
    compare_zero_and_high_ari_subjects(clustered_df)
    
    clustered_df.to_csv(f'{args.output_path}/eval/eval_func_clustered_df.csv', index=True)
    
    temporal_ari = adjusted_rand_score(clustered_df['gt_temporal_group'].fillna('nan').astype(str), 
                                      clustered_df['temporal_group'].fillna('nan').astype(str))
    temporal_nmi = normalized_mutual_info_score(clustered_df['gt_temporal_group'].fillna('nan').astype(str), 
                                               clustered_df['temporal_group'].fillna('nan').astype(str))
    
    entity_ari = adjusted_rand_score(clustered_df['gt_entity_group'].fillna('nan').astype(str), 
                                    clustered_df['LLM_cluster'].fillna('nan').astype(str))
    entity_nmi = normalized_mutual_info_score(clustered_df['gt_entity_group'].fillna('nan').astype(str), 
                                             clustered_df['LLM_cluster'].fillna('nan').astype(str))    
    # F-score와 Purity 계산
    purity_fscore = calculate_purity_fscore(clustered_df)
    text_f1 = calculate_text_f1(clustered_df)
    temporal_group_accuracies = calculate_temporal_group_accuracies(clustered_df)
    
    subject_metrics = calculate_subject_metrics(clustered_df)
    subject_purity_fscore = calculate_subject_purity_fscore(clustered_df)
    
    # Create metrics tables with metrics on columns and grouping types on rows
    # 1. Overall metrics table - Homogeneity, Completeness, V-measure 제외
    overall_metrics = [
        ["Grouping Type", "ARI", "NMI", "Purity", "F1-score", "Precision", "Recall"],
        ["Temporal Grouping", f"{temporal_ari:.4f}", f"{temporal_nmi:.4f}", f"{purity_fscore['temporal_purity']:.4f}", 
         f"{purity_fscore['temporal_f1']:.4f}", f"{purity_fscore['temporal_precision']:.4f}", f"{purity_fscore['temporal_recall']:.4f}"],
        ["Entity Grouping", f"{entity_ari:.4f}", f"{entity_nmi:.4f}", f"{purity_fscore['entity_purity']:.4f}", 
         f"{purity_fscore['entity_f1']:.4f}", f"{purity_fscore['entity_precision']:.4f}", f"{purity_fscore['entity_recall']:.4f}"]
    ]
    
    # 2. Subject-wise average metrics table - Homogeneity, Completeness, V-measure 제외
    subject_avg_metrics = [
        ["Grouping Type", "Avg. ARI", "Avg. NMI", "Avg. Purity", "Avg. F1-score", "Avg. Precision", "Avg. Recall"],
        ["Temporal Grouping", f"{subject_metrics.get('temporal_ARI_mean', 0):.4f}", f"{subject_metrics.get('temporal_NMI_mean', 0):.4f}", 
         f"{subject_purity_fscore.get('temporal_purity_mean', 0):.4f}", f"{subject_purity_fscore.get('temporal_f1_mean', 0):.4f}", 
         f"{subject_purity_fscore.get('temporal_precision_mean', 0):.4f}", f"{subject_purity_fscore.get('temporal_recall_mean', 0):.4f}"],
        ["Entity Grouping", f"{subject_metrics.get('entity_ARI_mean', 0):.4f}", f"{subject_metrics.get('entity_NMI_mean', 0):.4f}", 
         f"{subject_purity_fscore.get('entity_purity_mean', 0):.4f}", f"{subject_purity_fscore.get('entity_f1_mean', 0):.4f}", 
         f"{subject_purity_fscore.get('entity_precision_mean', 0):.4f}", f"{subject_purity_fscore.get('entity_recall_mean', 0):.4f}"]
    ]
    
    # 3. Top performing subjects table (top 5)
    temporal_performance = []
    for subject_id, metrics in subject_metrics.get('subject_metrics', {}).items():
        purity = subject_purity_fscore.get('subject_purity_fscore', {}).get(subject_id, {}).get('temporal_purity', 0)
        f1 = subject_purity_fscore.get('subject_purity_fscore', {}).get(subject_id, {}).get('temporal_f1', 0)
        temporal_performance.append((subject_id, metrics.get('temporal_ARI', 0), purity, f1))
    
    temporal_performance.sort(key=lambda x: x[1], reverse=True)
    top_temporal_subjects = [["Subject ID", "ARI", "Purity", "F1-score"]]
    for subject_id, ari, purity, f1 in temporal_performance[:5]:
        top_temporal_subjects.append([subject_id, f"{ari:.4f}", f"{purity:.4f}", f"{f1:.4f}"])
    
    entity_performance = []
    for subject_id, metrics in subject_metrics.get('subject_metrics', {}).items():
        purity = subject_purity_fscore.get('subject_purity_fscore', {}).get(subject_id, {}).get('entity_purity', 0)
        f1 = subject_purity_fscore.get('subject_purity_fscore', {}).get(subject_id, {}).get('entity_f1', 0)
        entity_performance.append((subject_id, metrics.get('entity_ARI', 0), purity, f1))
    
    entity_performance.sort(key=lambda x: x[1], reverse=True)
    top_entity_subjects = [["Subject ID", "ARI", "Purity", "F1-score"]]
    for subject_id, ari, purity, f1 in entity_performance[:5]:
        top_entity_subjects.append([subject_id, f"{ari:.4f}", f"{purity:.4f}", f"{f1:.4f}"])
    
    # Generate LaTeX tables
    def table_to_latex(table_data, caption, label):
        headers = table_data[0]
        rows = table_data[1:]
        
        latex = "\\begin{table}[h]\n\\centering\n"
        latex += "\\caption{" + caption + "}\n"
        latex += "\\label{" + label + "}\n"
        
        # Create column format
        col_format = "|" + "|".join(["c"] * len(headers)) + "|"
        latex += "\\begin{tabular}{" + col_format + "}\n\\hline\n"
        
        # Add headers
        latex += " & ".join(headers) + " \\\\ \\hline\n"
        
        # Add rows
        for row in rows:
            latex += " & ".join([str(cell) for cell in row]) + " \\\\ \\hline\n"
        
        latex += "\\end{tabular}\n\\end{table}"
        return latex
    
    # Generate LaTeX tables
    overall_latex = table_to_latex(overall_metrics, "Overall Evaluation Metrics", "tab:overall_metrics")
    subject_avg_latex = table_to_latex(subject_avg_metrics, "Subject-wise Average Metrics", "tab:subject_avg_metrics")
    top_temporal_latex = table_to_latex(top_temporal_subjects, "Top 5 Subjects by Temporal Grouping Performance", "tab:top_temporal")
    top_entity_latex = table_to_latex(top_entity_subjects, "Top 5 Subjects by Entity Grouping Performance", "tab:top_entity")
    
    # Save LaTeX tables to files
    with open(f'{args.output_path}/eval/overall_metrics.tex', 'w') as f:
        f.write(overall_latex)
    
    with open(f'{args.output_path}/eval/subject_avg_metrics.tex', 'w') as f:
        f.write(subject_avg_latex)
    
    with open(f'{args.output_path}/eval/top_temporal_subjects.tex', 'w') as f:
        f.write(top_temporal_latex)
    
    with open(f'{args.output_path}/eval/top_entity_subjects.tex', 'w') as f:
        f.write(top_entity_latex)
    
    # Generate PNG tables using matplotlib
    def table_to_png(table_data, title, filename):
        fig, ax = plt.figure(figsize=(10, len(table_data) * 0.5 + 1)), plt.gca()
        ax.axis('tight')
        ax.axis('off')
        table = ax.table(cellText=table_data[1:], colLabels=table_data[0], 
                         loc='center', cellLoc='center')
        table.auto_set_font_size(False)
        table.set_fontsize(10)
        table.scale(1.2, 1.5)
        plt.title(title, fontsize=14, pad=20)
        plt.tight_layout()
        plt.savefig(f'{args.output_path}/eval/{filename}.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    # Generate PNG tables
    table_to_png(overall_metrics, "Overall Evaluation Metrics", "overall_metrics")
    table_to_png(subject_avg_metrics, "Subject-wise Average Metrics", "subject_avg_metrics")
    table_to_png(top_temporal_subjects, "Top 5 Subjects by Temporal Grouping Performance", "top_temporal_subjects")
    table_to_png(top_entity_subjects, "Top 5 Subjects by Entity Grouping Performance", "top_entity_subjects")
    
    # 환자 10명 결과와 overall 결과를 함께 보여주는 그래프 생성
    # 1. Temporal Grouping 결과
    top10_temporal = temporal_performance[:10]
    
    # 데이터 준비
    subjects = [subj[0] for subj in top10_temporal] + ['Overall']
    ari_values = [subj[1] for subj in top10_temporal] + [temporal_ari]
    purity_values = [subj[2] for subj in top10_temporal] + [purity_fscore['temporal_purity']]
    f1_values = [subj[3] for subj in top10_temporal] + [purity_fscore['temporal_f1']]
    
    # 그래프 생성
    fig, ax = plt.subplots(figsize=(15, 8))
    
    # 막대 너비 및 위치 설정
    bar_width = 0.25
    r1 = np.arange(len(subjects))
    r2 = [x + bar_width for x in r1]
    r3 = [x + bar_width for x in r2]
    
    # 막대 그래프 그리기
    ax.bar(r1, ari_values, width=bar_width, label='ARI', color='skyblue')
    ax.bar(r2, purity_values, width=bar_width, label='Purity', color='lightgreen')
    ax.bar(r3, f1_values, width=bar_width, label='F1-score', color='salmon')
    
    # 축 및 레이블 설정
    ax.set_xlabel('Subject ID', fontweight='bold', fontsize=12)
    ax.set_ylabel('Score', fontweight='bold', fontsize=12)
    ax.set_title('Temporal Grouping Performance: Top 10 Subjects vs Overall', fontweight='bold', fontsize=14)
    ax.set_xticks([r + bar_width for r in range(len(subjects))])
    ax.set_xticklabels(subjects, rotation=45, ha='right')
    
    # Overall 결과 구분선 추가
    ax.axvline(x=len(subjects)-1.5, color='gray', linestyle='--')
    
    # 범례 추가
    ax.legend()
    
    # 그리드 추가
    ax.grid(True, linestyle='--', alpha=0.7)
    
    # y축 범위 설정 (0~1)
    ax.set_ylim(0, 1.05)
    
    plt.tight_layout()
    plt.savefig(f'{args.output_path}/eval/temporal_top10_vs_overall.png', dpi=300)
    plt.close()
    
    # 2. Entity Grouping 결과
    top10_entity = entity_performance[:10]
    
    # 데이터 준비
    subjects = [subj[0] for subj in top10_entity] + ['Overall']
    ari_values = [subj[1] for subj in top10_entity] + [entity_ari]
    purity_values = [subj[2] for subj in top10_entity] + [purity_fscore['entity_purity']]
    f1_values = [subj[3] for subj in top10_entity] + [purity_fscore['entity_f1']]
    
    # 그래프 생성
    fig, ax = plt.subplots(figsize=(15, 8))
    
    # 막대 너비 및 위치 설정
    bar_width = 0.25
    r1 = np.arange(len(subjects))
    r2 = [x + bar_width for x in r1]
    r3 = [x + bar_width for x in r2]
    
    # 막대 그래프 그리기
    ax.bar(r1, ari_values, width=bar_width, label='ARI', color='skyblue')
    ax.bar(r2, purity_values, width=bar_width, label='Purity', color='lightgreen')
    ax.bar(r3, f1_values, width=bar_width, label='F1-score', color='salmon')
    
    # 축 및 레이블 설정
    ax.set_xlabel('Subject ID', fontweight='bold', fontsize=12)
    ax.set_ylabel('Score', fontweight='bold', fontsize=12)
    ax.set_title('Entity Grouping Performance: Top 10 Subjects vs Overall', fontweight='bold', fontsize=14)
    ax.set_xticks([r + bar_width for r in range(len(subjects))])
    ax.set_xticklabels(subjects, rotation=45, ha='right')
    
    # Overall 결과 구분선 추가
    ax.axvline(x=len(subjects)-1.5, color='gray', linestyle='--')
    
    # 범례 추가
    ax.legend()
    
    # 그리드 추가
    ax.grid(True, linestyle='--', alpha=0.7)
    
    # y축 범위 설정 (0~1)
    ax.set_ylim(0, 1.05)
    
    plt.tight_layout()
    plt.savefig(f'{args.output_path}/eval/entity_top10_vs_overall.png', dpi=300)
    plt.close()
    
    # 맵핑 결과 분석 및 시각화 (예전 코드에서 가져옴)
    print("\n=== 맵핑 결과 분석 ===\n")
    
    # 전체 데이터에 대한 맵핑 테이블 생성 및 시각화
    print("\n=== 전체 데이터에 대한 맵핑 테이블 ===\n")
    
    # Temporal Grouping 전체 맵핑 테이블
    overall_temporal_confusion = pd.crosstab(
        clustered_df['gt_temporal_group'], 
        clustered_df['temporal_group']
    )
    overall_temporal_confusion.to_csv(f'{args.output_path}/eval/overall_temporal_mapping.csv', index=True)
    
    # 시각화 - 크기 조정 (클래스가 많을 경우 대비)
    plt.figure(figsize=(max(12, len(overall_temporal_confusion.columns)//2), 
                       max(10, len(overall_temporal_confusion.index)//2)))
    sns.heatmap(overall_temporal_confusion, annot=True, fmt='d', cmap='Blues', cbar=True)
    plt.title('Overall Temporal Grouping Mapping Table')
    plt.xlabel('Predicted Group')
    plt.ylabel('Actual Group')
    plt.tight_layout()
    plt.savefig(f'{args.output_path}/eval/overall_temporal_mapping.png')
    plt.close()
    
    # Entity Grouping 전체 맵핑 테이블
    overall_entity_confusion = pd.crosstab(
        clustered_df['gt_entity_group'], 
        clustered_df['LLM_cluster']
    )
    overall_entity_confusion.to_csv(f'{args.output_path}/eval/overall_entity_mapping.csv', index=True)
    
    # 시각화 - 크기 조정 (클래스가 많을 경우 대비)
    plt.figure(figsize=(max(15, len(overall_entity_confusion.columns)//2), 
                       max(12, len(overall_entity_confusion.index)//2)))
    sns.heatmap(overall_entity_confusion, annot=True, fmt='d', cmap='Blues', cbar=True)
    plt.title('Overall Entity Grouping Mapping Table')
    plt.xlabel('Predicted Group')
    plt.ylabel('Actual Group')
    plt.tight_layout()
    plt.savefig(f'{args.output_path}/eval/overall_entity_mapping.png')
    plt.close()
    
    # 환자별 맵핑 테이블 생성 및 시각화
    for subject in clustered_df['subject_id'].unique():
        subject_df = clustered_df[clustered_df['subject_id'] == subject]
        
        # temporal_group 분석
        print(f"\n--- subject_id: {subject} ---")
        print(f"항목 수: {len(subject_df)}")
        
        # temporal_group 분석
        gt_col, pred_col = 'gt_temporal_group', 'temporal_group'
        gt_groups = subject_df[gt_col].value_counts()
        pred_groups = subject_df[pred_col].value_counts()
        
        print(f"\nTemporal Grouping:")
        print(f"Actual Group Distribution: {dict(gt_groups)}")
        print(f"Predicted Group Distribution: {dict(pred_groups)}")
        
        # 혼동 행렬(confusion matrix) 출력
        confusion = pd.crosstab(subject_df[gt_col], subject_df[pred_col])
        confusion.to_csv(f'{args.output_path}/eval/{subject}_temporal_mapping.csv', index=True)
        
        # temporal_group 매핑 테이블 시각화
        plt.figure(figsize=(10, 8))
        sns.heatmap(confusion, annot=True, fmt='d', cmap='Blues', cbar=True)
        plt.title(f'Subject ID: {subject} - Temporal Grouping Mapping Table')
        plt.xlabel('Predicted Group')
        plt.ylabel('Actual Group')
        plt.tight_layout()
        plt.savefig(f'{args.output_path}/eval/{subject}_temporal_mapping.png')
        plt.close()
        
        # entity_group 분석
        gt_col, pred_col = 'gt_entity_group', 'LLM_cluster'
        gt_groups = subject_df[gt_col].value_counts()
        pred_groups = subject_df[pred_col].value_counts()
        
        print(f"\nEntity Grouping:")
        print(f"Actual Group Distribution: {dict(gt_groups)}")
        print(f"Predicted Group Distribution: {dict(pred_groups)}")
        
        # 혼동 행렬(confusion matrix) 출력
        confusion = pd.crosstab(subject_df[gt_col], subject_df[pred_col])
        confusion.to_csv(f'{args.output_path}/eval/{subject}_entity_mapping.csv', index=True)
        
        # entity_group 매핑 테이블 시각화
        # entity_group는 클래스가 많을 수 있으므로 크기 조정
        plt.figure(figsize=(max(12, len(pred_groups)//2), max(10, len(gt_groups)//2)))
        sns.heatmap(confusion, annot=True, fmt='d', cmap='Blues', cbar=True)
        plt.title(f'Subject ID: {subject} - Entity Grouping Mapping Table')
        plt.xlabel('Predicted Group')
        plt.ylabel('Actual Group')
        plt.tight_layout()
        plt.savefig(f'{args.output_path}/eval/{subject}_entity_mapping.png')
        plt.close()
        
        # 동일 텍스트 다른 그룹 분류 분석
        print("\n=== 동일 텍스트 다른 그룹 분류 분석 ===")
        
        # concatenated 필드가 있는 경우만 실행
        if 'concatenated' in subject_df.columns:
            # concatenated 텍스트별로 그룹화
            text_groups = {}
            for idx, row in subject_df.iterrows():
                concat_text = row['concatenated']
                if concat_text not in text_groups:
                    text_groups[concat_text] = []
                text_groups[concat_text].append((idx, row['gt_entity_group'], row['LLM_cluster']))
            
            # 동일 텍스트가 다른 GT 그룹이나 예측 그룹으로 분류된 경우 찾기
            inconsistent_texts = []
            for text, items in text_groups.items():
                if len(items) > 1:
                    gt_groups = set([item[1] for item in items])
                    pred_groups = set([item[2] for item in items])
                    
                    # GT 그룹이나 예측 그룹이 여러 개인 경우
                    if len(gt_groups) > 1 or len(pred_groups) > 1:
                        inconsistent_texts.append((text, gt_groups, pred_groups, items))
                        print(f"\n동일 텍스트가 다른 그룹으로 분류됨:")
                        print(f"  텍스트: {text[:100]}..." if len(text) > 100 else f"  텍스트: {text}")
                        print(f"  GT 그룹: {gt_groups}")
                        print(f"  예측 그룹: {pred_groups}")
                        print("  항목 세부 정보:")
                        for item in items:
                            print(f"    * 항목 ID: {item[0]}, GT: {item[1]}, 예측: {item[2]}")
            
            # 불일치 텍스트 결과 저장
            if inconsistent_texts:
                with open(f'{args.output_path}/eval/{subject}_inconsistent_texts.txt', 'w') as f:
                    f.write(f"Subject ID: {subject} - 동일 텍스트 다른 그룹 분류 분석\n\n")
                    for text, gt_groups, pred_groups, items in inconsistent_texts:
                        f.write(f"텍스트: {text}\n")
                        f.write(f"GT 그룹: {gt_groups}\n")
                        f.write(f"예측 그룹: {pred_groups}\n")
                        f.write("항목 세부 정보:\n")
                        for item in items:
                            f.write(f"  * 항목 ID: {item[0]}, GT: {item[1]}, 예측: {item[2]}\n")
                        f.write("\n")
        else:
            print("  'concatenated' 필드가 데이터프레임에 없습니다.")
    
    # 전체 데이터에 대한 동일 텍스트 다른 그룹 분류 분석
    print("\n=== 전체 데이터에 대한 동일 텍스트 다른 그룹 분류 분석 ===")
    
    if 'concatenated' in clustered_df.columns:
        # 전체 데이터에서 동일 텍스트 분석
        overall_text_groups = {}
        for idx, row in clustered_df.iterrows():
            concat_text = row['concatenated']
            subject_id = row['subject_id']
            if concat_text not in overall_text_groups:
                overall_text_groups[concat_text] = []
            overall_text_groups[concat_text].append((idx, subject_id, row['gt_entity_group'], row['LLM_cluster']))
        
        # 동일 텍스트가 다른 GT 그룹이나 예측 그룹으로 분류된 경우 찾기
        overall_inconsistent_texts = []
        for text, items in overall_text_groups.items():
            if len(items) > 1:
                gt_groups = set([item[2] for item in items])
                pred_groups = set([item[3] for item in items])
                
                # GT 그룹이나 예측 그룹이 여러 개인 경우
                if len(gt_groups) > 1 or len(pred_groups) > 1:
                    overall_inconsistent_texts.append((text, gt_groups, pred_groups, items))
        
        # 불일치 텍스트 결과 저장
        if overall_inconsistent_texts:
            with open(f'{args.output_path}/eval/overall_inconsistent_texts.txt', 'w') as f:
                f.write("전체 데이터 - 동일 텍스트 다른 그룹 분류 분석\n\n")
                for text, gt_groups, pred_groups, items in overall_inconsistent_texts:
                    f.write(f"텍스트: {text}\n")
                    f.write(f"GT 그룹: {gt_groups}\n")
                    f.write(f"예측 그룹: {pred_groups}\n")
                    f.write("항목 세부 정보:\n")
                    for item in items:
                        f.write(f"  * 항목 ID: {item[0]}, 환자 ID: {item[1]}, GT: {item[2]}, 예측: {item[3]}\n")
                    f.write("\n")
            
            print(f"전체 데이터에서 {len(overall_inconsistent_texts)}개의 동일 텍스트가 다른 그룹으로 분류됨")
            print(f"자세한 내용은 '{args.output_path}/eval/overall_inconsistent_texts.txt' 파일 참조")
    else:
        print("  'concatenated' 필드가 데이터프레임에 없습니다.")
    
    print("\n========== Debug information ==========")
    print(f"Total items: {len(clustered_df)}")
    print(f"Total subjects: {len(clustered_df['subject_id'].unique())}")
    
    print("\n=== Full dataset evaluation results ===\n")
    print("temporal_group metrics:")
    print(f"Adjusted Rand Index: {temporal_ari:.4f}")
    print(f"Normalized Mutual Information: {temporal_nmi:.4f}")
    print(f"Purity: {purity_fscore['temporal_purity']:.4f}")
    print(f"F1-score: {purity_fscore['temporal_f1']:.4f}")
    print(f"Precision: {purity_fscore['temporal_precision']:.4f}")
    print(f"Recall: {purity_fscore['temporal_recall']:.4f}")
    print(f"Temporal Group Accuracies: {temporal_group_accuracies}")
    
    print("\nentity_group metrics:")
    print(f"Adjusted Rand Index: {entity_ari:.4f}")
    print(f"Normalized Mutual Information: {entity_nmi:.4f}")
    print(f"Purity: {purity_fscore['entity_purity']:.4f}")
    print(f"F1-score: {purity_fscore['entity_f1']:.4f}")
    print(f"Precision: {purity_fscore['entity_precision']:.4f}")
    print(f"Recall: {purity_fscore['entity_recall']:.4f}")
    print(f"Text F1: {text_f1:.4f}")
    print("\n=== case analysis ===\n")

    
    # Print tables in console using tabulate
    print("\nOverall Evaluation Metrics:")
    print(tabulate(overall_metrics, headers="firstrow", tablefmt="grid"))
    
    print("\nSubject-wise Average Metrics:")
    print(tabulate(subject_avg_metrics, headers="firstrow", tablefmt="grid"))
    
    print("\nTop 5 Subjects by Temporal Grouping Performance:")
    print(tabulate(top_temporal_subjects, headers="firstrow", tablefmt="grid"))
    
    print("\nTop 5 Subjects by Entity Grouping Performance:")
    print(tabulate(top_entity_subjects, headers="firstrow", tablefmt="grid"))
    
    # Continue with existing code for printing subject performance
    for subject_id, ari, purity, f1 in temporal_performance:
        print(f"subject_id: {subject_id}, temporal_group ARI: {ari:.4f}, Purity: {purity:.4f}, F1: {f1:.4f}")
    
    for subject_id, ari, purity, f1 in entity_performance:
        print(f"subject_id: {subject_id}, entity_group ARI: {ari:.4f}, Purity: {purity:.4f}, F1: {f1:.4f}")
    
    print("\n=== Low performance subjects ===\n")
    low_perf_subjects = analyze_low_performance_subjects(clustered_df, threshold=0.6)
    
    print(f"\nLow performance subjects: {len(low_perf_subjects)}")
    print("Low performance subjects list:")
    for subject_id, ari in sorted(low_perf_subjects, key=lambda x: x[1]):
        print(f"  - {subject_id}: ARI = {ari:.4f}")

    metrics = {
        'temporal_ARI': temporal_ari,
        'temporal_NMI': temporal_nmi,
        'temporal_purity': purity_fscore['temporal_purity'],
        'temporal_f1': purity_fscore['temporal_f1'],
        'temporal_precision': purity_fscore['temporal_precision'],
        'temporal_recall': purity_fscore['temporal_recall'],
        'entity_ARI': entity_ari,
        'entity_NMI': entity_nmi,
        'entity_purity': purity_fscore['entity_purity'],
        'entity_f1': purity_fscore['entity_f1'],
        'entity_precision': purity_fscore['entity_precision'],
        'entity_recall': purity_fscore['entity_recall'],
        'text_f1': text_f1
    }
    
    return metrics

def analyze_zero_ari_subjects(df):
    from sklearn.metrics import adjusted_rand_score
    import pandas as pd
    import numpy as np
    from collections import Counter
    
    # ARI가 0인 주제 찾기
    zero_ari_subjects = []
    
    for subject_id in df['subject_id'].unique():
        subject_df = df[df['subject_id'] == subject_id]
        
        # 시간적 그룹화 ARI 계산
        temporal_ari = adjusted_rand_score(
            subject_df['gt_temporal_group'].fillna('nan').astype(str),
            subject_df['temporal_group'].fillna('nan').astype(str)
        )
        
        if temporal_ari == 0:
            zero_ari_subjects.append(subject_id)
    
    print(f"\n=== 시간적 그룹화 ARI가 0인 주제 분석 ({len(zero_ari_subjects)}개) ===\n")
    
    analysis_results = {}
    
    for subject_id in zero_ari_subjects:
        subject_df = df[df['subject_id'] == subject_id]
        
        print(f"\n## 주제 ID: {subject_id} ##")
        
        # 1. 기본 정보
        n_items = len(subject_df)
        n_gt_groups = subject_df['gt_temporal_group'].nunique()
        n_pred_groups = subject_df['temporal_group'].nunique()
        
        print(f"항목 수: {n_items}")
        print(f"GT 시간 그룹 수: {n_gt_groups}")
        print(f"예측 시간 그룹 수: {n_pred_groups}")
        
        # 2. GT 시간 그룹 분포
        gt_group_counts = subject_df['gt_temporal_group'].value_counts().sort_index()
        print("\nGT 시간 그룹 분포:")
        for group, count in gt_group_counts.items():
            print(f"  - 그룹 {group}: {count}개 항목 ({count/n_items*100:.1f}%)")
        
        # 3. 예측 시간 그룹 분포
        pred_group_counts = subject_df['temporal_group'].value_counts().sort_index()
        print("\n예측 시간 그룹 분포:")
        for group, count in pred_group_counts.items():
            print(f"  - 그룹 {group}: {count}개 항목 ({count/n_items*100:.1f}%)")
        
        # 4. 교차표 분석
        cross_tab = pd.crosstab(subject_df['gt_temporal_group'], subject_df['temporal_group'])
        print("\nGT vs 예측 시간 그룹 교차표:")
        print(cross_tab)
        
        # 5. 문제 패턴 분석
        print("\n문제 패턴 분석:")
        
        # 5.1 모든 항목이 하나의 예측 그룹으로 할당된 경우
        if n_pred_groups == 1:
            print("  - 문제: 모든 항목이 하나의 예측 그룹으로 할당됨")
            print(f"    * 예측 그룹: {subject_df['temporal_group'].iloc[0]}")
            
        # 5.2 GT 그룹과 예측 그룹 수가 다른 경우
        elif n_gt_groups != n_pred_groups:
            print(f"  - 문제: GT 그룹 수({n_gt_groups})와 예측 그룹 수({n_pred_groups})가 다름")
            
        # 5.3 GT 그룹과 예측 그룹이 일치하지 않는 경우
        else:
            # 각 GT 그룹별로 가장 많이 할당된 예측 그룹 찾기
            for gt_group in subject_df['gt_temporal_group'].unique():
                gt_items = subject_df[subject_df['gt_temporal_group'] == gt_group]
                pred_groups = gt_items['temporal_group'].value_counts()
                most_common_pred = pred_groups.index[0]
                
                print(f"  - GT 그룹 {gt_group}:")
                print(f"    * 항목 수: {len(gt_items)}")
                print(f"    * 가장 많이 할당된 예측 그룹: {most_common_pred} ({pred_groups[most_common_pred]}개, {pred_groups[most_common_pred]/len(gt_items)*100:.1f}%)")
                
                if len(pred_groups) > 1:
                    print("    * 다른 예측 그룹으로 할당된 항목:")
                    for pred_group, count in pred_groups.items():
                        if pred_group != most_common_pred:
                            print(f"      - 그룹 {pred_group}: {count}개 항목")
                            examples = gt_items[gt_items['temporal_group'] == pred_group].head(2)
                            for _, row in examples.iterrows():
                                print(f"        * 예시: '{row['concatenated']}' (날짜: {row.get('date', 'N/A')})")
        
        # 6. 날짜 정보 분석 (날짜 정보가 있는 경우)
        if 'date' in subject_df.columns:
            print("\n날짜 정보 분석:")
            
            # 날짜를 datetime으로 변환
            subject_df['date_dt'] = pd.to_datetime(subject_df['date'], errors='coerce')
            
            # GT 그룹별 날짜 범위
            for gt_group in subject_df['gt_temporal_group'].unique():
                gt_items = subject_df[subject_df['gt_temporal_group'] == gt_group]
                
                if gt_items['date_dt'].notna().any():
                    min_date = gt_items['date_dt'].min()
                    max_date = gt_items['date_dt'].max()
                    date_range = (max_date - min_date).days
                    
                    print(f"  - GT 그룹 {gt_group}:")
                    print(f"    * 날짜 범위: {min_date.date()} ~ {max_date.date()} ({date_range}일)")
                    
                    # 예측 그룹별 날짜 분포
                    for pred_group in gt_items['temporal_group'].unique():
                        pred_items = gt_items[gt_items['temporal_group'] == pred_group]
                        
                        if pred_items['date_dt'].notna().any():
                            pred_min_date = pred_items['date_dt'].min()
                            pred_max_date = pred_items['date_dt'].max()
                            pred_date_range = (pred_max_date - pred_min_date).days
                            
                            print(f"    * 예측 그룹 {pred_group} 날짜 범위: {pred_min_date.date()} ~ {pred_max_date.date()} ({pred_date_range}일)")
        
        # 7. 개선 제안
        print("\n개선 제안:")
        
        if n_pred_groups == 1:
            print("  - 시간적 그룹화 알고리즘이 이 주제에 대해 다양한 그룹을 생성하지 못함")
            print("  - 날짜 정보를 더 효과적으로 활용하거나, 시간 관련 텍스트 특성을 추출하는 방법 고려")
            
        elif n_gt_groups != n_pred_groups:
            print(f"  - 예측 그룹 수({n_pred_groups})를 GT 그룹 수({n_gt_groups})에 맞게 조정")
            print("  - 클러스터링 알고리즘의 파라미터 조정 또는 다른 알고리즘 시도")
            
        else:
            print("  - 예측 그룹과 GT 그룹 간의 매핑 개선")
            print("  - 날짜 정보를 더 정확하게 활용하는 방법 고려")
        
        # 결과 저장
        analysis_results[subject_id] = {
            'n_items': n_items,
            'n_gt_groups': n_gt_groups,
            'n_pred_groups': n_pred_groups,
            'gt_group_counts': gt_group_counts.to_dict(),
            'pred_group_counts': pred_group_counts.to_dict(),
            'cross_tab': cross_tab.to_dict()
        }
    
    return analysis_results

def compare_zero_and_high_ari_subjects(df):
    from sklearn.metrics import adjusted_rand_score
    import pandas as pd
    import numpy as np
    
    # 주제별 ARI 계산
    subject_aris = {}
    
    for subject_id in df['subject_id'].unique():
        subject_df = df[df['subject_id'] == subject_id]
        
        # 시간적 그룹화 ARI 계산
        temporal_ari = adjusted_rand_score(
            subject_df['gt_temporal_group'].fillna('nan').astype(str),
            subject_df['temporal_group'].fillna('nan').astype(str)
        )
        
        subject_aris[subject_id] = temporal_ari
    
    # ARI가 0인 주제와 가장 높은 주제 찾기
    zero_ari_subjects = [subject_id for subject_id, ari in subject_aris.items() if ari == 0]
    high_ari_subject = max(subject_aris.items(), key=lambda x: x[1])[0]
    
    print(f"\n=== ARI가 0인 주제와 가장 높은 주제(ARI={subject_aris[high_ari_subject]:.4f}) 비교 ===\n")
    
    # 가장 높은 ARI를 가진 주제 분석
    high_ari_df = df[df['subject_id'] == high_ari_subject]
    
    print(f"## 높은 ARI 주제 ID: {high_ari_subject} ##")
    print(f"항목 수: {len(high_ari_df)}")
    print(f"GT 시간 그룹 수: {high_ari_df['gt_temporal_group'].nunique()}")
    print(f"예측 시간 그룹 수: {high_ari_df['temporal_group'].nunique()}")
    
    # 교차표
    high_cross_tab = pd.crosstab(high_ari_df['gt_temporal_group'], high_ari_df['temporal_group'])
    print("\nGT vs 예측 시간 그룹 교차표 (높은 ARI):")
    print(high_cross_tab)
    
    # 날짜 정보 분석 (날짜 정보가 있는 경우)
    if 'date' in high_ari_df.columns:
        print("\n날짜 정보 분석 (높은 ARI):")
        
        # 날짜를 datetime으로 변환
        high_ari_df['date_dt'] = pd.to_datetime(high_ari_df['date'], errors='coerce')
        
        # GT 그룹별 날짜 범위
        for gt_group in high_ari_df['gt_temporal_group'].unique():
            gt_items = high_ari_df[high_ari_df['gt_temporal_group'] == gt_group]
            
            if gt_items['date_dt'].notna().any():
                min_date = gt_items['date_dt'].min()
                max_date = gt_items['date_dt'].max()
                date_range = (max_date - min_date).days
                
                print(f"  - GT 그룹 {gt_group}:")
                print(f"    * 날짜 범위: {min_date.date()} ~ {max_date.date()} ({date_range}일)")
    
    # ARI가 0인 주제 중 하나 선택하여 비교
    if zero_ari_subjects:
        zero_ari_subject = zero_ari_subjects[0]
        zero_ari_df = df[df['subject_id'] == zero_ari_subject]
        
        print(f"\n## ARI=0 주제 ID: {zero_ari_subject} ##")
        print(f"항목 수: {len(zero_ari_df)}")
        print(f"GT 시간 그룹 수: {zero_ari_df['gt_temporal_group'].nunique()}")
        print(f"예측 시간 그룹 수: {zero_ari_df['temporal_group'].nunique()}")
        
        # 교차표
        zero_cross_tab = pd.crosstab(zero_ari_df['gt_temporal_group'], zero_ari_df['temporal_group'])
        print("\nGT vs 예측 시간 그룹 교차표 (ARI=0):")
        print(zero_cross_tab)
        
        # 날짜 정보 분석 (날짜 정보가 있는 경우)
        if 'date' in zero_ari_df.columns:
            print("\n날짜 정보 분석 (ARI=0):")
            
            # 날짜를 datetime으로 변환
            zero_ari_df['date_dt'] = pd.to_datetime(zero_ari_df['date'], errors='coerce')
            
            # GT 그룹별 날짜 범위
            for gt_group in zero_ari_df['gt_temporal_group'].unique():
                gt_items = zero_ari_df[zero_ari_df['gt_temporal_group'] == gt_group]
                
                if gt_items['date_dt'].notna().any():
                    min_date = gt_items['date_dt'].min()
                    max_date = gt_items['date_dt'].max()
                    date_range = (max_date - min_date).days
                    
                    print(f"  - GT 그룹 {gt_group}:")
                    print(f"    * 날짜 범위: {min_date.date()} ~ {max_date.date()} ({date_range}일)")
        
        # 주요 차이점 분석
        print("\n## 주요 차이점 분석 ##")
        
        # 1. 그룹 수 차이
        print(f"1. 그룹 수:")
        print(f"  - 높은 ARI 주제: GT {high_ari_df['gt_temporal_group'].nunique()}개, 예측 {high_ari_df['temporal_group'].nunique()}개")
        print(f"  - ARI=0 주제: GT {zero_ari_df['gt_temporal_group'].nunique()}개, 예측 {zero_ari_df['temporal_group'].nunique()}개")
        
        # 2. 그룹 분포 차이
        print("\n2. 그룹 분포:")
        print("  - 높은 ARI 주제 GT 분포:")
        for group, count in high_ari_df['gt_temporal_group'].value_counts().sort_index().items():
            print(f"    * 그룹 {group}: {count}개 ({count/len(high_ari_df)*100:.1f}%)")
        
        print("  - ARI=0 주제 GT 분포:")
        for group, count in zero_ari_df['gt_temporal_group'].value_counts().sort_index().items():
            print(f"    * 그룹 {group}: {count}개 ({count/len(zero_ari_df)*100:.1f}%)")
        
        # 3. 날짜 패턴 차이 (날짜 정보가 있는 경우)
        if 'date' in df.columns:
            print("\n3. 날짜 패턴:")
            
            # 높은 ARI 주제의 날짜 간격
            high_ari_df['date_dt'] = pd.to_datetime(high_ari_df['date'], errors='coerce')
            if high_ari_df['date_dt'].notna().any() and len(high_ari_df['date_dt'].unique()) > 1:
                high_dates = sorted(high_ari_df['date_dt'].unique())
                high_intervals = [(high_dates[i+1] - high_dates[i]).days for i in range(len(high_dates)-1)]
                
                print(f"  - 높은 ARI 주제 날짜 간격: {high_intervals} (일)")
                print(f"    * 평균 간격: {np.mean(high_intervals):.1f}일")
                print(f"    * 최소 간격: {min(high_intervals)}일")
                print(f"    * 최대 간격: {max(high_intervals)}일")
            
            # ARI=0 주제의 날짜 간격
            zero_ari_df['date_dt'] = pd.to_datetime(zero_ari_df['date'], errors='coerce')
            if zero_ari_df['date_dt'].notna().any() and len(zero_ari_df['date_dt'].unique()) > 1:
                zero_dates = sorted(zero_ari_df['date_dt'].unique())
                zero_intervals = [(zero_dates[i+1] - zero_dates[i]).days for i in range(len(zero_dates)-1)]
                
                print(f"  - ARI=0 주제 날짜 간격: {zero_intervals} (일)")
                print(f"    * 평균 간격: {np.mean(zero_intervals):.1f}일")
                print(f"    * 최소 간격: {min(zero_intervals)}일")
                print(f"    * 최대 간격: {max(zero_intervals)}일")
                
class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NumpyEncoder, self).default(obj)

class Finding(BaseModel):
    IDX: int = Field(description="The index number of the finding from the input")
    DAY: int = Field(description="The day number when the finding was observed")
    finding: str = Field(description="The description of the finding")

class Episode(BaseModel):
    episode: int = Field(description="Sequential episode number")
    days: List[int] = Field(description="Array of day numbers that belong to this episode")

class FindingGroup(BaseModel):
    group_name: str = Field(description="Name of the finding group")
    findings: List[Finding] = Field(description="List of all findings in this group")
    episodes: List[Episode] = Field(description="Temporal groupings of these findings")
    rationale: str = Field(description="Explanation for the grouping decisions")

class RadiologyOutput(BaseModel):
    """Output schema for normalized and grouped radiological findings"""
    results: List[FindingGroup] = Field(description="List of finding groups including episodes and findings")