"""
Baichuan: 
CUDA_VISIBLE_DEVICES=0,1 vllm serve baichuan-inc/Baichuan-M1-14B-Base  --trust-remote-code --port 8004  --tensor-parallel-size 2
"""
import pandas as pd
import json
import os
import argparse
import datetime
import time
import glob
import os
from llm_utils import *
import torch

parser = argparse.ArgumentParser()
parser.add_argument('--LLM_name', type=str, 
                    default='qwen3-235b-a22b', 
                    choices=['baichuan',
                            'vllm_medgemma',
                            'MedGemma-27b-it',
                            'gpt-4.1', 
                            'qwen3-235b-a22b', 
                            'deepseek-v3-0324',
                            'llama4-maverick-instruct-basic',
                            'vllm_gpt-oss-120b',
                            'vllm_gpt-oss-20b'])

parser.add_argument('--model_run', action='store_true')
parser.add_argument('--batch_path', type=str, default='./batch_files')
parser.add_argument('--mode', type=str, default='api_only', choices=['api_only'])
parser.add_argument('--all_eval', action='store_true')
parser.add_argument('--subset', type=list, default=['p10274145', 'p10523725', 'p10886362', 'p10959054', 'p12433421', 
                                                    'p15321868', 'p15446959', 'p15881535', 'p17720924', 'p18079481'
                                                    ])
parser.add_argument('--process_missing', action='store_true')
parser.add_argument('--few_shot', action='store_true')
parser.add_argument('--eval_section', type=str, default=None)

parser.add_argument('--input_path', type=str, default='./dataset/Lunguage.csv')
parser.add_argument('--output_path', type=str, default='./sequentialSR/results')
# Add API key arguments
parser.add_argument('--api_key', type=str, default=None, help='API key for LLM API')
parser.add_argument('--port', type=str, default='8003', help='Port for local LLM server')

args = parser.parse_args()

# Set API keys from arguments if provided
if args.api_key:
    os.environ["API_KEY"] = args.api_key
if args.port:
    os.environ["PORT"] = args.port

if args.mode == 'api_only':
    if args.all_eval:
        if args.few_shot:
            args.output_path = f'{args.output_path}/{args.LLM_name}/api_only/all_eval/few_shot/{args.input_path.split("/")[-1].split(".")[0]}'
            args.batch_path = f'{args.batch_path}/{args.LLM_name}/api_only/all_eval/few_shot/{args.input_path.split("/")[-1].split(".")[0]}'
        else:
            args.output_path = f'{args.output_path}/{args.LLM_name}/api_only/all_eval/zero_shot/{args.input_path.split("/")[-1].split(".")[0]}'
            args.batch_path = f'{args.batch_path}/{args.LLM_name}/api_only/all_eval/zero_shot/{args.input_path.split("/")[-1].split(".")[0]}'
    else:
        if args.few_shot:
            args.output_path = f'{args.output_path}/{args.LLM_name}/api_only/subset{len(args.subset)}/few_shot/{args.input_path.split("/")[-1].split(".")[0]}'
            args.batch_path = f'{args.batch_path}/{args.LLM_name}/api_only/subset{len(args.subset)}/few_shot/{args.input_path.split("/")[-1].split(".")[0]}'
        else:
            args.output_path = f'{args.output_path}/{args.LLM_name}/api_only/subset{len(args.subset)}/zero_shot/{args.input_path.split("/")[-1].split(".")[0]}'
            args.batch_path = f'{args.batch_path}/{args.LLM_name}/api_only/subset{len(args.subset)}/zero_shot/{args.input_path.split("/")[-1].split(".")[0]}'


os.makedirs(args.output_path, exist_ok=True)

client, tokenizer = initialize_llm_client(args.LLM_name)


