"""

Localserver:

    MedGemma: 
        CUDA_VISIBLE_DEVICES=5,6 python singleSR/run.py --deployment_name MedGemma-27b-it --api_key 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 --mode gold_eval --n_retrieval 0
        CUDA_VISIBLE_DEVICES=5,6 python singleSR/run.py --deployment_name MedGemma-27b-it --api_key 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 --mode gold_eval --n_retrieval 5
        
        VLLM:
            huggingface-cli login

            export HF_TOKEN=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
            docker rm -f vllm-medgemma 2>/dev/null || true

            docker run --gpus '"device=2,3"' --ipc=host --name vllm-medgemma -p 8002:8000 \
            --user $(id -u):$(id -g) \
            -e USER=jhmoon \
            -e HOME=/tmp \
            -e XDG_CACHE_HOME=/tmp/.cache \
            -e HF_HOME=/tmp/.cache/huggingface \
            -e TRANSFORMERS_CACHE=/tmp/.cache/huggingface/transformers \
            -e HUGGINGFACE_HUB_CACHE=/tmp/.cache/huggingface \
            -e FLASHINFER_WORKSPACE_DIR=/tmp/.cache/flashinfer \
            -e HUGGING_FACE_HUB_TOKEN=$HF_TOKEN \
            vllm/vllm-openai:gptoss \
            --model google/medgemma-27b-it \
            --hf-token $HF_TOKEN \
            --gpu-memory-utilization 0.90 \
            --served-model-name medgemma-27b-text-it \
            --trust-remote-code \
            --tensor-parallel-size 2
        
            ###############################################################

            huggingface-cli login

            export HF_TOKEN=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
            docker rm -f vllm-gpt-oss-120b 2>/dev/null || true

            docker run --gpus '"device=4,5,6,7"' --ipc=host --name vllm-gpt-oss-120b -p 8004:8000 \
            --user $(id -u):$(id -g) \
            -e USER=jhmoon \
            -e HOME=/tmp \
            -e XDG_CACHE_HOME=/tmp/.cache \
            -e HF_HOME=/tmp/.cache/huggingface \
            -e TRANSFORMERS_CACHE=/tmp/.cache/huggingface/transformers \
            -e HUGGINGFACE_HUB_CACHE=/tmp/.cache/huggingface \
            -e FLASHINFER_WORKSPACE_DIR=/tmp/.cache/flashinfer \
            -e HUGGING_FACE_HUB_TOKEN=$HF_TOKEN \
            vllm/vllm-openai:v0.10.1.1 \
            --model openai/gpt-oss-120b \
            --hf-token $HF_TOKEN \
            --gpu-memory-utilization 0.90 \
            --served-model-name gpt-oss-120b \
            --trust-remote-code \
            --tensor-parallel-size 4
        
        CMD:
            python singleSR/run.py --deployment_name vllm_gpt-oss-120b --api_key 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 --mode gold_eval --n_retrieval 5 --port 8002
        
    Baichuan:
    - env:
        pip install vllm==0.6.6.post1
        pip install transformers==4.48.0
        git clone https://github.com/baichuan-inc/vllm.git
        cd vllm
        export VLLM_PRECOMPILED_WHEEL_LOCATION=https://pypi.tuna.tsinghua.edu.cn/packages/b0/14/9790c07959456a92e058867b61dc41dde27e1c51e91501b18207aef438c5/vllm-0.6.6.post1-cp38-abi3-manylinux1_x86_64.whl
        pip install --editable .
    
    - serving:

        export HF_TOKEN=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
        docker rm -f vllm-baichuan 2>/dev/null || true
        docker run --gpus '"device=0,1"' --ipc=host --name vllm-baichuan -p 8002:8000 \
        -e HUGGING_FACE_HUB_TOKEN=$HF_TOKEN \
        vllm/vllm-openai:gptoss \
        --model baichuan-inc/Baichuan-M1-14B-Instruct \
        --hf-token $HF_TOKEN \
        --gpu-memory-utilization 0.90 \
        --served-model-name baichuan-inc/Baichuan-M1-14B-Instruct \
        --trust-remote-code \
        --tensor-parallel-size 2
        
        CUDA_VISIBLE_DEVICES=4,5 vllm serve baichuan-inc/Baichuan-M1-14B-Base  --trust-remote-code --port 8004  --tensor-parallel-size 2
    
    - running *under-the-same-env and port number:
        python singleSR/run.py --deployment_name baichuan-2-72b-chat --api_key fake-key --mode gold_eval --port 8004

"""

import os
import torch
import argparse

# Set environment variable to avoid tokenizer parallelism issues
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from openai import AzureOpenAI
import datetime
from generate_prompt import create_sys_msg
from utils import (retreive_query_related_fewshot, generate_graph, create_input, StructuredOutput,
                  initialize_llm_client, convert_to_sr_structure, evaluate_funct, create_relation_dataframe, create_relation_dataframe2, create_relation_dataframe3,
                  generate_table, visualize_table, visualize_metrics, visualize_rexval_metrics, create_fp, create_fn, get_vocab_lookup, get_words, get_cols)
from tqdm import tqdm
import time
import json
from multiprocessing import Pool, cpu_count
import pandas as pd

