from enum import Enum
from typing import Optional, Dict, Union, NamedTuple
import re, copy
from dataclasses import dataclass

from vllm import LLM, SamplingParams
from typing import List, Tuple, Set
import time
import torch
from tqdm import tqdm

import json
import os
import random
from typing import List, Dict, Any
from pathlib import Path
from itertools import combinations, permutations
import pandas as pd
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from tqdm import tqdm
from dataclasses import dataclass, asdict
import numpy as np
from collections import defaultdict
from IPython import embed

random.seed(42)

@dataclass
class SingleAnnotation:
    score: int
    explanation: Optional[str] = None
    
@dataclass
class PairAnnotation:
    score_a: int
    score_b: int
    explanation: Optional[str] = None

class AnnotationStrategy(Enum):
    SINGLE_RM = "single_rm" # Single response, no guidelines, no explanation, use Reward Model to score
    SINGLE_BASIC = "single_basic"  # Single response, no guidelines, no explanation
    SINGLE_GUIDED_EXPLAINED = "single_guided_explained" # Single response, with guidelines, no explanation
    SINGLE_GUIDED_EXPLAINED_FINE_GRAINED = "single_guided_explained_fine_grained" # Single response, with fine-grained preferences questions and explanation
    SINGLE_BASIC_PROB = "single_basic_prob" # Single response, no guidelines, no explanation, use probability to score
    SINGLE_BASIC_MAJORITY = "single_basic_majority" # Single response, no guidelines, no explanation, use majority vote to score
    PAIR_BASIC = "pair_basic"      # Pair comparison, no guidelines, no explanation
    PAIR_EXPLAINED = "pair_explained"  # Pair comparison, no guidelines, with explanation
    PAIR_GUIDED = "pair_guided"    # Pair comparison, with guidelines, no explanation
    PAIR_GUIDED_EXPLAINED = "pair_guided_explained"  # Pair comparison, with guidelines and explanation
    PAIR_GUIDED_EXPLAINED_FINE_GRAINED = "pair_guided_explained_fine_grained" # Pair comparison, with fine-grained preferences questions and explanation

class ResponseEvaluator:
    def __init__(self, strategy: AnnotationStrategy):
        self.strategy = strategy
        self.scoring_guidelines = """
Scoring Guidelines:
- 8-9: Exceptional response that excels in all aspects
- 6-7: Strong response with minor room for improvement
- 4-5: Adequate response with some notable gaps
- 2-3: Poor response with significant issues
- 0-1: Severely inadequate or irrelevant response
"""
    def generate_prompt(self, 
                       instruction: str, 
                       system_prompt: str,
                       dialogue_history: str, 
                       response_a: str,
                       preference_questions: Optional[List[str]] = None,
                       response_b: Optional[str] = None) -> str:
        if self.strategy == AnnotationStrategy.SINGLE_BASIC or self.strategy == AnnotationStrategy.SINGLE_BASIC_PROB or self.strategy == AnnotationStrategy.SINGLE_BASIC_MAJORITY or self.strategy == AnnotationStrategy.SINGLE_GUIDED_EXPLAINED or self.strategy == AnnotationStrategy.SINGLE_GUIDED_EXPLAINED_FINE_GRAINED:
            return self._generate_single_prompt(instruction, system_prompt, dialogue_history, preference_questions, response_a)
        else:
            return self._generate_pair_prompt(instruction, dialogue_history, preference_questions, response_a, response_b)

    def _generate_single_prompt(self, 
                              instruction: str, 
                              system_prompt: str,
                              dialogue_history: str,
                              preference_questions: List[str],
                              response: str) -> str:
        # TODO: use history or system prompt