def load_and_process_data(input_df):
    input_df.loc[:, 'location'] = input_df['location'].str.replace(r'loc:\s*', '', regex=True).str.replace(r'det:\s*', '', regex=True)
    
    if 'ELA_cur_ent' not in input_df.columns:
        # Define the columns to concatenate
        RELATIONS = [
            'location', 'morphology', 'distribution', 'measurement', 'severity', 
            'comparison', 'onset', 'no change', 'improved', 'worsened', 
            'placement', 'past hx', 'other source', 'assessment limitations'
        ]
        
        if 'ent' in input_df.columns:        
            # Create a new column that concatenates the entity with all attribute values
            input_df['ELA_cur_ent'] = input_df.apply(
                lambda row: ' '.join([row['ent']] + 
                                    [str(row[col]) for col in RELATIONS 
                                    if pd.notna(row[col]) and row[col] != '']), 
                axis=1
            )
        else:
            input_df['ELA_cur_ent'] = input_df.apply(
                lambda row: ' '.join([row['entity']] + 
                                    [str(row[col]) for col in RELATIONS 
                                    if pd.notna(row[col]) and row[col] != '']), 
                axis=1
            )
    if 'sequence' not in input_df.columns:
        gold = pd.read_csv('./dataset/Lunguage.csv')
        sequence_mapping = gold[['study_id', 'sequence']].drop_duplicates().set_index('study_id')['sequence'].to_dict()
        input_df['sequence'] = input_df['study_id'].map(sequence_mapping)
                    
    if 'section' not in input_df.columns and args.eval_section:
        input_df['section'] = args.eval_section
        
    if 'cluster_name' not in input_df.columns:
        input_df['cluster_name'] = 'ANALYZE THE FOLLOWING OBSERVATIONS'
        
    def format_study_time(time_str):
        """Format study time string to timedelta"""
        if pd.isna(time_str):
            return pd.Timedelta(0)
        
        try:
            time_str = str(time_str).strip()
            main_part = time_str.split('.')[0].zfill(6)
            hours = int(main_part[:2])
            minutes = int(main_part[2:4])
            seconds = int(main_part[4:6])
            return pd.Timedelta(hours=hours, minutes=minutes, seconds=seconds)
        except (ValueError, IndexError) as e:
            print(f"Warning: Could not parse time value: {time_str}, Error: {e}")
            return pd.Timedelta(0)
        
    def calculate_time_from_first(df):
        """Calculate time differences from the first sequence (lowest sequence number) for each subject"""
        # Ensure sequence is integer type and StudyDateTime is datetime type
        df = df.copy()
        df['sequence'] = df['sequence'].astype(int)
        
        # Check if StudyDateTime column exists
        if 'StudyDateTime' not in df.columns:
            print("Error: StudyDateTime column not found in dataframe")
            return pd.DataFrame(columns=['subject_id', 'sequence', 'ent_idx', 'time_from_first', 
                                        'StudyDateTime', 'FirstStudyDateTime'])
        
        # Convert StudyDateTime to datetime if it's not already
        if df['StudyDateTime'].dtype == 'object':
            df['StudyDateTime'] = pd.to_datetime(df['StudyDateTime'])
        
        # Get first (lowest) sequence dates for each subject
        first_sequences = (df[['subject_id', 'StudyDateTime']]
                        .dropna(subset=['StudyDateTime'])
                        .sort_values('StudyDateTime')
                        .groupby('subject_id').first()
                        .reset_index()
                        .rename(columns={'StudyDateTime': 'FirstStudyDateTime'}))
        
        # Calculate differences from first sequence
        return (df[['subject_id', 'ent_idx',  'sequence', 'StudyDateTime']]
                # .drop_duplicates(['subject_id', 'sequence'])
                .sort_values(['subject_id', 'sequence'])
                .merge(first_sequences, on='subject_id', how='left')
                .assign(time_from_first=lambda x: x['StudyDateTime'] - x['FirstStudyDateTime'])
                [['subject_id', 'sequence', 'ent_idx', 'time_from_first', 'StudyDateTime', 'FirstStudyDateTime']])
                        
    if not all(col in input_df.columns for col in ['StudyDateTime']):
        # raise ValueError("StudyDateTime column not found in dataframe, please put mimic-cxr-2.0.0-metadata.csv in the same directory as the input file")
        metadata_path = './mimic-cxr-2.0.0-metadata.csv'
        if not os.path.exists(metadata_path):
            raise FileNotFoundError(
                "MIMIC-CXR metadata file not found. Please download 'mimic-cxr-2.0.0-metadata.csv' from PhysioNet "
                "(https://physionet.org/content/mimic-cxr/2.0.0/) and place it in the 'sequentialSR' directory."
            )
            
        mimic_metadata = pd.read_csv(metadata_path)
        mimic_metadata['StudyDate'] = pd.to_datetime(mimic_metadata['StudyDate'].astype(str).str.zfill(8), format='%Y%m%d')
        mimic_metadata['StudyTime'] = mimic_metadata['StudyTime'].apply(format_study_time)
        mimic_metadata['StudyDateTime'] = mimic_metadata['StudyDate'] + mimic_metadata['StudyTime']
        mimic_metadata['study_id'] = mimic_metadata['study_id'].astype(str).apply(lambda x: 's' + x if not x.startswith('s') else x)
        
        datetime_mapping = (mimic_metadata[['study_id', 'StudyDate', 'StudyTime', 'StudyDateTime']]
                        .dropna()
                        .sort_values('StudyDateTime')
                        .groupby('study_id').first()
                        .reset_index())
        input_df = input_df.merge(datetime_mapping, on=['study_id'], how='left')    


    if 'time_from_first' not in input_df.columns:
        time_diffs = calculate_time_from_first(input_df)
        input_df = input_df.merge(time_diffs, on=['subject_id', 'sequence', 'ent_idx'], how='left')
    
    if 'day_from_first' not in input_df.columns:
        input_df['day_from_first'] = input_df['time_from_first'].apply(lambda x: '0 days' if pd.isna(x) or pd.Timedelta(x).days == 0 else f"{pd.Timedelta(x).days} days")
    
    return input_df
                