def process_single_item(args_tuple):
    """Process a single item for batch file creation"""
    custom_id, input_data, devset, args, system_message = args_tuple
    
    query = input_data[custom_id]['passage']
    query_subject_id = input_data[custom_id]['subject_id']
    if query is None or len(query) < 2 or pd.isna(query) or not isinstance(query, str):
        return None 
    
    json_vocab_input, user_history, assistant_history = retreive_query_related_fewshot(devset, query, query_subject_id, args)
    
    conversation = [{"role": "system", "content": system_message}]

    for i in range(len(user_history)):
        conversation.append({"role": "user", "content": user_history[i]})
        conversation.append({"role": "assistant", "content": assistant_history[i]})

    # GT candidate를 사용할 경우의 input + candidate를 사용하지 않을 경우의 input
    if ('gt' in args.candidate_type) or (args.candidate_type == 'no_candidates'):
        json_gt_input = json.dumps(input_data[custom_id]['json_input'], indent=4)
        user_string = f"INPUT:\n{json_gt_input}\n"
    # vocab candidate를 사용할 경우의 input --> retrieve에서 가져온 json_vocab_input을 사용한다.
    else:
        json_vocab_input = json.dumps(json_vocab_input, indent=4)
        user_string = f"INPUT:\n{json_vocab_input}\n"

    prompt = {"role": "user", "content": user_string}
    
    conversation.append(prompt)

    schema_file_path = "./singleSR/gpt_json_schema.json"
    with open(schema_file_path, 'r') as f:
        gpt_json_schema = json.load(f)

    if args.deployment_name.startswith('gpt'):
        task = {
            "custom_id": custom_id,
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {
                    "model": args.deployment_name,
                    "temperature": 0.0,
                    "messages": conversation,
                    "response_format": { 
                        "type": "json_schema",
                        "json_schema": gpt_json_schema
                    }
                    }
        }
    else:
        task = {
            "custom_id": custom_id,
            "system_message": system_message,
            "messages": conversation
        }
    
    return task

def create_batch_file(args):
    input_data, devset = create_input(args)
    system_message = create_sys_msg(args)
    
    print(f"Starting multiprocessing with {min(cpu_count(), 64)} workers")
    
    # Prepare arguments for multiprocessing
    process_args = [(custom_id, input_data, devset, args, system_message) 
                   for custom_id in input_data.keys()]
    
    # Use multiprocessing to process items in parallel
    with Pool(processes=min(cpu_count(), 64)) as pool:
        tasks = list(tqdm(
            pool.imap(process_single_item, process_args),
            total=len(process_args),
            desc="Generating batch file"
        ))

    # Filter out None values
    tasks = [task for task in tasks if task is not None]
    
    if not args.dynamic_retrieval:
        data_dir = f'./data/batch_files/{args.n_retrieval}_{args.candidate_type}_{args.deployment_name}/{args.mode}/{args.unit}/{args.candidate_usage}'
    else:
        data_dir = f'./data/batch_files/dynamic_{args.candidate_type}_{args.deployment_name}/{args.mode}/{args.unit}/{args.candidate_usage}'

    if not os.path.exists(data_dir):
        os.makedirs(f'{data_dir}', exist_ok=True)
    with open(f'{data_dir}/batch_file.jsonl', 'w') as file:
        for obj in tasks:
            file.write(json.dumps(obj) + '\n')
    print('Batch file has been created')


def run_llm_batch(file_name, client, args):
    batch_file = client.files.create(
                    file=open(file_name, "rb"),
                    purpose="batch"
                    )
    print(f"Batch file uploaded: {batch_file}")

    batch_job = client.batches.create(
                    input_file_id=batch_file.id,
                    # endpoint="https://edlab-gpt4.openai.azure.com/",
                    endpoint="/v1/chat/completions",                        
                    completion_window="24h"
                    )
    print(batch_job)
    
    return batch_job

def extract_passage_history(message):
    passage = None
    user_history = []
    assistant_history = []

    for idx, m in enumerate(message):
        if idx == len(message) - 1:
            passage = m['content']
        else:
            if m['role'] == 'user':
                user_history.append(m['content'])
            elif m['role'] == 'assistant':
                assistant_history.append(m['content'])

    return passage, user_history, assistant_history
        
def read_batch_results(batch_path, results_path, args):
    batch_file_path = batch_path + '/batch_file.jsonl'
    batch_results_file_path = results_path + '/batch_results.jsonl'
    all_model_outputs = {}
    all_results = []
    with open(batch_results_file_path, 'r') as file:
        for line in file:
            result = json.loads(line)            
            custom_id = result.get('custom_id', '')
            content = None
            try:
                if "content" in result:
                    content = result.get("content", "")
                else:
                    # This is GPT format (nested content)
                    response = result.get('response', {})
                    body = response.get('body', {})
                    
                    if isinstance(body, str):
                        # If body is a string, attempt to parse it
                        try:
                            body = json.loads(body)
                        except json.JSONDecodeError:
                            print(f"Warning: Could not parse body as JSON for custom_id {custom_id}")
                            continue
                    
                    # Get choices array
                    choices = body.get('choices', [])
                    
                    if not choices or len(choices) == 0:
                        print(f"Warning: No choices found for custom_id {custom_id}")
                        continue
                    message = choices[0].get('message', {})
                    content = message.get('content', '')
            except Exception as e:
                print(f"Error extracting content from result: {e}")
                continue
            
            if not content:
                print(f"Warning: No content found for custom_id {custom_id}")
                continue
            all_model_outputs[custom_id] = content

    # Process batch file
    with open(batch_file_path, 'r') as file:
        for line in file:
            result = json.loads(line)
            custom_id = result.get('custom_id', '')
            if 'body' in result:
                body = result.get('body', {})
                message = body.get('messages', {}) # list
            else:
                message = result.get('messages', {})
            
            passage, user_history, assistant_history = extract_passage_history(message)

            results = {
                'custom_id': custom_id,
                'passage': passage,
                'user_history': user_history,
                'assistant_history': assistant_history,
                'model_output': all_model_outputs[custom_id]
            }

            all_results.append(results)
            
    total_results = {
        "exp_name": args.mode,
        "exp_date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "deployment_name": args.deployment_name,
        "entity_types": args.entity_types,
        "relation_types": args.relation_types,
        "attribute_types": args.attribute_types,
        "output_format": args.output_format,
        "input_unit": args.unit,
        "retreive_unit": args.unit,
        "n_retrieval": args.n_retrieval,
        "diverse_retrieval": args.diverse_retrieval,
        "annotations": all_results,
        }
    
    output_file = os.path.join(results_path, f'{args.deployment_name}_{args.output_format}_{args.mode}_{args.unit}.json')
    with open(output_file, "w", encoding="utf-8") as outfile:
        json.dump(total_results, outfile, indent=4, ensure_ascii=False)
    return total_results, output_file

