import os
import sys
import random
import warnings
import logging
import asyncio
import hydra
import numpy as np

from tqdm import tqdm
from omegaconf import DictConfig
from dotenv import load_dotenv
from transformers import AutoTokenizer, AutoConfig

from src.models.vllm_model import LocalVLLM
from src.models.rule_model import RuleModel
from src.models.litellm_model import LiteLLM, AsyncLiteLLM
from src.models.openai_model import AsyncOpenAILLM
from src.utils.utils import (
    load_json, 
    dump_json_output, 
    parse_evaluation_results,
    convert_question_to_instruction,
    check_all_checklist_completion
)
from src.utils.constant import (
    EVAL_OUTPUT_PATH,
    REASONING_MODEL_NAMES
)
from src.utils.prompts import (
    INITIAL_QUESTION_PROMPT_TEMPLATE,
    INITIAL_PASSAGES_QUESTION_PROMPT_TEMPLATE,
    CHECKLIST_EVALUATION_TEMPLATE,
    NEXT_TURN_EVALUATION_TEMPLATE,
    NEXT_TURN_EVALUATION_DEFAULT_TEMPLATE,
    NEXT_TURN_PARTIAL_FEEDBACK_EVALUATION_TEMPLATE,
    NEXT_TURN_EVALUATION_DEFAULT_TEMPLATE_W_CHECKLIST
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def load_refinebench(path: str) -> list:
    return load_json(path)


def load_models(model_path: str,
                max_tokens: int = None,
                temperature: float = None,
                top_p: float = None,
                batch_size: int = None,
                requests_per_minute: int = None,
                reasoning_effort: str = None):
    if model_path == 'user':
        return RuleModel(), None, None

    if 'openrouter' in model_path:
        
        target_model = AsyncLiteLLM(
            model_path=model_path, 
            batch_size=batch_size, 
            requests_per_minute=requests_per_minute
        )
        tokenizer = None
        
    elif model_path in [
        "openai/o1-mini",
        "openai/o1",
        "openai/o3",
        "openai/o3-mini",
        "xai/grok-3-mini-beta",
        "deepseek/deepseek-r1",
        "qwen/qwq-32b",
        "deepseek/deepseek-r1-distill-qwen-14b",
        "deepseek/deepseek-r1-distill-qwen-32b",
        "qwen/qwen3-32b"
    ]:
        target_model = AsyncOpenAILLM(
            model_path=model_path, 
            batch_size=batch_size, 
            requests_per_minute=requests_per_minute
        )
        tokenizer = None
    
    else:
        
        if model_path in [
            'Qwen/Qwen2.5-14B-Instruct',
            'Qwen/Qwen2.5-32B-Instruct'
        ]:
            tensor_parallel_size = 2
        else:
            tensor_parallel_size = 1
        
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        if model_path.endswith("AWQ"):
            target_model = LocalVLLM(model_path, tensor_parallel_size=tensor_parallel_size, quantization="AWQ")
        elif model_path.endswith("GPTQ"):
            target_model = LocalVLLM(model_path, tensor_parallel_size=tensor_parallel_size, quantization="GPTQ")
        else:
            target_model = LocalVLLM(model_path, tensor_parallel_size=tensor_parallel_size)

    if hasattr(target_model, 'validate_vllm'):
        params = {
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "repetition_penalty": 1.03,
            "best_of": 1,
            #"use_tqdm": True
        }
        if model_path in [
            'simplescaling/s1.1-1.5B',
            'simplescaling/s1.1-3B',
            'simplescaling/s1.1-7B'
        ]:
            stop_token_ids = tokenizer("<|im_end|>")["input_ids"]

            params = {
                'max_tokens': 32768,
                'min_tokens': 0,
                'stop_token_ids': stop_token_ids,
                'skip_special_tokens': False,
                "temperature": temperature,
            }
        elif "deepseek-r1-distill-qwen" in model_path.lower():
            
            params = {
                "max_tokens": max_tokens,
                "temperature": temperature,
                "top_p": top_p,
                #"reasoning_effort": reasoning_effort,
            }

    elif hasattr(target_model, 'validate_openai'):
        #max_tokens = 65536
        if model_path in ['openai/o1-mini']:
            params = {
                "max_completion_tokens": max_tokens,
            }
        elif model_path in [
            'openai/o1', 'openai/o3', 'openai/o3-mini', 'openai/o4-mini', 'openai/o4'
        ]:
            params = {
                "max_completion_tokens": max_tokens,
                "reasoning_effort": reasoning_effort,
                #"reasoning_summary": "auto"
                # "temperature": temperature,
                # "top_p": top_p
            }
        elif 'deepseek-r1' in model_path:
            params = {
                "max_tokens": max_tokens,
                "temperature": temperature,
                "top_p": top_p,
                "reasoning_effort": reasoning_effort,
            }
        elif 'qwq-32b' in model_path:
            params = {
                "max_tokens": max_tokens,
                "temperature": temperature,
                "top_p": top_p,
                "reasoning_effort": reasoning_effort,
            }
        else:
            
            params = {
                "max_completion_tokens": max_tokens,
                "temperature": temperature,
                "top_p": top_p,
                "reasoning_effort": reasoning_effort,
            }
            
            
    else:
        params = {
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p
        }

    return target_model, params, tokenizer


def get_next_turn_user_prompt(cfg: DictConfig, instance: dict) -> str:
    
    if cfg.experimental_setup.checklist_type == 'default':
        return NEXT_TURN_EVALUATION_DEFAULT_TEMPLATE
    elif cfg.experimental_setup.checklist_type == 'default_w_criteria':
        checklist = instance['checklist']

        next_turn_prompt = NEXT_TURN_EVALUATION_DEFAULT_TEMPLATE_W_CHECKLIST.format(checklist='\n- '.join(checklist))
        return next_turn_prompt
    
    elif cfg.experimental_setup.checklist_type == 'full_feedback':
        prev_result = instance[f'{cfg.experimental_setup.target_turn_num-1}_evaluation']
        if prev_result is None:
            prev_result = {str(i+1): 'No' for i, _ in enumerate(instance['checklist'])}
        feedbacks = []
        for idx, (checklist_idx, decision) in enumerate(prev_result.items()):
            assert instance['checklist'][idx] == instance['checklist'][int(checklist_idx)-1]
            if decision.lower().strip() == 'no':
                feedbacks.append(instance['checklist'][idx])
        
        next_turn_prompt = NEXT_TURN_EVALUATION_TEMPLATE.format(
            feedbacks='\n'.join(['- {}'.format(convert_question_to_instruction(ele)) 
                                        for j, ele in enumerate(feedbacks)])
        )
        return next_turn_prompt
    elif cfg.experimental_setup.checklist_type == 'partial_feedback':
        unknown_feedback_num = int(len(instance['checklist']) * cfg.experimental_setup.unknown_ratio)
        next_turn_prompt = NEXT_TURN_PARTIAL_FEEDBACK_EVALUATION_TEMPLATE.format(
            feedbacks='\n'.join(['- {}'.format(convert_question_to_instruction(ele)) 
                                        for j, ele in enumerate(instance['checklist'][:unknown_feedback_num])])
        )
        return next_turn_prompt
    else:
        raise ValueError('wrong')


def build_messages_for_turn(cfg, 
                            instance, 
                            turn_num,
                            tokenizer=None,
                            max_token_num=None):

    passages = instance.get('passages', [])
    question = instance.get('question', "")
    materials = instance['materials']

    if turn_num == 1:
        user_prompt = []
        if passages:
            user_prompt.append('\n\n'.join(passages))
        
        if materials:
            user_prompt.append('\n\n'.join([e['content'] for e in materials]))

        user_prompt.append(question)
        
        user_prompt = '\n\n'.join(user_prompt)
        
        
        if cfg.agent.model_path in REASONING_MODEL_NAMES:
            if cfg.agent.model_path in ['openai/o1-mini']:
                messages = [
                    {"role": "user", "content": user_prompt},
                ]
            else:
                messages = [
                    {"role": "developer", "content": "You are a helpful assistant."},
                    {"role": "user", "content": user_prompt},
                ]
        elif cfg.agent.model_path in [
                'simplescaling/s1.1-1.5B',
                'simplescaling/s1.1-3B',
                'simplescaling/s1.1-7B'
            ]:
            messages = [
                {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
                {"role": "user", "content": user_prompt},
            ]
        elif "deepseek-r1-distill-qwen" in cfg.agent.model_path.lower():
            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": f'{user_prompt} <think>\n'},
            ]
        else:
            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": user_prompt},
            ]
        return messages, user_prompt

    if cfg.agent.model_path in REASONING_MODEL_NAMES:
        if cfg.agent.model_path in ['openai/o1-mini']:
            messages = []
        else:
            messages = [
                {"role": "developer", "content": "You are a helpful assistant."},
            ]
    elif cfg.agent.model_path in [
        'simplescaling/s1.1-1.5B',
        'simplescaling/s1.1-3B',
        'simplescaling/s1.1-7B'
    ]:
        messages = [
            {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}
        ]
    else:
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
        ]

    for prev_turn in range(1, turn_num):
        user_input_key = f"{prev_turn}_input"
        assistant_gen_key = f"{prev_turn}_generation"

        user_input = instance.get(user_input_key, "")
        assistant_gen = instance.get(assistant_gen_key, "")

        if not user_input:
            assert False
        
        if cfg.experimental_setup.add_thinking_to_next_turn:
            reasoning_gen = instance[f'{prev_turn}_reasoning_output']
            assistant_gen = f'{reasoning_gen}\n{assistant_gen}'

        messages.append({"role": "user", "content": user_input})
        messages.append({"role": "assistant", "content": assistant_gen})

    user_prompt = get_next_turn_user_prompt(cfg, instance)
    
    if "deepseek-r1-distill-qwen" in cfg.agent.model_path.lower():
        messages.append({"role": "user", "content": f'{user_prompt} <think>\n'})
    else:
        messages.append({"role": "user", "content": user_prompt})

    if cfg.experimental_setup.add_thinking_to_next_turn:
        #print(messages)
        prompt_tokens = tokenizer.apply_chat_template(messages)
        if len(prompt_tokens) > max_token_num:
            return 'Do not process', 'Do not process'
        

    return messages, user_prompt