# ## System Prompt
# <|begin_system_prompt|>
# {"None" if system_prompt == "" else system_prompt}
# <|end_system_prompt|>

        base_prompt = f"""You are an expert evaluator tasked with providing a simple numerical score for the response given the conversation history and user query.

## Conversation History
<|begin_history|>
{"None" if dialogue_history == "" else dialogue_history}
<|end_history|>

## Current User Query
<|begin_query|>
{instruction}
<|end_query|>

## Response to Evaluate
<|begin_response|>
{response}
<|end_response|>
"""
        if self.strategy == AnnotationStrategy.SINGLE_BASIC or self.strategy == AnnotationStrategy.SINGLE_BASIC_PROB or self.strategy == AnnotationStrategy.SINGLE_BASIC_MAJORITY:
             return base_prompt + """
Please provide your evaluation by directly scoring the overall quality of the response from 0 to 9 in the following format exactly, where 0 is the worst and 9 is the best.

Your response MUST follow this exact format:
SCORE: [0-9]
"""

        elif self.strategy == AnnotationStrategy.SINGLE_GUIDED_EXPLAINED:
            return base_prompt + f"""
{self.scoring_guidelines}

Please provide your evaluation by first point out the pros and cons of the response without polite phrases as short as you can. Then rate the overall quality of the response from 0 to 9, with 0 being worst and 9 being best.

Your response MUST follow this exact format:
EXPLANATION:
[Your detailed evaluation based on the scoring guidelines]

SCORE: [0-9]
"""

        elif self.strategy == AnnotationStrategy.SINGLE_GUIDED_EXPLAINED_FINE_GRAINED:
            questions_prompt = "\n".join([f"{i+1}. {q}" for i, q in enumerate(preference_questions)])
            explanation_prompt = "\n".join([f"{i+1}. [Yes/No] - [Brief explanation]" for i in range(len(preference_questions))])
            return base_prompt + f"""
Please evaluate the response by first answering the following task-specific preference questions with Yes/No followed by a brief explanation. Then rate the overall quality of the responses from 0 to 9, with 0 being worst and 9 being best. Avoid polite phrases and be as concise as possible. The order of responses should not affect your judgment.

## Task-specific Preference Questions
{questions_prompt}

Your response MUST follow this exact format:

EXPLANATION:
{explanation_prompt}

SCORE: [0-9]

Your scoring should be based on how many task-specific questions were answered with "Yes" and the quality of fulfillment for each criterion."""

    def _generate_pair_prompt(self, 
                            instruction: str, 
                            dialogue_history: str, 
                            preference_questions: List[str],
                            response_a: str, 
                            response_b: str) -> str:
        
        base_prompt = f"""You are an expert evaluator tasked with comparing two responses to the same query.

## Conversation History
<|begin_history|>
{"None" if dialogue_history == "" else dialogue_history}
<|end_history|>

## Current User Query
<|begin_query|>
{instruction}
<|end_query|>

## Response A
<|begin_response_a|>
{response_a}
<|end_response_a|>

## Response B
<|begin_response_b|>
{response_b}
<|end_response_b|>
"""

        if self.strategy == AnnotationStrategy.PAIR_BASIC:
            return base_prompt + """
Please provide your evaluation by directly scoring the overall quality of the responses from 0 to 9 in the following format exactly, where 0 is the worst and 9 is the best.

Your response MUST follow this exact format:
SCORE_A: [0-9]
SCORE_B: [0-9]
"""
        
        elif self.strategy == AnnotationStrategy.PAIR_EXPLAINED:
            return base_prompt + """
Please provide your evaluation by first point out the pros and cons of both responses without polite phrases as short as you can, ensuring that the order of the responses does not affect your judgment. Then rate the overall quality of the responses from 0 to 9, with 0 being worst and 9 being best.

Your response MUST follow this exact format:
EXPLANATION:
[Your detailed comparison of both responses]

SCORE_A: [0-9]
SCORE_B: [0-9]
"""
        
        elif self.strategy == AnnotationStrategy.PAIR_GUIDED:
            return base_prompt + f"""
{self.scoring_guidelines}

Please provide your evaluation by directly scoring the overall quality of the responses from 0 to 9 in the following format exactly, where 0 is the worst and 9 is the best.

Your response MUST follow this exact format:
SCORE_A: [0-9]
SCORE_B: [0-9]
"""
        
        elif self.strategy == AnnotationStrategy.PAIR_GUIDED_EXPLAINED:  # PAIR_GUIDED_EXPLAINED
            return base_prompt + f"""
{self.scoring_guidelines}

Please provide your evaluation by first point out the pros and cons of both responses without polite phrases as short as you can, ensuring that the order of the responses does not affect your judgment. Then rate the overall quality of the responses from 0 to 9, with 0 being worst and 9 being best.

Your response MUST follow this exact format:
EXPLANATION:
[Your detailed evaluation based on the scoring guidelines]

SCORE_A: [0-9]
SCORE_B: [0-9]
"""

        elif self.strategy == AnnotationStrategy.PAIR_GUIDED_EXPLAINED_FINE_GRAINED:
            questions_prompt = "\n".join([f"{i+1}. {q}" for i, q in enumerate(preference_questions)])
            explanation_prompt = "\n".join([f"{i+1}. [Yes/No] - [Brief explanation]" for i in range(len(preference_questions))])
            return base_prompt + f"""
Please evaluate both responses by first answering the following task-specific preference questions with Yes/No followed by a brief explanation. Then rate the overall quality of the responses from 0 to 9, with 0 being worst and 9 being best. Avoid polite phrases and be as concise as possible. The order of responses should not affect your judgment.

## Task-specific Preference Questions
{questions_prompt}

Your response MUST follow this exact format:

EXPLANATION:
Response A:
{explanation_prompt}

Response B:
{explanation_prompt}

[Overall comparison explanation considering the fine-grained evaluation results]

SCORE_A: [0-9]
SCORE_B: [0-9]

Your scoring should be based on how many task-specific questions were answered with "Yes" and the quality of fulfillment for each criterion."""
        else:
            raise ValueError(f"Unsupported annotation strategy: {self.strategy.value}")

    def extract_annotation(self, model_output: str) -> Union[SingleAnnotation, PairAnnotation]:
        """Extract structured annotation from model output."""
        if self.strategy == AnnotationStrategy.SINGLE_BASIC or self.strategy == AnnotationStrategy.SINGLE_BASIC_PROB or self.strategy == AnnotationStrategy.SINGLE_BASIC_MAJORITY or self.strategy == AnnotationStrategy.SINGLE_GUIDED_EXPLAINED or self.strategy == AnnotationStrategy.SINGLE_GUIDED_EXPLAINED_FINE_GRAINED:
            return self._extract_single_annotation(model_output)
        else:
            return self._extract_pair_annotation(model_output)

    def _extract_single_annotation(self, model_output: str) -> SingleAnnotation:
        """Extract score from single response evaluation."""
        score_match = re.search(r'SCORE:\s*(\d+)', model_output)
        if not score_match:
            raise ValueError("Could not find score in model output")
        
        # Extract explanation if the strategy includes it
        explanation = None
        if self.strategy in [AnnotationStrategy.SINGLE_GUIDED_EXPLAINED, AnnotationStrategy.SINGLE_GUIDED_EXPLAINED_FINE_GRAINED]:
            explanation_match = re.search(r'EXPLANATION:\s*(.+?)(?=SCORE|$)', model_output, re.DOTALL)
            if explanation_match:
                explanation = explanation_match.group(1).strip()

        return SingleAnnotation(score=int(score_match.group(1)), explanation=explanation)

    def _extract_pair_annotation(self, model_output: str) -> PairAnnotation:
        """Extract scores and explanation (if applicable) from pair comparison."""
        score_a_match = re.search(r'SCORE_A:\s*(\d+)', model_output)
        score_b_match = re.search(r'SCORE_B:\s*(\d+)', model_output)
        
        if not (score_a_match and score_b_match):
            raise ValueError("Could not find scores in model output")
        
        score_a = int(score_a_match.group(1))
        score_b = int(score_b_match.group(1))
        
        # Extract explanation if the strategy includes it
        explanation = None
        if self.strategy in [AnnotationStrategy.PAIR_EXPLAINED, AnnotationStrategy.PAIR_GUIDED_EXPLAINED, AnnotationStrategy.PAIR_GUIDED_EXPLAINED_FINE_GRAINED]:
            explanation_match = re.search(r'EXPLANATION:\s*(.+?)(?=SCORE_|$)', model_output, re.DOTALL)
            if explanation_match:
                explanation = explanation_match.group(1).strip()
        
        return PairAnnotation(score_a=score_a, score_b=score_b, explanation=explanation)