def monitor_batch_job(batch_job_id, client, args, check_interval=2):

    print("-" * 50)

    previous_status=None

    try:
        while True:
            current_time = datetime.datetime.now().strftime("%H:%M:%S")
            batch = client.batches.retrieve(batch_job_id)
            batch_status = batch.status
            completed_count = batch.request_counts.completed
            total_count = batch.request_counts.total
                            
            if batch_status != previous_status:
                print(f"\n[{current_time}] Status changed: {previous_status} → {batch_status}")
                print(f"Completed: {completed_count}/{total_count}")
                
                if batch_status == "in_progress":
                    print(f"In progress: {completed_count}/{total_count}")
                
                if batch_status in ["completed", "succeeded", "ended"]:
                    print(f"Completed! Result file ID: {batch.output_file_id}")
                    break
                    
                if batch_status in ["failed", "errored"]:
                    print("Failed!")
                    break
                    
                previous_status = batch_status
            else:
                if batch_status == "in_progress" and total_count > 0:
                    completion_percent = (completed_count / total_count) * 100
                    print(f"\r[{current_time}] Progress: {completion_percent:.1f}% ({completed_count}/{total_count})", end="")
                else:
                    print(f"\r[{current_time}] Current status: {batch_status}", end="")
            
            time.sleep(check_interval)

    except KeyboardInterrupt:
        print("\n\nMonitoring stopped.")
        cancel_batch = client.batches.cancel(batch_job_id)
        print("Batch cancelled", cancel_batch)
    
    return batch_status

def generate_gt_output(triplet_list):
    
    unique_entities = list(set([entity[0] for entity in triplet_list]))
    entities = []
    
    for entity in unique_entities:
        relations = []
        for triplet in triplet_list:
            if triplet[0] == entity:
                if triplet[1] in ['associate', 'evidence']:
                    # Split value like "pneumonia, obj_ent_idx2, effusion, obj_ent_idx3"
                    tokens = [t.strip() for t in triplet[2].split(',')]
                    for i in range(0, len(tokens), 2):
                        value = tokens[i]
                        if i+1 < len(tokens) and tokens[i+1].startswith('obj_ent_idx'):
                            obj_ent_idx = int(tokens[i+1].replace('obj_ent_idx', ''))
                            relations.append({
                                "relation": triplet[1],
                                "value": value,
                                "obj_ent_idx": obj_ent_idx
                            })
                else:
                    value = triplet[2]
                    if triplet[1] == 'location' and isinstance(value, str) and value.lower().startswith("loc:"):
                        value = value[4:].strip()
                    if triplet[1] in ['cat', 'dx_status', 'dx_certainty']:
                        value = value.upper()
                    relations.append({
                        "relation": triplet[1],
                        "value": value
                    })
            
        entity_obj = {
                "name": entity,
                "sent_idx": 1,
                "ent_idx": 1,
                "relations": relations
            }
        entities.append(entity_obj)

    return "OUTPUT: " + json.dumps({"entities": entities}, indent=2)


def process_conversation(user_history, assistant_history, system_message, input_data, custom_id, args):
    
    conversation = [{"role": "system", "content": system_message}]
    
    for hist_idx, history in enumerate(user_history):
        conversation.append({"role": "user", "content": "----------- New Report Section -------------"})
        for i in range(len(history)):
            conversation.append({"role": "user", "content": history[i]})
            conversation.append({"role": "assistant", "content": assistant_history[hist_idx][i]})
    
    return conversation