async def run_generation(cfg: DictConfig, 
                         refinebench: list, 
                         output_dir: str) -> list:
    target_model, params, tokenizer = load_models(
        cfg.agent.model_path,
        max_tokens=cfg.agent.max_tokens,
        temperature=cfg.agent.temperature,
        top_p=cfg.agent.top_p,
        requests_per_minute=cfg.agent.requests_per_minute,
        batch_size=cfg.agent.batch_size,
        reasoning_effort=cfg.experimental_setup.reasoning_effort
    )
    max_token_num = None
    if cfg.experimental_setup.add_thinking_to_next_turn:
        tokenizer_path = {
            'deepseek/deepseek-r1-distill-qwen-14b': 'deepseek-ai/DeepSeek-R1-Distill-Qwen-14B',
            'deepseek/deepseek-r1-distill-qwen-32b': 'deepseek-ai/DeepSeek-R1-Distill-Qwen-32B',
            'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B': 'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B',
            'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B': 'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B',
            'deepseek/deepseek-r1': 'deepseek-ai/DeepSeek-R1'
        }
        tok_name = tokenizer_path[cfg.agent.model_path]
        config = AutoConfig.from_pretrained(tok_name)
        tokenizer = AutoTokenizer.from_pretrained(tok_name)
        
        max_token_num = config.max_position_embeddings

    turn_num = cfg.experimental_setup.target_turn_num
    records, inputs = [], []

    for instance in tqdm(refinebench, desc="Generation", total=len(refinebench)):
        
        if turn_num > 1 and cfg.experimental_setup.checklist_type != 'full_feedback':
            if '[TERMINATE]' in instance[f'{turn_num-1}_generation']:
                continue
        
        if cfg.experimental_setup.checklist_type == 'full_feedback' and turn_num > 1:
            completion_flag = check_all_checklist_completion(instance[f'{turn_num-1}_evaluation'])
            if completion_flag: # terminate
                continue
        
        cp_instance = instance.copy()
        messages, user_prompt = build_messages_for_turn(cfg, instance, turn_num, tokenizer=tokenizer, max_token_num=max_token_num)
        if user_prompt == 'Do not process':
            print('Maximum length overflow!')
            continue

        cp_instance[f'{turn_num}_input'] = user_prompt

        records.append(cp_instance)
        inputs.append(messages)

    # if len(records) == len(inputs) == 0:
    #     sys.exit(0)
    
    if hasattr(target_model, "validate_vllm"):
        logger.info("Using VLLM-based model (synchronous completions).")
        outputs = target_model.completions(inputs, **params)
    else:
        logger.info("Using async model completions.")
        outputs = await target_model.completions(inputs, **params)

    assert len(inputs) == len(outputs), "Input/Output length mismatch."

    generation_results = []
    for record, output in zip(records, outputs):
        turn_num = cfg.experimental_setup.target_turn_num
        cp_record = record.copy()
        if cfg.agent.model_path in REASONING_MODEL_NAMES:
            cp_record[f'{turn_num}_generation'] = output['output']
            cp_record[f'{turn_num}_prompt_tokens'] = output['prompt_tokens']
            cp_record[f'{turn_num}_completion_tokens'] = output['completion_tokens']
            cp_record[f'{turn_num}_reasoning_tokens'] = output['reasoning_tokens']
        elif cfg.agent.model_path in ['xai/grok-3-mini-beta']:
            cp_record[f'{turn_num}_generation'] = output['output']
            cp_record[f'{turn_num}_prompt_tokens'] = output['prompt_tokens']
            cp_record[f'{turn_num}_completion_tokens'] = output['completion_tokens']
            cp_record[f'{turn_num}_reasoning_tokens'] = output['reasoning_tokens']
            cp_record[f'{turn_num}_reasoning_output'] = output['reasoning_output']
        elif cfg.agent.model_path in ["qwen/qwq-32b"]:
            try:
                cp_record[f'{turn_num}_generation'] = output['output']
                cp_record[f'{turn_num}_prompt_tokens'] = output['prompt_tokens']
                cp_record[f'{turn_num}_completion_tokens'] = output['completion_tokens']
                cp_record[f'{turn_num}_reasoning_output'] = output['reasoning_output']
            except:
                with open('qwq32b.txt', 'w') as f:
                    f.write(output)

                cp_record[f'{turn_num}_generation'] = output
                cp_record[f'{turn_num}_prompt_tokens'] = 0
                cp_record[f'{turn_num}_completion_tokens'] = 0
                cp_record[f'{turn_num}_reasoning_output'] = ''

        elif cfg.agent.model_path in [
            'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B',
            'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B'
        ]:
            if output is not None:
                try:
                    reasoning, final_answer = output.split('</think>')
                    cp_record[f'{turn_num}_generation'] = final_answer.strip()
                    cp_record[f'{turn_num}_reasoning_output'] = reasoning.strip()
                except:
                    with open('r1-distill.txt', 'w') as f:
                        f.write(output)
                    cp_record[f'{turn_num}_generation'] = output.strip()
                    cp_record[f'{turn_num}_reasoning_output'] = ''
                
                
            else:
                cp_record[f'{turn_num}_generation'] = output
                cp_record[f'{turn_num}_reasoning_output'] = ''


        elif cfg.agent.model_path in [
            'deepseek/deepseek-r1',
            'deepseek/deepseek-r1-distill-qwen-14b',
            'deepseek/deepseek-r1-distill-qwen-32b'
        ]:
            try:
                cp_record[f'{turn_num}_generation'] = output['output']
                cp_record[f'{turn_num}_prompt_tokens'] = output['prompt_tokens']
                cp_record[f'{turn_num}_completion_tokens'] = output['completion_tokens']
                cp_record[f'{turn_num}_reasoning_output'] = output['reasoning_output']
            except:
                with open('deepseek.txt', 'w') as f:
                    f.write(output)

                cp_record[f'{turn_num}_generation'] = output
                cp_record[f'{turn_num}_prompt_tokens'] = 0
                cp_record[f'{turn_num}_completion_tokens'] = 0
                cp_record[f'{turn_num}_reasoning_output'] = ''
        elif cfg.agent.model_path in [
            'simplescaling/s1.1-1.5B',
            'simplescaling/s1.1-3B',
            'simplescaling/s1.1-7B'
        ]:
            if output is not None:
                try:
                    reasoning, final_answer = output.split('<|im_start|>answer')
                    cp_record[f'{turn_num}_generation'] = final_answer.strip()
                    cp_record[f'{turn_num}_reasoning_output'] = reasoning.replace('<|im_start|>think', '').strip()
                except:
                    with open('s1_output.txt', 'w') as f:
                        f.write(output)
                    cp_record[f'{turn_num}_generation'] = output.strip()
                    cp_record[f'{turn_num}_reasoning_output'] = ''
                
            else:
                cp_record[f'{turn_num}_generation'] = output
                cp_record[f'{turn_num}_reasoning_output'] = ''
        else:
            cp_record[f'{turn_num}_generation'] = output
        generation_results.append(cp_record)

    gen_file_name = 'generation_results.json'
    gen_save_path = os.path.join(output_dir, gen_file_name)
    dump_json_output(generation_results, save_path=gen_save_path)
    logger.info(f"Saved generation results to {gen_save_path}")

    return generation_results