@dataclass
class Instruction:
    system_prompt: str
    instruction_id: str  # original id in the json
    dialogue_history: str
    query: str
    responses: Dict[str, str]  # model_name -> response
    raw_prompt: List
    preference_questions: List

class ResponseAnnotator:
    def __init__(
        self,
        data_dir: str,
        online: bool = False,
        num_samples: int = 3,
        strategy: AnnotationStrategy = AnnotationStrategy.PAIR_GUIDED_EXPLAINED
    ):
        self.online = online
        self.dataset_name = data_dir.split("/")[-1].strip()
        self.data_dir = Path(data_dir)
        self.num_samples = num_samples
        self.strategy = strategy
        self.evaluator = ResponseEvaluator(strategy)
        
        # Load and process all data
        self.fine_grained_questions = self._load_fine_grained_questions()
        self.model_responses = self._load_all_responses()
        self.instructions = self._process_instructions()

    def _load_fine_grained_questions(self) -> Dict:
        try:
            # TODO:
            print(f"Loading from {self.dataset_name}_fine-grained.json...")
            with open(f"path/to/AIR/outputs/{self.dataset_name}_fine-grained.json", 'r', encoding='utf-8') as f:
                return json.load(f)
        except Exception as e:
            print(f"Error loading fine-grained questions: {e}")
            return []
        
    def _load_all_responses(self) -> Dict[str, Dict]:
        """Load all JSON files and organize by model name."""
        model_responses = {}
        
        for json_file in self.data_dir.glob("*.json"):
            if not self.online and "Llama-3.1-Tulu-3-8B-SFT" in str(json_file): continue
            with open(json_file, 'r') as f:
                data = json.load(f)
            
            # Assuming model name is consistent within a file
            model_name = data['data'][0]['model']
            # if model_name == "sft": continue
            
            # Organize responses by instruction ID
            responses = {item['id']: {
                'pred': item['pred'],
                'raw_prompt': item['raw_prompt']
            } for item in data['data']}
            
            if self.online and model_name in model_responses:
                i = 1
                while f"{model_name}_{i}" in model_responses: i += 1
                model_responses[f"{model_name}_{i}"] = responses
            else:
                model_responses[model_name] = responses
            
        return model_responses

    def _process_instructions(self) -> List[Instruction]:
        """Process all instructions and their responses."""
        # Get all unique instruction IDs
        instruction_ids = set()
        for model_data in self.model_responses.values():
            instruction_ids.update(model_data.keys())
        
        instructions = []
        for instr_id in instruction_ids:
            # Get raw prompt from any model (should be same for all)
            raw_prompt = None
            for model_data in self.model_responses.values():
                if instr_id in model_data:
                    raw_prompt = model_data[instr_id]['raw_prompt']
                    break
            
            if not raw_prompt:
                continue
            
            system_prompt = ""
            if self.dataset_name == "openhermes":
                history = []
                if raw_prompt[0]['role'] == "system": system_prompt = raw_prompt[0]['content']
            else:
                # Extract history and query
                if len(raw_prompt) % 2 == 0: continue # Skip even-length prompts
                history = raw_prompt[:-1]

            query = raw_prompt[-1]
            if self.strategy == AnnotationStrategy.SINGLE_GUIDED_EXPLAINED_FINE_GRAINED or self.strategy == AnnotationStrategy.PAIR_GUIDED_EXPLAINED_FINE_GRAINED:
                preference_questions = self.fine_grained_questions[instr_id].get("preference_questions", [])
            else:
                preference_questions = []
            
            # Format history into string
            history_str = ""
            for turn in history:
                role = turn['role'].capitalize()
                content = turn['content']
                history_str += f"{role}: {content}\n"
            
            # Collect responses from all models for this instruction
            responses = {}
            for model_name, model_data in self.model_responses.items():
                if instr_id in model_data:
                    responses[model_name] = model_data[instr_id]['pred']
            
            instructions.append(Instruction(
                system_prompt=system_prompt,
                instruction_id=str(instr_id),
                dialogue_history=history_str.strip(),
                preference_questions=preference_questions,
                query=query['content'],
                responses=responses,
                raw_prompt=raw_prompt
            ))
        
        return instructions

    def generate_annotation_tasks(self) -> pd.DataFrame:
        """Generate annotation tasks by sampling models for each instruction."""
        tasks = []
        
        for instruction in tqdm(self.instructions, desc="Generating tasks"):
            # Sample models
            available_models = list(instruction.responses.keys())
            if len(available_models) < self.num_samples:
                continue

            if self.online:
                # online_models = ["Llama-3.1-Tulu-3-8B-SFT"] # TODO
                online_models = [m for m in available_models if "Llama-3.1-Tulu-3-8B-SFT" in m]
                offline_models = [m for m in available_models if "Llama-3.1-Tulu-3-8B-SFT" not in m]
                sampled_models = random.sample(offline_models, self.num_samples) + online_models
            else:
                sampled_models = random.sample(available_models, self.num_samples)

            # Generate tasks based on strategy
            if self.strategy == AnnotationStrategy.SINGLE_RM or self.strategy == AnnotationStrategy.SINGLE_BASIC or self.strategy == AnnotationStrategy.SINGLE_BASIC_PROB or self.strategy == AnnotationStrategy.SINGLE_BASIC_MAJORITY or self.strategy == AnnotationStrategy.SINGLE_GUIDED_EXPLAINED or self.strategy == AnnotationStrategy.SINGLE_GUIDED_EXPLAINED_FINE_GRAINED:
                # For single evaluation, create one task per response
                for model in sampled_models:
                    task = {
                        'instruction_id': instruction.instruction_id,
                        'system_prompt': instruction.system_prompt,
                        'dialogue_history': instruction.dialogue_history,
                        'preference_questions': instruction.preference_questions,
                        'raw_prompt': instruction.raw_prompt,
                        'query': instruction.query,
                        'response_a': instruction.responses[model],
                        'response_b': None,
                        'model_a': model,
                        'model_b': None
                    }
                    tasks.append(task) # TODO
                    if self.strategy == AnnotationStrategy.SINGLE_BASIC_MAJORITY:
                        tasks.extend([task] * 2)
            else:
                # For pair comparison, create tasks for each ordered pair
                for model_a, model_b in combinations(sampled_models, 2):
                # for model_a, model_b in permutations(sampled_models, 2):
                    task = {
                        'instruction_id': instruction.instruction_id,
                        'system_prompt': instruction.system_prompt,
                        'dialogue_history': instruction.dialogue_history,
                        'preference_questions': instruction.preference_questions,
                        'raw_prompt': instruction.raw_prompt,
                        'query': instruction.query,
                        'response_a': instruction.responses[model_a],
                        'response_b': instruction.responses[model_b],
                        'model_a': model_a,
                        'model_b': model_b
                    }
                    tasks.append(task)
        
        return pd.DataFrame(tasks)

    def generate_prompts(self, tasks_df: pd.DataFrame):
        """Generate evaluation prompts for all tasks."""
        prompts = []
        
        for _, task in tasks_df.iterrows():
            if self.strategy == AnnotationStrategy.SINGLE_RM:
                prompt = copy.deepcopy(task["raw_prompt"])
                prompt.append({"role": "assistant", "content": task['response_a']})
            else:
                prompt = self.evaluator.generate_prompt(
                    instruction=task['query'],
                    system_prompt=task['system_prompt'],
                    dialogue_history=task['dialogue_history'],
                    preference_questions=task['preference_questions'],
                    response_a=task['response_a'],
                    response_b=task['response_b']
                )
            prompts.append(prompt)
            
        return prompts