def process_chunk(chunk_data, devset, results_path, args, chunk_id, client=None, tokenizer=None):
    """Process a chunk of data and save intermediate results"""
    # Use the provided client and tokenizer instances, or create new ones if None
    if client is None or tokenizer is None:
        client, tokenizer = initialize_llm_client(args.deployment_name, args.api_key)
    system_message = create_sys_msg(args)
    all_results = []
    error_cases = []
    
    # Create intermediate results directory
    intermediate_dir = os.path.join(results_path, 'intermediate')
    os.makedirs(intermediate_dir, exist_ok=True)
    
    # Load existing results if any
    chunk_file = os.path.join(intermediate_dir, f'chunk_{chunk_id}.json')
    if os.path.exists(chunk_file):
        with open(chunk_file, 'r', encoding='utf-8') as f:
            existing_data = json.load(f)
            all_results = existing_data.get('annotations', [])
            error_cases = existing_data.get('error_cases', [])
            # Get processed file indices (only from successful results)
            processed_indices = {result['custom_id'] for result in all_results}
            # Get error case indices
            error_indices = {error['file_idx'] for error in error_cases}
    else:
        processed_indices = set()
        error_indices = set()
    
    # If we have error cases, only process those
    if error_cases:
        print(f"Reprocessing {len(error_cases)} error cases in chunk {chunk_id}")
        to_process = {idx: chunk_data[idx] for idx in error_indices if idx in chunk_data}
    else:
        to_process = {idx: data for idx, data in chunk_data.items() if idx not in processed_indices}
    
    progress_bar = tqdm(enumerate(to_process.keys()), total=len(to_process), desc=f"Processing chunk {chunk_id}")
    
    for itr, file_idx in progress_bar:
        try:
            query = chunk_data[file_idx]['passage']
            query_subject_id = chunk_data[file_idx]['subject_id']
            if query is None or len(query) < 2 or pd.isna(query) or not isinstance(query, str):
                continue
            else:
                json_vocab_input, user_history, assistant_history = retreive_query_related_fewshot(devset, query, query_subject_id, args)
                
                conversation = [{"role": "system", "content": system_message}]
                        
                if args.candidate_usage != 1:
                    target_length = int(len(query_words) * args.candidate_usage)
                    step = len(query_words) // target_length
                    query_words = query_words[::step][:target_length]

                for i in range(len(user_history)):
                    conversation.append({"role": "user", "content": user_history[i]})
                    conversation.append({"role": "assistant", "content": assistant_history[i]})

                if ('gt' in args.candidate_type) or (args.candidate_type == 'no_candidates'):
                    json_gt_input = json.dumps(chunk_data[file_idx]['json_input'], indent=4)
                    user_string = f"INPUT:\n{json_gt_input}\n"
                else:
                    json_vocab_input = json.dumps(json_vocab_input, indent=4)
                    user_string = f"INPUT:\n{json_vocab_input}\n"

                prompt = {"role": "user", "content": user_string}
                conversation.append(prompt)

                try:
                    if args.deployment_name.startswith('gemini'):
                        response = client.messages.create(
                            messages=conversation,
                            generation_config={
                                "temperature": 0.0
                            },
                            response_model=StructuredOutput
                        )
                    elif args.deployment_name.startswith('MedGemma'):
                        
                        inputs = tokenizer.apply_chat_template(
                            conversation,
                            add_generation_prompt=True,
                            tokenize=True,
                            return_dict=True,
                            return_tensors="pt"
                        )
                        with torch.inference_mode():
                            output = client.generate(
                                **inputs.to('cuda'),
                                max_new_tokens=8092  
                            )
                            generation = output[0][inputs["input_ids"].shape[1]:]
                            response = tokenizer.decode(generation, skip_special_tokens=True)
                    
                    elif args.deployment_name.startswith('baichuan') or args.deployment_name.startswith('vllm_medgemma') or args.deployment_name.startswith('vllm_gpt-oss-20b') or args.deployment_name.startswith('vllm_gpt-oss-120b'):
                        response = client.chat.completions.create(
                            model= "baichuan-inc/Baichuan-M1-14B-Instruct" if args.deployment_name.startswith('baichuan') else "medgemma-27b-text-it" if args.deployment_name.startswith('vllm_medgemma') else "gpt-oss-20b" if args.deployment_name.startswith('vllm_gpt-oss-20b') else "gpt-oss-120b",
                            messages=conversation,
                            response_model=StructuredOutput
                            )
                    else:
                        response = client.chat.completions.create(
                            model=f'accounts/fireworks/models/{args.deployment_name}',
                            max_tokens= 8092,
                            temperature=0.0,
                            messages=conversation,
                            response_model=StructuredOutput)
                    
                    triplets = convert_to_sr_structure(response)
            
                    passage, user_history, assistant_history = extract_passage_history(conversation)

                    results = {
                        'custom_id': file_idx,
                        'passage': passage,
                        'user_history': user_history,
                        'assistant_history': assistant_history,
                        'model_output': triplets
                    }

                    all_results.append(results)
                    
                    # If this was an error case, remove it from error_cases
                    if file_idx in error_indices:
                        error_cases = [error for error in error_cases if error['file_idx'] != file_idx]
                        print(f"Successfully reprocessed error case {file_idx}")
                        
                except Exception as e:
                    error_info = {
                        'file_idx': file_idx,
                        'error_type': type(e).__name__,
                        'error_message': str(e),
                        'conversation': conversation
                    }

                    # Only add to error_cases if it's not already there
                    if not any(error['file_idx'] == file_idx for error in error_cases):
                        error_cases.append(error_info)
                    print(f"Error processing file {file_idx}: {str(e)}")
                
                # Save intermediate results after each file
                chunk_results = {
                    "chunk_id": chunk_id,
                    "exp_name": args.mode,
                    "exp_date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                    "deployment_name": args.deployment_name,
                    "entity_types": args.entity_types,
                    "relation_types": args.relation_types,
                    "attribute_types": args.attribute_types,
                    "output_format": args.output_format,
                    "input_unit": args.unit,
                    "retreive_unit": args.unit,
                    "n_retrieval": args.n_retrieval,
                    "diverse_retrieval": args.diverse_retrieval,
                    "annotations": all_results,
                    "error_cases": error_cases
                }
                
                with open(chunk_file, "w", encoding="utf-8") as outfile:
                    json.dump(chunk_results, outfile, indent=4, ensure_ascii=False)
                continue

        except Exception as e:
            error_info = {
                'file_idx': file_idx,
                'error_type': type(e).__name__,
                'error_message': str(e)
            }
            # Only add to error_cases if it's not already there
            if not any(error['file_idx'] == file_idx for error in error_cases):
                error_cases.append(error_info)
            print(f"Error in initial processing of file {file_idx}: {str(e)}")
            
            # Save intermediate results even when there's an error
            chunk_results = {
                "chunk_id": chunk_id,
                "exp_name": args.mode,
                "exp_date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                "deployment_name": args.deployment_name,
                "entity_types": args.entity_types,
                "relation_types": args.relation_types,
                "attribute_types": args.attribute_types,
                "output_format": args.output_format,
                "input_unit": args.unit,
                "retreive_unit": args.unit,
                "n_retrieval": args.n_retrieval,
                "diverse_retrieval": args.diverse_retrieval,
                "annotations": all_results,
                "error_cases": error_cases
            }
            
            with open(chunk_file, "w", encoding="utf-8") as outfile:
                json.dump(chunk_results, outfile, indent=4, ensure_ascii=False)
            continue
    
    return chunk_file