def apply_template_for_evaluation(instance: dict, cfg: DictConfig) -> list:
    passages = instance.get('passages', [])
    question = instance.get('question', "")
    turn_num = cfg.experimental_setup.target_turn_num

    model_answer = instance.get(f'{turn_num}_generation', "")
    checklist = instance.get('checklist', [])

    if not checklist:
        logger.warning("No checklist found in instance. Proceeding with empty checklist.")
        checklist_str = ""
    else:
        checklist_str = '\n'.join([f"- Q{i+1}: {item}" for i, item in enumerate(checklist)])

    if passages:
        query_prompt = INITIAL_PASSAGES_QUESTION_PROMPT_TEMPLATE.format(
            passages='\n\n'.join(passages),
            question=question
        )
    else:
        query_prompt = INITIAL_QUESTION_PROMPT_TEMPLATE.format(question=question)

    prompt = CHECKLIST_EVALUATION_TEMPLATE.format(
        query=query_prompt,
        model_answer=model_answer,
        checklist=checklist_str
    )

    messages = [
        {"role": "system", "content": "You are a helpful assistant and an excellent evaluator."},
        {"role": "user", "content": prompt},
    ]
    return messages


async def run_evaluation(cfg: DictConfig, 
                         generation_results: list, 
                         output_dir: str):
    evaluator_model, params, tokenizer = load_models(
        cfg.evaluator.model_path,
        max_tokens=cfg.evaluator.max_tokens,
        temperature=cfg.evaluator.temperature,
        top_p=cfg.evaluator.top_p,
        requests_per_minute=cfg.evaluator.requests_per_minute,
        batch_size=cfg.evaluator.batch_size,
    )

    records, inputs = [], []
    for instance in tqdm(generation_results, desc="Evaluation", total=len(generation_results)):
        if cfg.experimental_setup.target_turn_num > 1 and cfg.experimental_setup.checklist_type == 'full_feedback':
            if instance[f'{cfg.experimental_setup.target_turn_num-1}_generation'] in ['[TERMINATE]', 'TERMINATE']:
                continue
        records.append(instance)
        inputs.append(apply_template_for_evaluation(instance, cfg))

    if len(records) == len(inputs) == 0:
        sys.exit(0)
    
    if hasattr(evaluator_model, "validate_vllm"):
        logger.info("Using VLLM-based evaluator (synchronous).")
        outputs = evaluator_model.completions(inputs, **params)
    else:
        logger.info("Using async evaluator completions.")
        outputs = await evaluator_model.completions(inputs, **params)

    assert len(inputs) == len(outputs), "Input/Output length mismatch in evaluation."

    # 실패한 케이스 재시도 로직
    to_retry_inputs = []
    to_retry_indices = []
    for i, (record, output) in enumerate(zip(records, outputs)):
        if cfg.evaluator.model_path in [
            'qwen/qwen3-32b'
        ]:
            parsed_result = parse_evaluation_results(output['output'], len(record.get('checklist', [])))
        else:
            parsed_result = parse_evaluation_results(output, len(record.get('checklist', [])))
        if parsed_result is None:
            to_retry_inputs.append(record)
            to_retry_indices.append(i)

    retries = 0
    max_retries = 3 
    while to_retry_inputs and retries < max_retries:
        retries += 1
        logger.warning(
            f"Retrying failed instances: Attempt {retries}/{max_retries}, "
            f"# of failed instances: {len(to_retry_inputs)}"
        )
        _to_retry_messages = [apply_template_for_evaluation(item, cfg) for item in to_retry_inputs]

        if hasattr(evaluator_model, "validate_vllm"):
            to_retry_outputs = evaluator_model.completions(_to_retry_messages, **params)
        else:
            to_retry_outputs = await evaluator_model.completions(_to_retry_messages, **params)

        assert len(_to_retry_messages) == len(to_retry_outputs), "Retry input/output mismatch."

        new_to_retry_inputs = []
        new_to_retry_indices = []
        for idx, (retry_idx, retry_output) in enumerate(zip(to_retry_indices, to_retry_outputs)):
            if cfg.evaluator.model_path in [
                'qwen/qwen3-32b'
            ]:
                parsed_result = parse_evaluation_results(
                    retry_output['output'], 
                    len(records[retry_idx].get('checklist', []))
                )
            else:
                parsed_result = parse_evaluation_results(
                    retry_output, 
                    len(records[retry_idx].get('checklist', []))
                )
            if parsed_result is None:
                new_to_retry_inputs.append(to_retry_inputs[idx])
                new_to_retry_indices.append(to_retry_indices[idx])
            else:
                outputs[retry_idx] = retry_output

        to_retry_inputs = new_to_retry_inputs
        to_retry_indices = new_to_retry_indices

    if to_retry_inputs:
        warnings.warn(
            f"Failed to generate explanation for {len(to_retry_inputs)} instances after {max_retries} retries."
        )

    final_result = []
    turn_num = cfg.experimental_setup.target_turn_num
    for record, output in zip(records, outputs):
        if cfg.evaluator.model_path in [
            'qwen/qwen3-32b'
        ]:
            parsed_result = parse_evaluation_results(
                output['output'], 
                len(record.get('checklist', []))
            )
        else:
            parsed_result = parse_evaluation_results(
                output, 
                len(record.get('checklist', []))
            )
        if parsed_result is None:
            parsed_result = {str(i+1): 'No' for i, _ in enumerate(record['checklist'])}
            
        cp_record = record.copy()
        cp_record[f'{turn_num}_evaluation'] = parsed_result
        if cfg.evaluator.model_path in [
            'qwen/qwen3-32b'
        ]:
            cp_record[f'{turn_num}_evaluation_reasoning_output'] = output['reasoning_output']
        final_result.append(cp_record)

    eval_file_name = 'evaluation_results.json'
    eval_save_path = os.path.join(output_dir, eval_file_name)
    dump_json_output(final_result, save_path=eval_save_path)
    logger.info(f"Saved evaluation results to {eval_save_path}")