def run_llm(total_cost, clustered_df, missing_data=None, is_missing_process=False):
    output_df = pd.DataFrame()
    if not is_missing_process:
        if not args.all_eval:
            clustered_df = clustered_df[clustered_df['subject_id'].isin(args.subset)]
        
        print("number of subject_id: ", clustered_df.subject_id.nunique())

        for subject_id in clustered_df.subject_id.unique():
            print(f'Patient: "{subject_id}"')    
            print("number of groups: ", len(clustered_df[clustered_df['subject_id'] == subject_id].cluster_name.unique()))
            for group_name in clustered_df[clustered_df['subject_id'] == subject_id].cluster_name.unique():
                print(f'  Group: "{group_name}"')
                
                cur_group_df = clustered_df[(clustered_df['subject_id'] == subject_id) &
                                (clustered_df['cluster_name'] == group_name)]
                
                # Exclude groups with only one observation
                if len(cur_group_df["ELA_cur_ent"]) <= 1:
                    print("Skipping group with only one observation", group_name)
                    continue
                
                dict_obs, ent_idx_info, seq_info = {}, {}, {}
                idx_counter = 0
                for idx, obs in enumerate(cur_group_df['ELA_cur_ent'].to_list()):
                    day = cur_group_df['day_from_first'].to_list()[idx]
                    status = cur_group_df['dx_status'].to_list()[idx]
                    day_num = int(day.split()[0])  # Extract the number from "X days"
                    dict_obs[f"IDX:{idx_counter}, DAY: {day_num}, status: {status}"] = [obs]  # Create single-item list for each observation
                    ent_idx_info[f"IDX:{idx_counter}, DAY: {day_num}"] = cur_group_df['ent_idx'].to_list()[idx]
                    seq_info[f"IDX:{idx_counter}, DAY: {day_num}"] = cur_group_df['sequence'].to_list()[idx]
                    idx_counter += 1

                # Should be:
                input_json = {
                    "cluster_name": group_name,
                    "observations": dict_obs
                }
                Input = json.dumps(input_json, ensure_ascii=False)
                
                conversation = generate_few_shot(Input, args)
                conversation = convert_sets_to_lists(conversation)
                
                # Ensure each message has the correct format
                formatted_conversation = []
                for msg in conversation:
                    if isinstance(msg, dict) and 'role' in msg and 'content' in msg:
                        # If content is a dict or list, convert it to a string
                        if isinstance(msg['content'], (dict, list)):
                            msg['content'] = str(msg['content'])
                        formatted_conversation.append(msg)
                    else:
                        print(f"Skipping invalid message format: {msg}")
            
                try:
                    if args.LLM_name.startswith('MedGemma'):
                        inputs = tokenizer.apply_chat_template(
                            conversation,
                            add_generation_prompt=True,
                            tokenize=True,
                            return_dict=True,
                            return_tensors="pt"
                        )
                        with torch.inference_mode():
                            output = client.generate(
                                **inputs.to('cuda'),
                                max_new_tokens=8092
                            )
                            generation = output[0][inputs["input_ids"].shape[1]:]
                            response = tokenizer.decode(generation, skip_special_tokens=True)
                    
                    elif args.LLM_name.startswith('baichuan') or args.LLM_name.startswith('vllm_medgemma') or args.LLM_name.startswith('vllm_gpt-oss-120b') or args.LLM_name.startswith('vllm_gpt-oss-20b'):
                        response = client.chat.completions.create(
                            model= "baichuan-inc/Baichuan-M1-14B-Instruct" if args.LLM_name.startswith('baichuan') else "medgemma-27b-text-it" if args.LLM_name.startswith('vllm_medgemma') else "gpt-oss-20b" if args.LLM_name.startswith('vllm_gpt-oss-20b') else "gpt-oss-120b",
                            messages=conversation,
                            response_model=RadiologyOutput
                            )
                    else:
                        response = client.chat.completions.create(
                        model=f'accounts/fireworks/models/{args.LLM_name}', 
                        max_tokens= 8092,
                        temperature=0.0,
                        messages=formatted_conversation,
                        response_model=RadiologyOutput
                    )
                except Exception as e:
                    print(f"Error processing patient {subject_id}: {e}")
                    continue
                
                input_data = {
                    "batch_idx": "non-batch",
                    "subject_id": subject_id,
                    "Input": Input,
                    "Input_ent_idx": ent_idx_info,
                    "Input_seq": seq_info,
                    "Group_name": group_name
                }
                llm_output_df, stats = process_radiology_output(response, input_data)
                output_df = pd.concat([output_df, llm_output_df])
        output_path = f"{args.output_path}/output_df.csv"
        
    else:
        for subject_id, subject_clusters in missing_data.items():            
            for group_name, cluster_content in subject_clusters.items():
                conversation = generate_few_shot(cluster_content, args)
                
                # Convert any sets to lists and ensure proper message format
                conversation = convert_sets_to_lists(conversation)

                # Ensure each message has the correct format
                formatted_conversation = []
                for msg in conversation:
                    if isinstance(msg, dict) and 'role' in msg and 'content' in msg:
                        # If content is a dict or list, convert it to a string
                        if isinstance(msg['content'], (dict, list)):
                            msg['content'] = str(msg['content'])
                        formatted_conversation.append(msg)
                    else:
                        print(f"Skipping invalid message format: {msg}")

        
                try:
                    if args.LLM_name.startswith('MedGemma'):
                        inputs = tokenizer.apply_chat_template(
                            conversation,
                            add_generation_prompt=True,
                            tokenize=True,
                            return_dict=True,
                            return_tensors="pt"
                        )
                        with torch.inference_mode():
                            output = client.generate(
                                **inputs.to('cuda'),
                                max_new_tokens=8092
                            )
                            generation = output[0][inputs["input_ids"].shape[1]:]
                            response = tokenizer.decode(generation, skip_special_tokens=True)
                    
                    elif args.LLM_name.startswith('baichuan') or args.LLM_name.startswith('vllm_medgemma') or args.LLM_name.startswith('vllm_gpt-oss-120b') or args.LLM_name.startswith('vllm_gpt-oss-20b'):
                        response = client.chat.completions.create(
                            model="baichuan-inc/Baichuan-M1-14B-Instruct" if args.LLM_name.startswith('baichuan') else "medgemma-27b-text-it" if args.LLM_name.startswith('vllm_medgemma') else "gpt-oss-20b" if args.LLM_name.startswith('vllm_gpt-oss-20b') else "gpt-oss-120b",
                            messages=conversation,
                            response_model=RadiologyOutput
                            )
                    else:
                        response = client.chat.completions.create(
                            model=f'accounts/fireworks/models/{args.LLM_name}', 
                            max_tokens= 8092,
                            temperature=0.0,
                            messages=formatted_conversation,
                            response_model=RadiologyOutput
                        )
                except Exception as e:
                    print(f"Error processing patient {subject_id}: {e}")
                    continue

                input_data = {
                    "batch_idx": "non-batch-missing",
                    "Group_name": group_name,
                    "subject_id": subject_id
                }
                llm_output_df, stats = process_radiology_output(response, input_data, is_missing_process=True)
                output_df = pd.concat([output_df, llm_output_df])
        output_path = f"{args.output_path}/missing_outputs.csv"
    
    try:
        output_df.to_csv(output_path, index=False)
        print(f"Successfully saved output_df to {output_path}")
    except Exception as e:
        print(f"Error saving output_df: {e}")
    return output_df