def merge_results(intermediate_dir, final_output_file):
    """Merge all intermediate results into a single file"""
    all_results = []
    all_error_cases = []
    
    # Get all chunk files
    chunk_files = sorted([f for f in os.listdir(intermediate_dir) if f.startswith('chunk_')])
    
    for chunk_file in chunk_files:
        with open(os.path.join(intermediate_dir, chunk_file), 'r', encoding='utf-8') as f:
            chunk_data = json.load(f)
            all_results.extend(chunk_data['annotations'])
            all_error_cases.extend(chunk_data['error_cases'])
    # Create final merged results
    merged_results = {
        "exp_name": chunk_data['exp_name'],
        "exp_date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "deployment_name": chunk_data['deployment_name'],
        "entity_types": chunk_data['entity_types'],
        "relation_types": chunk_data['relation_types'],
        "attribute_types": chunk_data['attribute_types'],
        "output_format": chunk_data['output_format'],
        "input_unit": chunk_data['input_unit'],
        "retreive_unit": chunk_data['retreive_unit'],
        "n_retrieval": chunk_data['n_retrieval'],
        "diverse_retrieval": chunk_data['diverse_retrieval'],
        "annotations": all_results,
        "error_cases": all_error_cases
    }
    
    # Save merged results
    with open(final_output_file, "w", encoding="utf-8") as outfile:
        json.dump(merged_results, outfile, indent=4, ensure_ascii=False)