class LLMEvaluator:
    def __init__(self, model_path: str, tensor_parallel_size: int = 1):
        if "Skywork-Reward-Gemma-2-27B-v0.2" in model_path:
            self.device = "cuda:0"
            self.rm = AutoModelForSequenceClassification.from_pretrained(
                model_path,
                torch_dtype=torch.bfloat16,
                device_map=self.device,
                attn_implementation="flash_attention_2",
                num_labels=1,
            )
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
            self.sampling_params = SamplingParams(
                temperature=0.7, # TODO
                max_tokens=2048,
                logprobs=100 # TODO
            )
            print(f"Sampling parameters: {self.sampling_params}")
            self.llm = LLM(
                model=model_path,
                tensor_parallel_size=tensor_parallel_size,
                max_logprobs=100, # TODO
            )
        
    def apply_chat_template(self, prompt: str) -> str:
        messages = [{"role": "user", "content": prompt}]
        return self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
    
    def evaluate_prompts(self, 
                        tasks_df: pd.DataFrame, 
                        prompts,
                        strategy,
                        evaluator: ResponseEvaluator,
                        max_retries: int = 10):
        """Evaluate prompts and save results to JSON file."""
        
        results = []
        
        if strategy == AnnotationStrategy.SINGLE_RM:
            bs = 8
            num_batchs = len(prompts) // bs + 1
            with torch.no_grad():
                for i in tqdm(range(len(prompts))):
                    conv_tokenized = self.tokenizer.apply_chat_template(prompts[i], tokenize=True, return_tensors="pt").to(self.device)
                    score = self.rm(conv_tokenized).logits[0][0].item()
                    current_idx = i
                    result = {
                        'instruction_id': tasks_df.iloc[current_idx]['instruction_id'],
                        'raw_prompt': tasks_df.iloc[current_idx]['raw_prompt'],
                        'model_a': tasks_df.iloc[current_idx]['model_a'],
                        'response_a': tasks_df.iloc[current_idx]['response_a'],
                        'model_b': tasks_df.iloc[current_idx]['model_b'],
                        'response_b': tasks_df.iloc[current_idx]['response_b'],
                        'score': score
                    }
                    results.append(result)
                # for i in tqdm(range(num_batchs)):
                #     batch_convs = prompts[i * bs: (i + 1) * bs]
                #     if len(batch_convs) == 0: continue
                #     convs_tokenized = self.tokenizer.apply_chat_template(
                #         batch_convs, 
                #         tokenize=True, 
                #         max_length=4096,
                #         return_tensors="pt",
                #         padding="max_length"
                #     ).to(self.device)
                #     scores = self.rm(convs_tokenized).logits.squeeze().tolist()

                #     for j in range(len(scores)):
                #         current_idx = i * bs + j
                #         result = {
                #             'instruction_id': tasks_df.iloc[current_idx]['instruction_id'],
                #             'raw_prompt': tasks_df.iloc[current_idx]['raw_prompt'],
                #             'model_a': tasks_df.iloc[current_idx]['model_a'],
                #             'response_a': tasks_df.iloc[current_idx]['response_a'],
                #             'model_b': tasks_df.iloc[current_idx]['model_b'],
                #             'response_b': tasks_df.iloc[current_idx]['response_b'],
                #             'score': scores[j]
                #         }
                #         results.append(result)
            return results
        
        
        failed_indices = set(range(len(prompts)))
        current_prompts = prompts
        current_indices = list(failed_indices)
        
        for retry in range(max_retries):
            if not failed_indices:
                break
                
            print(f"\nAttempt {retry + 1}, Processing {len(failed_indices)} prompts...")
            
            # Format all prompts with chat template
            formatted_prompts = [self.apply_chat_template(p) for p in current_prompts]

            # Generate for all prompts at once
            outputs = self.llm.generate(formatted_prompts, self.sampling_params)
            logprobs = [o.outputs[0].logprobs[3] for o in outputs]
            outputs = [o.outputs[0].text for o in outputs]
            
            accepted_tokens = [str(i) for i in range(10)]

            # Process results and track failures
            new_failed_indices = set()
            
            for idx, (current_idx, output) in enumerate(zip(current_indices, outputs)):
                try:
                    annotation = evaluator.extract_annotation(output)
                    
                    result = {
                        'instruction_id': tasks_df.iloc[current_idx]['instruction_id'],
                        'system_prompt': tasks_df.iloc[current_idx]['system_prompt'],
                        'raw_prompt': tasks_df.iloc[current_idx]['raw_prompt'],
                        'model_a': tasks_df.iloc[current_idx]['model_a'],
                        'response_a': tasks_df.iloc[current_idx]['response_a'],
                        'model_b': tasks_df.iloc[current_idx]['model_b'],
                        'response_b': tasks_df.iloc[current_idx]['response_b'],
                        'output': output
                    }
                    
                    if isinstance(annotation, SingleAnnotation):
                        if evaluator.strategy == AnnotationStrategy.SINGLE_BASIC_PROB:
                            logprob = logprobs[idx]
                            result['score'] = sum(np.exp(lp.logprob) * int(lp.decoded_token) for lp in logprob.values() if lp.decoded_token in accepted_tokens)
                        else:
                            result['score'] = annotation.score
                        result['explanation'] = annotation.explanation
                    else:
                        result['score_a'] = annotation.score_a 
                        result['score_b'] = annotation.score_b
                        result['explanation'] = annotation.explanation
                    
                    results.append(result)
                    
                except Exception as e:
                    new_failed_indices.add(current_idx)
            
            # Prepare for next retry
            current_indices = list(new_failed_indices)
            current_prompts = [prompts[i] for i in current_indices]
            failed_indices = new_failed_indices
            
            print(f"Successfully processed: {len(prompts) - len(failed_indices)}")
            print(f"Failed extractions: {len(failed_indices)}")
        
        if failed_indices:
            print(f"Warning: {len(failed_indices)} prompts failed all retry attempts")
            
        return results