def run_batch(batch_file, is_missing_process=False, iteration=0):
    print(f"\nRun batch processing with {batch_file.split('/')[-1]}")
    file_path = os.path.join(batch_file)
    
    if not os.path.isfile(batch_file):
        raise FileNotFoundError(f"File not found: {file_path}")

    batch_job = create_batch_job(file_path, client, args)
        
    print(f"Successfully created batch job: {batch_job.id}\n")
    batch_job_id = batch_job.id
    check_interval = 2  # 2sec interval
    previous_status = None


    print("-" * 50)
    try:
        while True:
            current_time = datetime.datetime.now().strftime("%H:%M:%S")
            batch = client.batches.retrieve(batch_job_id)
            batch_status = batch.status
            completed_count = batch.request_counts.completed
            total_count = batch.request_counts.total
                            
            if batch_status != previous_status:
                print(f"\n[{current_time}] Status changed: {previous_status} → {batch_status}")
                print(f"Completed: {completed_count}/{total_count}")
                
                if batch_status == "in_progress":
                    print(f"In progress: {completed_count}/{total_count}")
                
                if batch_status in ["completed", "succeeded", "ended"]:
                    print(f"Completed! Result file ID: {batch.output_file_id}")
                    break
                    
                if batch_status in ["failed", "errored"]:
                    print("Failed!")
                    break
                    
                previous_status = batch_status
            else:
                if batch_status == "in_progress" and total_count > 0:
                    completion_percent = (completed_count / total_count) * 100
                    print(f"\r[{current_time}] Progress: {completion_percent:.1f}% ({completed_count}/{total_count})", end="")
                else:
                    print(f"\r[{current_time}] Current status: {batch_status}", end="")
            
            time.sleep(check_interval)
            
    except KeyboardInterrupt:
        print("\n\nMonitoring stopped.")
        cancel_batch = client.batches.cancel(batch_job_id)
        print("Batch cancelled", cancel_batch)
    
    batch = client.batches.retrieve(batch_job_id)
    completed_count = batch.request_counts.completed
    total_count = batch.request_counts.total
        
    print("\n\nFinal batch status:")
    print(f"Status: {batch_status}")
    print(f"Completed: {completed_count}/{total_count}")

    if batch_status in ["completed", "succeeded", "ended"]:
        print("\nBatch job completed!")
        print(f"Result file ID: {batch.output_file_id}")
        result_content = client.files.content(batch.output_file_id).text
        if not is_missing_process:
            with open(f'{args.output_path}/batch_results.jsonl', 'w') as f:
                f.write(result_content)
        else:
            with open(f'{args.output_path}/missing_outputs{iteration}.jsonl', 'w') as f:
                f.write(result_content)
        print("-----")