def main(args, client, tokenizer):
    
    if not args.dynamic_retrieval:
        if args.multi:
            results_path = f'./singleSR/exp/M{args.n_retrieval}_{args.candidate_type}_{args.deployment_name}/{args.mode}/{args.unit}/{args.candidate_usage}'
        else:
            results_path = f'./singleSR/exp/{args.n_retrieval}_{args.candidate_type}_{args.deployment_name}/{args.mode}/{args.unit}/{args.candidate_usage}'
    else:
        results_path = f'./singleSR/exp/dynamic_{args.candidate_type}_{args.deployment_name}/{args.mode}/{args.unit}/{args.candidate_usage}'

    if not args.multi and args.deployment_name.startswith('gpt'):
        
        if not args.dynamic_retrieval:
            batch_path = f'./singleSR/data/batch_files/{args.n_retrieval}_{args.candidate_type}_{args.deployment_name}/{args.mode}/{args.unit}/{args.candidate_usage}'
        else:
            batch_path = f'./singleSR/data/batch_files/dynamic_{args.candidate_type}_{args.deployment_name}/{args.mode}/{args.unit}/{args.candidate_usage}'
        # 1. Generate batch file
        if not os.path.exists(batch_path) or not any(f.endswith('.jsonl') for f in os.listdir(batch_path)):
            create_batch_file(args)


        if not os.path.exists(f'{results_path}/batch_results.jsonl'):
            print("\n LLM result file does not exist. Running LLM Batch Processing!!")
            if args.mode == 'gold_eval':
                with open(f'{batch_path}/batch_file.jsonl', 'r') as file:
                    total_lines = sum(1 for line in file)
                    if args.mode == 'gold_eval':
                        assert total_lines == 3601, "Total lines should be 3601"
                    if args.mode == 'test_80studies':
                        assert total_lines == 201, "Total lines should be 201"
                    if args.mode == 'test_2studies':
                        assert total_lines == 5, "Total lines should be 5"
                    

        
            # 2. Get batch file path
            batch_file_list = os.listdir(batch_path)
            batch_file_path = os.path.join(batch_path, batch_file_list[0])

            # 3. Run batch file
            batch_job = run_llm_batch(batch_file_path, client, args)
            print(f"batch_job: {batch_job}")
            batch_job_id = batch_job.id
        
            batch_status = monitor_batch_job(batch_job_id, client, args, check_interval=2)

            # Final status check
            batch = client.batches.retrieve(batch_job_id)
            completed_count = batch.request_counts.completed
            total_count = batch.request_counts.total
            
            print("\n\nFinal batch status:")
            print(f"Status: {batch_status}")
            print(f"Completed: {completed_count}/{total_count}")

            if batch_status in ["completed", "succeeded", "ended"]:
                print("\nBatch job completed!")
                
                if args.deployment_name.startswith('gpt'):
                    result_content = client.files.content(batch.output_file_id).text
                else:
                    result_content = ""
                    
                    result_list = []
                    for result in client.messages.batches.results(batch_job_id):
                        if result.result.type == "succeeded":
                            result_json = {
                                "custom_id": result.custom_id,
                                "content": result.result.message.content[0].input if result.result.message.content[0].type == 'tool_use' else result.result.message.content[0].text
                            }
                            result_list.append(json.dumps(result_json))
                    result_content = "\n".join(result_list)
                    print(f"Processed {len(result_list)} successful results")
                    
                os.makedirs(results_path, exist_ok=True)
                with open(f'{results_path}/batch_results.jsonl', 'w') as f:
                    f.write(result_content)
                print("-----")
        else:
            print("LLM result file exists. Reading LLM result...")

        _, LLM_SR_file = read_batch_results(batch_path, results_path, args) 
                     
    else:
        if args.run_model:
            final_output_file = os.path.join(results_path, f'{args.n_retrieval}_{args.candidate_type}_{args.deployment_name}_{args.output_format}_{args.mode}_{args.unit}.json')

            if not os.path.exists(final_output_file):                
                input_data, devset = create_input(args)

                # Split data into chunks
                if args.mode in ['medversa', 'rgrg', 'cvt2distilgpt2']:
                    chunk_size = 5
                else:
                    chunk_size = 225  # Adjust this based on your needs

                data_items = list(input_data.items())
                chunks = [dict(data_items[i:i + chunk_size]) for i in range(0, len(data_items), chunk_size)]

                # Use sequential processing for MedGemma models to avoid CUDA context issues
                if args.deployment_name.startswith('MedGemma'):
                    print("Using sequential chunk processing for MedGemma model")
                    chunk_files = []
                    for i, chunk in enumerate(tqdm(chunks, desc="Processing chunks")):
                        chunk_file = process_chunk(chunk, devset, results_path, args, i, client, tokenizer)
                        chunk_files.append(chunk_file)
                else:
                    # Process chunks in parallel
                    with Pool(processes=min(cpu_count(), 64)) as pool:
                        chunk_args = [(chunk, devset, results_path, args, i, None, None) for i, chunk in enumerate(chunks)]
                        chunk_files = list(tqdm(
                            pool.starmap(process_chunk, chunk_args),
                            total=len(chunks),
                            desc="Processing chunks"
                        ))
    
                # Merge results
                
                merge_results(os.path.join(results_path, 'intermediate'), final_output_file)
    
                # Clean up intermediate files
                #intermediate_dir = os.path.join(results_path, 'intermediate')
                #for chunk_file in os.listdir(intermediate_dir):
                #    os.remove(os.path.join(intermediate_dir, chunk_file))
                #os.rmdir(intermediate_dir)
    
            LLM_SR_file = final_output_file
        else:
            LLM_SR_file = os.path.join(results_path, f'{args.n_retrieval}_{args.candidate_type}_{args.deployment_name}_{args.output_format}_{args.mode}_{args.unit}.json')
                
    gold_file_path = f'{args.output_dir}/{args.mode}_{args.candidate_type}_{args.unit}_input.json'
    
    if args.mode == 'silver_eval':
        create_relation_dataframe3(gold_file_path, LLM_SR_file, f'./singleSR/eval/{args.mode}/{args.n_retrieval}_{args.candidate_type}_{args.deployment_name}/{args.output_format}/{args.unit}/{args.candidate_usage}')

    elif args.mode in ['test_80studies', 'test_2studies', 'gold_eval', 'rexval', 'maira', 'maira_cascade', 'rexerr', 'medversa', 'rgrg', 'cvt2distilgpt2']:
        eval_path = evaluate_funct(gold_file_path, LLM_SR_file, args)
        
        create_fp(eval_path, gold_file_path)
        create_fn(eval_path, gold_file_path)
        
        if args.mode in ['rexval']:
            create_relation_dataframe(gold_file_path, LLM_SR_file, eval_path)
        else:
            create_relation_dataframe2(gold_file_path, LLM_SR_file, eval_path)
        
        generate_graph(eval_path, args)   
        
        generate_table(base_path=f'./singleSR/eval/{args.mode}', args=args,
                    output_path=f'./singleSR/model_comparison/{args.mode}/all_models_comparison.csv',
                    mode=args.mode)
        
        if args.mode != 'rexval':
            visualize_table(f'./singleSR/model_comparison/{args.mode}/all_models_comparison.csv', 
                            output_path=f'./singleSR/model_comparison/figures/{args.mode}/table.png')

        if args.mode == 'rexval':
            visualize_rexval_metrics(
            f'./singleSR/model_comparison/{args.mode}/all_models_comparison.csv',
            f'./singleSR/model_comparison/figures/{args.mode}/model_comparison.png',
            metrics=['SR F1', 'SRO F1'])
        
        else:
            visualize_metrics(f'./singleSR/model_comparison/{args.mode}/all_models_comparison.csv', 
                            output_path=f'./singleSR/model_comparison/figures/{args.mode}',
                            metrics=['SR F1', 'SRO F1'])

            visualize_metrics(f'./singleSR/model_comparison/{args.mode}/all_models_comparison.csv', 
                            output_path=f'./singleSR/model_comparison/figures/{args.mode}',
                            metrics=['SRO P', 'SRO R'])
    
        
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Create input data for GPT-SR')
    parser.add_argument('--gold_path', type=str, default='./dataset/Lunguage.csv',
                        help='Path to gold dataset')
    
    ########### SILVER EVAL ###########
    parser.add_argument('--silver_eval_report_path', type=str, default='./dataset/processed_reports/mimic_cxr_reports.csv',
                        help='Raw Report Path')
    #######################################
    
    ########### REXVAL  ###########
    parser.add_argument('--rexval_report_path', type=str, default='./benchmark/rexval/50_samples_gt_and_candidates.csv',
                        help='Raw Report Path')
    parser.add_argument('--report_col_name', type=list, default=['gt_report', 'radgraph', 'bertscore', 's_emb', 'bleu'],
                        help='Report column name')
    #######################################
    
    ########### MAIRA 2 ###########
    parser.add_argument('--maira2_cascade_report_path', type=str, default='./benchmark/maira_cascade.ndjson',
                        help='Path to mira2 cascade report')
    parser.add_argument('--maira2_report_path', type=str, default='./benchmark/maira.ndjson',
                        help='Path to mira2 report')
    #######################################


    ########### MEDVERSA ###########
    parser.add_argument('--medversa_report_path', type=str, default='./benchmark/medversa_generated_reports.csv',
                        help='Path to medversa report')
    #######################################
    
    ########### RGRG ###########
    parser.add_argument('--rgrg_report_path', type=str, default='./benchmark/rgrg_generated_reports.csv',
                        help='Path to rgrg report')
    #######################################
    
    ########### Cvt2distilgpt2 ###########
    parser.add_argument('--cvt2distilgpt2_report_path', type=str, default='./benchmark/cvt2distilgpt2_generated_reports.csv',
                        help='Path to cvt2distilgpt2 report')
    #######################################

    ########### REXERR ###########
    parser.add_argument('--rexerr_report_path', type=str, default='./benchmark/rexerr_gold.csv',
                        help='Path to rexerr report')
    #######################################

    parser.add_argument('--lingshu_path', type=str, default='./benchmark/lingshu_results.jsonl',
                        help='Path to rexerr report')
    parser.add_argument('--medgemma_path', type=str, default='./benchmark/medgemma_results.jsonl',
                        help='Path to rexerr report')    

    parser.add_argument('--run_model', action='store_true', default=True,
                        help='Run model')
    parser.add_argument('--toy_set', action='store_true', default=True,
                        help='Use toy set')
    parser.add_argument('--output_dir', type=str, default='./singleSR/data',
                        help='Directory to save output files')
    parser.add_argument('--deployment_name', type=str, default='vllm_gpt-oss-20b',
                        choices=['gpt-4.1', 
                            'vllm_gpt-oss-20b',
                            'vllm_gpt-oss-120b',
                            'baichuan',
                            'vllm_medgemma',
                            'MedGemma-27b-it',
                            'qwen3-235b-a22b', 
                            'deepseek-v3-0324',
                            'llama4-maverick-instruct-basic'])
    
    parser.add_argument('--mode', type=str, default='medversa',
                        choices=['test_80studies', 'test_2studies', 'rexval', 'rexerr', 'gold_eval', 'maira_cascade', 'maira', 'medversa', 'rgrg', 'cvt2distilgpt2', 'silver_eval'],
                        help='Mode for data processing')
    
    parser.add_argument('--port', type=str, default='8000', help='Port for local LLM server')
    
    parser.add_argument('--candidate_type', type=str, default='vocab_ent_rcg',
                        choices=['gt_sro_review', 'gt_sro', 'gt_so', 'gt_s', 'gt_ent_rcg', 'vocab_so', 'vocab_s', 'vocab_ent_rcg', 'no_candidates'],
                        help='Use gold standard')
    parser.add_argument('--unit', type=str, default='section',
                        choices=['report', 'section', 'sent'],
                        help='Unit for SR extraction')
    parser.add_argument('--n_retrieval', type=int, default=5,
                        help='Number of retrievals')
    parser.add_argument('--multi', action='store_true', default=False,
                        help='Use multi-turn')
    parser.add_argument('--context_width', type=int, default=4,
                        help='Context width')
    
    
    ########### Batch split ###########
    parser.add_argument('--except_80std', action='store_true', default=False,
                        help='Use except 80 std')
    
    parser.add_argument('--batch_itr', type=int, default=None,
                        help='Batch iteration')
    
    parser.add_argument('--batch_file_size', type=int, default=2,
                        help='Batch split size')
    #######################################
    
    
    parser.add_argument('--entity_types', nargs='+', default=['COF', 'NCD', 'PATIENT INFO.', 'PF', 'CF', 'OTH'],
                        help='Entity types')
    
    parser.add_argument('--relation_types', nargs='+', default=['Location', 'Associate', 'Evidence'],
                        help='Relation types')
    
    parser.add_argument('--attribute_types', nargs='+', 
                        default=['Morphology', 'Distribution', 'Measurement', 'Severity',
                                'Comparison', 'Onset', 'No Change', 'Improved', 'Worsened',
                                'Placement', 'Past Hx', 'Other Source', 'Assessment Limitations'],
                        help='Attribute types')
    
    
    parser.add_argument('--output_format', type=str, default='SROSRO',
                        help='Output format')
    
    parser.add_argument('--diverse_retrieval', action='store_true', default=True,
                        help='Use diverse retrieval')
    parser.add_argument('--fast_retrieval', action='store_true', default=True,
                        help='Use fast retrieval')
    parser.add_argument('--dynamic_retrieval', action='store_true', default=False,
                        help='Use dynamic retrieval')
    parser.add_argument('--candidate_usage', type=float, default=1,
                        help='Candidate usage ratio')
    parser.add_argument('--candidate_discontinuous', action='store_true', default=False,
                        help='Use discontinuous candidate')
    parser.add_argument('--jaccard', action='store_true', default=True,
                        help='Use Jaccard similarity')
    parser.add_argument('--vocab_path', type=str, default='./dataset/Lunguage_vocab.csv',
                        help='Path to vocab file')
    
    parser.add_argument('--holdout_devset', action='store_true', default=False,
                        help='Use holdout devset')
    parser.add_argument('--api_key', type=str, default='',
                        help='API key for the selected model')
    
    args = parser.parse_args()
    
    if args.holdout_devset:
        args.dev_list = ['s51749906', 's50853840', 's51264956', 's52145612', 's54224807', 's59798967', 's52901628', 's51907814', 's51357526', 's50035498', 's50336741', 's50546279', 's58786693', 's52707748', 's53423060', 's59608718', 's57678258', 's55874928', 's52939782', 's54523680', 's59956784', 's50770541', 's56010471', 's50476602', 's55853389', 's57849643', 's52998783', 's58068113', 's58566283', 's50022945', 's56193921', 's53086061', 's50273882', 's55615214', 's55683961', 's50710771', 's59081164', 's57917788', 's56321140', 's50620677', 's57674897', 's57988469', 's53138800', 's51765454', 's53759718', 's50100756', 's58352175', 's57525852', 's51727838', 's51144460', 's54151331', 's55755138', 's56130174', 's50968695', 's58961408', 's55883502', 's58093109', 's54093116', 's58517699', 's56581797', 's57778607', 's57523636', 's53462360', 's57529728', 's54589789', 's58286219', 's58478940', 's50281752', 's50399800', 's52381727', 's50971332', 's59716296', 's50243114', 's51285349', 's54704786', 's57554917', 's53051689', 's59221699', 's59044985', 's56470564', 's53957798', 's57746739', 's59523573', 's53663749', 's54224166', 's55553875', 's57399078', 's59842808', 's52177069', 's51464763', 's54232769', 's57782283', 's55599778', 's51678067', 's50924449', 's56196471', 's58349137', 's52428827', 's56081926', 's58137643', 's51044625', 's51231499', 's55604705', 's54432661', 's53572321', 's58898395', 's57120453', 's52072042', 's57420525', 's50255843', 's55233589', 's51271572', 's50989704', 's59121133', 's53641457', 's51791247', 's57779343', 's58393560', 's50178679', 's59986698', 's55698800', 's58085167', 's51683155', 's54381763', 's55594849', 's57664750', 's52195893', 's53897449', 's51742525', 's59775769', 's55414814', 's52775752', 's51288835', 's52189004', 's58369249', 's51900597', 's54335521', 's54917064', 's54133231', 's51199892', 's59249979', 's58352022', 's52835225', 's53177649', 's54066754', 's54899257', 's53418217', 's59357257', 's50421764', 's51223853', 's56018459', 's57801123', 's59221051', 's50829485', 's55502536', 's57983519', 's53570653', 's51325572', 's58760787', 's59715122', 's58641137', 's54097861', 's58656783', 's51423353', 's54962274', 's58072789', 's53460154', 's50301215', 's52555178', 's57211901', 's56034024', 's54849848', 's55957472', 's54833205', 's54100996', 's53854854', 's59697640', 's57232140', 's52008677', 's55124994', 's56779415', 's55939586', 's56843282', 's56508966', 's59197220', 's52682048', 's54259878', 's51140249', 's58304701', 's58907220', 's53631792', 's55751115', 's53225676', 's53619001', 's56790426', 's58778783', 's56771404', 's59060938', 's53130454', 's51363438', 's56440919', 's59037095', 's55562335', 's50701107', 's54590636', 's54772630', 's59138498', 's58060878', 's55803590', 's55714183', 's57952807', 's58625748', 's57192814', 's51466579', 's58979101', 's50247294', 's54729238', 's55617591', 's57410883', 's58836797', 's56168637', 's58195876', 's50844004', 's50555779', 's51765753', 's52616494', 's54692227', 's50714348', 's54058678', 's51719198', 's50014127', 's59203230', 's50822353', 's55728799', 's50547182', 's50256977', 's53924742', 's56679657', 's51526655', 's55902256', 's57827533', 's57395479', 's52939447', 's54949810', 's52995335', 's50184397', 's53482917', 's57161577', 's50308220', 's51972257', 's58666319', 's54060800', 's51099690', 's59900684', 's52064406', 's55463368', 's51405069', 's57629666', 's52350132', 's58087032', 's51233560', 's54766893', 's59756815', 's57975962', 's56237499', 's52697084', 's56042355', 's57251948', 's50323961', 's55775814', 's59044011', 's59607772', 's52600197', 's53982700', 's55499601', 's55212349', 's57356552', 's53414987', 's55082399', 's59284918', 's59480672', 's57233393', 's58232231', 's58001075', 's56427859', 's53583954', 's58056585', 's51465438', 's57632806', 's50431066', 's58327706', 's58897728', 's56093476', 's58215117', 's50270173', 's50830952', 's54331436', 's54073075', 's51189125', 's54712047']
    else:
        args.dev_list = []
    
    if args.api_key == '':
        raise ValueError(f"API key or Huggingface token is required for {args.deployment_name} models")
    
    if args.port:
        os.environ["PORT"] = args.port
    
    client, tokenizer = initialize_llm_client(args.deployment_name, args.api_key)
    
    main(args, client, tokenizer)