def main():
    # TODO: CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python src/annotate.py
    # Configuration
    dataset_name = "sharegpt_v3" # sharegpt_v3 / ultrafeedback
    data_dir = f"path/to/AIR/outputs/{dataset_name}"
    output_dir = f"path/to/AIR/data/annotation_outputs/{dataset_name}" # TODO
    strategy = AnnotationStrategy.SINGLE_GUIDED_EXPLAINED # TODO
    output_dir = os.path.join(output_dir, strategy.value)
    os.makedirs(output_dir, exist_ok=True) # TODO

    num_samples = 4 # TODO
    
    if strategy == AnnotationStrategy.SINGLE_RM:
        model_path = "path/to/Skywork-Reward-Gemma-2-27B-v0.2"
    else:
        # model_path = "path/to/Llama-3.1-70B-Instruct" # TODO
        model_path = "path/to/Qwen2.5-72B-Instruct" # TODO
    print(f"Running on {strategy.value}...")
    
    # Initialize annotator
    annotator = ResponseAnnotator(
        data_dir=data_dir,
        online=True, # TODO
        num_samples=num_samples,
        strategy=strategy
    )
    
    turn = 1 # TODO
    MAX_CNT = 120000

    tasks_df = annotator.generate_annotation_tasks()
    print(f"Generated totally {len(tasks_df)} tasks")
    tasks_df = tasks_df[MAX_CNT * turn: MAX_CNT * (turn + 1)] # TODO

    prompts = annotator.generate_prompts(tasks_df)
    print(f"Sampled {len(tasks_df)} tasks")
    print(f"Demo prompt: {prompts[-1]}")
    
    # Initialize evaluator
    llm_evaluator = LLMEvaluator(model_path, tensor_parallel_size=torch.cuda.device_count())
    print(f"Sampling from {MAX_CNT * turn} to {MAX_CNT * (turn + 1)}")
    results = llm_evaluator.evaluate_prompts(
        tasks_df=tasks_df,
        prompts=prompts,
        strategy=strategy,
        evaluator=annotator.evaluator,
        max_retries=3 # TODO
    )

    save_path = os.path.join(output_dir, f"{turn}.json") # TODO
    # save_path = os.path.join(output_dir, f"{strategy.value}_2.json") # TODO
    with open(save_path, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=4, ensure_ascii=False)


if __name__ == "__main__":
    main()