if __name__ == "__main__":
    clustered_df = pd.read_csv(args.input_path)
    clustered_df = load_and_process_data(clustered_df)
    if args.eval_section:
        clustered_df = clustered_df[clustered_df['section'] == args.eval_section]
        
    if not args.all_eval:
        clustered_df = clustered_df[clustered_df['subject_id'].isin(args.subset)]    
    
    if args.LLM_name.startswith('gpt'):
        print(f"{args.LLM_name} Batch processing start!")        
        all_inputs = creating_batch_file(clustered_df, args)
        iteration = 0
        
        if args.model_run:
            os.makedirs(args.batch_path, exist_ok=True)
            run_batch(f"{args.batch_path}/llm_batch.jsonl", is_missing_process=False)
            batch_result = f"{args.output_path}/batch_results.jsonl"
            llm_output = read_batch_results_to_csv(args, batch_result, clustered_df, all_inputs, is_missing_process=False)
            clustered_df = post_process(llm_output, clustered_df, args.output_path, is_missing_process=False)
            
            unmatched_count = (clustered_df['llm_processed'] == 'unmatched').sum()
            matched_count = (clustered_df['llm_processed'] == 'matched').sum()
            print(f"\n Iteration 0 completed. Remaining unmatched: {unmatched_count}, Matched: {matched_count}")
            clustered_df.drop_duplicates(inplace=True)
            clustered_df.to_csv(f"{args.output_path}/final_processed{iteration}_{unmatched_count}missing.csv", index=False)

        else:
            iteration = 0
            processed_files = glob.glob(f"{args.output_path}/final_processed*_0missing.csv")
            if processed_files:
                # Sort files to get the most recent one
                latest_file = sorted(processed_files)[-1]
                clustered_df = pd.read_csv(latest_file)
                print(f"Loaded processed data from {latest_file}")
            else:
                raise FileNotFoundError(f"No processed data found in {args.output_path}")

        if args.process_missing:
            os.makedirs(args.batch_path, exist_ok=True)
            while (clustered_df['llm_processed'] == 'unmatched').any():
                iteration += 1
                print(f"{args.LLM_name} Missing process iteration {iteration} start!")
                result_data = prepare_missing_inputs(clustered_df)
                missing_dict = {}
            
                for subject_id, subject_clusters in result_data.items():
                    print(f"\n=== Subject ID: {subject_id} ===")
                    missing_dict[subject_id] = {}
            
                    for cluster_name in subject_clusters.keys():
                        missing_input = create_missing_input(subject_id, cluster_name, result_data)
                        missing_dict[subject_id][cluster_name] = missing_input
            
                _ = creating_batch_file(clustered_df, args, missing_data=missing_dict, all_inputs=all_inputs, iteration=iteration)
                run_batch(f"{args.batch_path}/missing_batch{iteration}.jsonl", is_missing_process=True, iteration=iteration)
            
                llm_output = read_batch_results_to_csv(args, f"{args.output_path}/missing_outputs{iteration}.jsonl", clustered_df, all_inputs, is_missing_process=True)
                clustered_df = post_process(llm_output, clustered_df, args.output_path, is_missing_process=True, iteration=iteration)

                # Break the loop if no progress is being made (to prevent infinite loops)
                unmatched_count = (clustered_df['llm_processed'] == 'unmatched').sum()
                matched_count = (clustered_df['llm_processed'] == 'matched').sum()
                print(f"Iteration {iteration} completed. Remaining unmatched: {unmatched_count}, Matched: {matched_count}")
                clustered_df.drop_duplicates(inplace=True)
                clustered_df.to_csv(f"{args.output_path}/final_processed{iteration}_{unmatched_count}missing.csv", index=False)
                
                if (clustered_df['llm_processed'] == 'unmatched').sum() == 0 or iteration > 4:
                    break
        if 'gt_temporal_group' in clustered_df.columns and clustered_df['gt_temporal_group'].notna().all():
            try:
                eval_func(args, clustered_df)
            except Exception as e:
                print(f"Warning: Evaluation function failed: {e}")
                print("Continuing without evaluation...")
        print(f"Process completed! Path: ({args.output_path}/final_processed{iteration}.csv)")
    
    else:
        print(f"{args.LLM_name} Single processing start!")
        total_cost = {
            "prompt_tokens": [],
            "completion_tokens": [],
            "gpt-45-cost": [],
            "gpt-4o-batch-cost": [],
            "gpt-4o-cost": [],
            "gpt-4o-mini-cost": [],
        }
        iteration = 0
        
        if args.model_run:
            llm_output = run_llm(total_cost, clustered_df, is_missing_process=False)
            clustered_df = post_process(llm_output, clustered_df, args.output_path, is_missing_process=False)
            unmatched_count = (clustered_df['llm_processed'] == 'unmatched').sum()
            matched_count = (clustered_df['llm_processed'] == 'matched').sum()
            print(f"Iteration 0 completed. Remaining unmatched: {unmatched_count}, Matched: {matched_count}")
            clustered_df.drop_duplicates(inplace=True)
            clustered_df.to_csv(f"{args.output_path}/final_processed{iteration}_{unmatched_count}missing.csv", index=False)
        else:
            processed_files = glob.glob(f"{args.output_path}/final_processed*_*missing.csv")
            print(processed_files)
            if processed_files:
                # Sort files to get the most recent one
                latest_file = sorted(processed_files)[-1]
                iteration = int(latest_file.split('final_processed')[1].split('_')[0])


                clustered_df = pd.read_csv(latest_file)
                print(f"Loaded processed data from {latest_file}")
            else:
                raise FileNotFoundError(f"No processed data found in {args.output_path}")
        
        if args.process_missing:
            while (clustered_df['llm_processed'] == 'unmatched').any():
                iteration += 1
                print(f"{args.LLM_name} Missing process iteration {iteration} start!")
                result_data = prepare_missing_inputs(clustered_df)
                
                missing_dict = {}
                for subject_id, subject_clusters in result_data.items():
                    print(f"\n=== Subject ID: {subject_id} ===")
                    missing_dict[subject_id] = {}
                    for cluster_name in subject_clusters.keys():
                        missing_input = create_missing_input(subject_id, cluster_name, result_data)
                        missing_dict[subject_id][cluster_name] = missing_input
                
                missing_llm_output = run_llm(total_cost, clustered_df, missing_data=missing_dict, is_missing_process=True)
                clustered_df = post_process(missing_llm_output, clustered_df, args.output_path, is_missing_process=True, iteration=iteration)
                # Break the loop if no progress is being made (to prevent infinite loops)
                unmatched_count = (clustered_df['llm_processed'] == 'unmatched').sum()
                matched_count = (clustered_df['llm_processed'] == 'matched').sum()  
                print(f"Iteration {iteration} completed. Remaining unmatched: {unmatched_count}, Matched: {matched_count}")
                clustered_df.drop_duplicates(inplace=True)
                clustered_df.to_csv(f"{args.output_path}/final_processed{iteration}_{unmatched_count}missing.csv", index=False)
                
                if (clustered_df['llm_processed'] == 'unmatched').sum() == 0 or iteration > 4:
                    break

        if 'gt_temporal_group' in clustered_df.columns and clustered_df['gt_temporal_group'].notna().all():
            try:
                eval_func(args, clustered_df)
            except Exception as e:
                print(f"Warning: Evaluation function failed: {e}")
                print("Continuing without evaluation...")
        print(f"Process completed! Path: ({args.output_path}/final_processed{iteration}.csv)")