async def evaluate_refinebench_async(cfg: DictConfig):
    seed = cfg.experimental_setup.seed
    random.seed(seed)
    np.random.seed(seed)

    data_path = cfg.dataset.data_path
    if cfg.experimental_setup.target_turn_num > 1:
        if cfg.experimental_setup.target_turn_num > 2:
            data_path = os.path.join(
                EVAL_OUTPUT_PATH,
                os.path.splitext(os.path.basename(cfg.dataset.data_path))[0],
                cfg.experimental_setup.checklist_type,
                str(seed),
                f'{cfg.agent.temperature}_{cfg.agent.top_p}',
                f'{cfg.evaluator.model_name}',
                f'{cfg.agent.model_name}',
                f'{cfg.experimental_setup.target_turn_num-1}'
            )
            if cfg.experimental_setup.checklist_type == 'partial_feedback':
                data_path = os.path.join(
                    EVAL_OUTPUT_PATH,
                    os.path.splitext(os.path.basename(cfg.dataset.data_path))[0],
                    f'{cfg.experimental_setup.checklist_type}_{cfg.experimental_setup.unknown_ratio}',
                    str(seed),
                    f'{cfg.agent.temperature}_{cfg.agent.top_p}',
                    f'{cfg.evaluator.model_name}',
                    f'{cfg.agent.model_name}',
                    f'{cfg.experimental_setup.target_turn_num-1}'
                )
                
            if cfg.agent.model_path in [
                "openai/o1-mini",
                "openai/o1",
                "openai/o3",
                "openai/o3-mini",
                "openrouter/openai/o1-mini",
                "openrouter/openai/o1",
                "openrouter/openai/o3",
                "openrouter/openai/o3-mini",
                "deepseek/deepseek-r1",
                "qwen/qwq-32b",
                "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
                "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
                "deepseek/deepseek-r1-distill-qwen-14b",
                "deepseek/deepseek-r1-distill-qwen-32b"
            ]:
                data_path = os.path.join(
                    data_path,
                    f'{cfg.experimental_setup.reasoning_effort}'
                )
        elif cfg.experimental_setup.target_turn_num == 2:
            data_path = os.path.join(
                EVAL_OUTPUT_PATH,
                os.path.splitext(os.path.basename(cfg.dataset.data_path))[0],
                'single_turn',
                str(seed),
                f'{cfg.agent.temperature}_{cfg.agent.top_p}',
                f'{cfg.evaluator.model_name}',
                f'{cfg.agent.model_name}',
                f'{cfg.experimental_setup.target_turn_num-1}'
            )
            if cfg.agent.model_path in [
                "openai/o1-mini",
                "openai/o1",
                "openai/o3",
                "openai/o3-mini",
                "openrouter/openai/o1-mini",
                "openrouter/openai/o1",
                "openrouter/openai/o3",
                "openrouter/openai/o3-mini",
                "deepseek/deepseek-r1",
                "qwen/qwq-32b",
                "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
                "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
                "deepseek/deepseek-r1-distill-qwen-14b",
                "deepseek/deepseek-r1-distill-qwen-32b"
            ]:
                data_path = os.path.join(
                    data_path,
                    f'{cfg.experimental_setup.reasoning_effort}'
                )
        
        data_path = os.path.join(data_path, 'evaluation_results.json')
    
    refinebench = load_refinebench(data_path)
    if cfg.experimental_setup.debug:
        refinebench = refinebench[:2]
    logger.info(f"[INFO] refinebench length: {len(refinebench)}")

    output_dir = os.path.join(
        EVAL_OUTPUT_PATH,
        os.path.splitext(os.path.basename(cfg.dataset.data_path))[0],
        cfg.experimental_setup.checklist_type,
        str(seed),
        f'{cfg.agent.temperature}_{cfg.agent.top_p}',
        f'{cfg.evaluator.model_name}',
        f'{cfg.agent.model_name}',
        f'{cfg.experimental_setup.target_turn_num}'
    )
    if cfg.experimental_setup.checklist_type == 'partial_feedback':
        output_dir = os.path.join(
            EVAL_OUTPUT_PATH,
            os.path.splitext(os.path.basename(cfg.dataset.data_path))[0],
            f'{cfg.experimental_setup.checklist_type}_{cfg.experimental_setup.unknown_ratio}',
            str(seed),
            f'{cfg.agent.temperature}_{cfg.agent.top_p}',
            f'{cfg.evaluator.model_name}',
            f'{cfg.agent.model_name}',
            f'{cfg.experimental_setup.target_turn_num}'
        )
    if cfg.agent.model_path in [
        "openai/o1-mini",
        "openai/o1",
        "openai/o3",
        "openai/o3-mini",
        "openrouter/openai/o1-mini",
        "openrouter/openai/o1",
        "openrouter/openai/o3",
        "openrouter/openai/o3-mini",
        "deepseek/deepseek-r1",
        "qwen/qwq-32b",
        "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
        "deepseek/deepseek-r1-distill-qwen-14b",
        "deepseek/deepseek-r1-distill-qwen-32b"
    ]:
        output_dir = os.path.join(
            output_dir, cfg.experimental_setup.reasoning_effort)
        assert cfg.experimental_setup.reasoning_effort in [
            'low', 'medium', 'high'
        ]

    if cfg.experimental_setup.add_thinking_to_next_turn:
        output_dir = os.path.join(output_dir, 'add_thinking')
    os.makedirs(output_dir, exist_ok=True)

    generation_results = await run_generation(cfg, refinebench, output_dir)
    
    await run_evaluation(cfg, generation_results, output_dir)


@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: DictConfig):
    load_dotenv()
    asyncio.run(evaluate_refinebench_async(cfg))


if __name__ == "__main__":
    main()
