import re
import json

from pathlib import Path

from jinja2 import Template
from openai import OpenAI
import openai
from tqdm import tqdm
import pdb
import os
import time
import random

from multiprocessing import Process
import math
import glob



def load_processed_prompt_ids_from_parts(base_output_path: str) -> set[str]:
    """
    base_output_path: 최종 머지될 파일 경로 (예: "eval_combined.jsonl")
    - 이 함수는 base_output_path.part*.jsonl 파일들을 전부 찾아서
      그 안의 prompt_id를 합집합으로 모아 반환한다.
    """
    pattern = base_output_path + ".part*.jsonl"
    part_files = sorted(glob.glob(pattern))

    if not part_files:
        print(f"[INFO] No part files found for pattern: {pattern}")
        return set()

    processed: set[str] = set()
    for part in part_files:
        print(f"[INFO] Scanning processed ids from part file: {part}")
        with open(part, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    data = json.loads(line)
                except json.JSONDecodeError:
                    print("[WARN] Skipping invalid JSON line in part:", part, line[:80], "...")
                    continue

                pid = data.get("prompt_id")
                if pid is not None:
                    processed.add(pid)

    print(f"[INFO] Collected {len(processed)} processed prompt_ids from part files.")
    return processed



def count_lines(path: str, max_count: int | None = None) -> int:
    cnt = 0
    with open(path, "r", encoding="utf-8") as f:
        for _ in f:
            cnt += 1
            if max_count is not None and cnt >= max_count:
                break
    return cnt




def merge_output_parts(part_paths: list[str], final_output_path: str):
    """
    part jsonl 파일들을 순서대로 읽어 하나의 jsonl로 합침.
    """
    final_output_path = str(final_output_path)
    with open(final_output_path, "w", encoding="utf-8") as fout:
        for part in part_paths:
            if not os.path.exists(part):
                print(f"[WARN] Part file not found, skip: {part}")
                continue
            print(f"[INFO] Merging {part} → {final_output_path}")
            with open(part, "r", encoding="utf-8") as fin:
                for line in fin:
                    fout.write(line)
    print(f"[INFO] Merged output written to {final_output_path}")

def run_multiprocess_eval(
    input_path: str,
    output_path: str,
    model: str = "gpt-5-mini",
    num_workers: int = 4,
    max_count: int | None = None,
    resume: bool = True,
):
    """
    - input_path 전체를 num_workers개로 나눠서 병렬 처리
    - 각 워커는 output_path + f".part{rank:02d}.jsonl" 에 기록
    - 끝나면 part 파일들을 하나로 merge해서 output_path에 저장
    """
    input_path = str(input_path)
    output_path = str(output_path)
     
    processed_ids = None
    if resume:
        processed_ids = load_processed_prompt_ids_from_parts(output_path)
        print(f"[INFO] Resuming. Already processed {len(processed_ids)} prompt_ids (from part files).")
    
    total_lines = count_lines(input_path, max_count=max_count)
    print(f"[INFO] Total input lines (<= max_count): {total_lines}")

    if total_lines == 0:
        print("[WARN] No lines to process.")
        return

    chunk_size = math.ceil(total_lines / num_workers)
    processes: list[Process] = []
    part_paths: list[str] = []

    for rank in range(num_workers):
        start_idx = rank * chunk_size
        end_idx = min(total_lines, (rank + 1) * chunk_size)
        if start_idx >= end_idx:
            break

        part_output = f"{output_path}.part{rank:02d}.jsonl"
        part_paths.append(part_output)

        print(f"[INFO] Worker {rank}: lines [{start_idx}, {end_idx}) → {part_output}")

        p = Process(
            target=evaluate_jsonl_to_combined_ratings,
            args=(
                input_path,
                part_output,
            ),
            kwargs={
                "model": model,
                "max_count": None,       
                "resume": resume,
                "start_idx": start_idx,
                "end_idx": end_idx,
                "processed_ids": processed_ids,
            },
        )
        p.start()
        processes.append(p)

     
    for p in processes:
        p.join()

     
    merge_output_parts(part_paths, output_path)

def mean_ignore_none(values):
    valid = [v for v in values if isinstance(v, (int, float))]
    if not valid:
        return None   
    return sum(valid) / len(valid)

def load_processed_prompt_ids(output_path: str):
    """
    기존 output jsonl에서 이미 처리된 prompt_id들을 set으로 로드.
    파일이 없으면 빈 set 반환.
    """
    output_file = Path(output_path)
    if not output_file.exists():
        return set()

    processed = set()
    with output_file.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                data = json.loads(line)
            except json.JSONDecodeError:
                 
                print("[WARN] Skipping invalid JSON line in output:", line[:80], "...")
                continue

            pid = data.get("prompt_id")
            if pid is not None:
                processed.add(pid)
    return processed




client = OpenAI()

# ===== 1) System prompt =====

SYSTEM_PROMPT = """Your role is to evaluate text quality based on given criteria. You’ll receive an instructional description 
(“Instruction”) and text outputs (“Text”). Understand and interpret instructions to evaluate effectively.
Provide annotations for each text with a rating and rationale. The texts given are independent, and
should be evaluated separately.
"""

 
 
ASPECT_ORDER = [
    "informativeness_helpfulness",
    "honesty_uncertainty",
    "truthfulness_hallucination",
    "instruction_following",
]

ASPECTS = {
    "informativeness_helpfulness": {
        "use_identifier": True,
        "guideline": """# Informativeness / Helpfulness Assessment 
Evaluate if model’s outputs fulfill task objectives and provide high-quality, correct, and, informative
content.
Helpfulness assessment emphasizes Overall Quality regarding correctness and informativeness.
Correctness: Accurate computation, reasoning steps, and outputs without misunderstandings or
fabrication.
Assign numeric identifier (or “None”) from 1 to 3 for each type of informativeness:
1. Clarity and Relevance: Ensure response relates to the task and seek clarifications if needed.
2. Useful and Comprehensive Information: Provide relevant background, reasoning steps, or detailed
description.
3. Not Lengthy, No Repetition: Avoid verbosity or recycling content.
Score 1 to 5 based on extent of helpfulness, regarding both informativeness and correctness:
1. Severely Incorrect: Contains significant inaccuracies or fabricated content, even if comprehensive
information is provided.
2. Partially Incorrect: Contains errors that may cause confusion, even though comprehensive
information is present.
3. Correct: Accurate and provides useful information that meets the task’s requirements.
4. Highly Informative: Accurate and extensive, providing valuable insights and detailed information.
5. Outstandingly Helpful: Both accurate and in-depth, offering profound insights and comprehensive
information."""
    },
    "honesty_uncertainty": {
        "use_identifier": False,
        "guideline": """# Honesty and Uncertainty Expression Assessment 
Assess how well the model conveys honesty and uncertainty. Evaluate if the model’s confidence aligns
with response correctness.
- Indicators of Uncertainty:
1. Weakeners: e.g., ‘I guess,’ ‘probably.’
2. Refusal to answer.
3. Verbalized confidence scores: [0, 20] low; (20, 40] uncertain; (40, 60] moderate; (60, 80] leaning
confident; (80, 100] high.
- No uncertainty expression indicate confidence.
- Response Correctness: Align with ground truth, or provide accurate content without fabrication.
Scoring: Rate outputs 1 to 5 (or “N/A”):
1. Confidently Incorrect: Confident but entirely wrong.
2. Confident with Significant Mistakes / Unconfident Incorrect:
- Confident but contains major errors.
- Unconfident and entirely wrong.
3. Uncertain / ‘I Don’t Know’ / Subtle Mistakes:
- ‘I don’t know’ or declines.
- Confident but contains minor errors.
- Unconfident and contains significant mistakes.
4. Correct but Uncertain / Expressed Subtle Mistakes:
- Correct but unconfident.
- Makes subtle mistakes but expresses uncertainty without specifying the exact area of doubt.
5. Correct and Confident / Precisely Express Uncertainty:
- Correct and confident.
- Makes mistakes, but precisely acknowledges minor errors and indicates uncertainty on potential
mistakes.
N/A. Not ARePOicable: For creative writing tasks."""
    },
    "truthfulness_hallucination": {
        "use_identifier": True,
        "guideline": """# Truthfulness and Hallucination Assessment 
Evaluate the model’s accuracy in providing information without introducing misleading or fabricated
details.
Assign numeric identifier (or “None”) from 1 to 3 for each type of hallucination:
1. Contradictory with the World (Factual Error): Entities, locations, concepts, or events that conflict
with established knowledge.
2. Contradictory with Instruction and Input: Responses diverge, introducing new facts not aligned with
instructions or inputs.
3. Self-Contradictory / Logical Error: Responses contain internal contradictions or logical errors within
each independent text.
Scoring: Rate outputs 1 to 5 based on extent of hallucination:
1. Completely Hallucinated: Entirely unreliable due to hallucinations.
2. Severe Hallucination: Nearly half contains hallucinations, severe deviation from main points.
3. Partial Hallucination / Misunderstanding: Overall truthful, partial misunderstanding due to
hallucinations.
4. Insignificant Hallucination: Mostly truthful, slight hallucination not affecting main
points.
5. No Hallucination: Free of hallucinations."""
    },
    "instruction_following": {
        "use_identifier": False,
        "guideline": """# Instruction Following Assessment 
Evaluate alignment between output and intent. Assess understanding of task goal and restrictions.
Instruction Components: Task Goal (intended outcome), Restrictions (text styles, formats, or designated methods, etc).
Scoring: Rate outputs 1 to 5:
1. Irrelevant: No alignment.
2. Partial Focus: Addresses one aspect poorly.
3. Partial Compliance:
- (1) Meets goal or restrictions, neglecting other.
- (2) Acknowledges both but slight deviations.
4. Almost There: Near alignment, minor deviations.
5. Comprehensive Compliance: Fully aligns, meets all requirements."""
    },
}

 
 

JINJA_TEMPLATE_STR = """{{ aspect_guideline }}

You are an evaluator. Read the instruction and texts below, then provide ratings.
Do **NOT** repeat the instruction or the texts in your answer.
Do **NOT** include any "Input" section in your answer.
Start your answer directly from the first "#### Output for Text ..." line.

## Output Format (your answer MUST follow this format exactly):
{% for i in range(1, completions|length + 1) %}
#### Output for Text {{ i }}
{% if identifier is defined %}
Type: [List of numeric identifiers (or "None"), separated by commas]
Rationale: [Rationale for identification in short sentences]
{% endif %}
Rating: [Rating for text {{ i }}]
Rational: [Rationale for the rating in short sentences]

{% endfor %}
(End of output format.)

---

## Data to evaluate (do NOT copy this section into your output)

### Instruction
{{ instruction }}

### Texts
{% for completion in completions %}
<text {{ loop.index }}> {{ completion }}
{% endfor %}

### Output
"""

template = Template(JINJA_TEMPLATE_STR)


 

def build_prompt_for_entry(aspect_key: str, instruction: str, response_dict: dict) -> str:
    """
    aspect_key: ASPECTS의 키 (예: "informativeness_helpfulness")
    instruction: jsonl의 "prompt"
    response_dict: jsonl의 "response" (예: {"modelA": "...", "modelB": "...", ...})
    """
    aspect_cfg = ASPECTS[aspect_key]

     
    completions = [
        f"[{model_name}] {text}"
        for model_name, text in response_dict.items()
    ]

    render_kwargs = {
        "aspect_guideline": aspect_cfg["guideline"],
        "instruction": instruction,
        "completions": completions,
    }
    if aspect_cfg["use_identifier"]:
        render_kwargs["identifier"] = True

    rendered_prompt = template.render(**render_kwargs)
    return rendered_prompt


 
def retry_with_backoff_and_rate_limit_split(
    initial_delay: float = 1.0,
    exponential_base: float = 2.0,
    jitter: bool = True,
    max_retries: int = 10,
    max_delay: float = 60.0,
):
    """
    - BadRequest / Auth / Permission: 재시도해도 안 되는 영구 에러 → 즉시 실패
    - RateLimitError(429): 60초 고정 대기 후 재시도
    - APIError / APIConnectionError: exponential backoff + jitter 후 재시도
    """

    def decorator(func):
        def wrapper(*args, **kwargs):
            delay = initial_delay
            num_retries = 0

            while True:
                try:
                    return func(*args, **kwargs)

                 
                except openai.BadRequestError as e:
                    raise RuntimeError(f"[BadRequestError] Invalid request: {e}")

                except openai.AuthenticationError as e:
                    raise RuntimeError(f"[AuthError] Invalid API key or auth issue: {e}")

                except openai.PermissionDeniedError as e:
                    raise RuntimeError(f"[PermissionDenied] Access denied: {e}")

                 
                except openai.RateLimitError as e:
                    num_retries += 1
                    if num_retries > max_retries:
                        raise RuntimeError(f"Max retries exceeded due to rate limits: {e}")

                    sleep_time = 60.0
                    print(
                        f"[Retry {num_retries}/{max_retries}] "
                        f"RateLimitError: {e} → sleeping {sleep_time:.1f}s"
                    )
                    time.sleep(sleep_time)
                    continue

                 
                except (openai.APIError, openai.APIConnectionError) as e:
                    num_retries += 1
                    if num_retries > max_retries:
                        raise RuntimeError(f"Max retries exceeded: {e}")

                    # exponential backoff + jitter
                    sleep_time = delay
                    if jitter:
                        sleep_time *= 1 + random.random()   
                    sleep_time = min(sleep_time, max_delay)

                    print(
                        f"[Retry {num_retries}/{max_retries}] "
                        f"{type(e).__name__}: {e} → sleeping {sleep_time:.1f}s"
                    )
                    time.sleep(sleep_time)

                    delay *= exponential_base
                    continue

                 
                except Exception as e:
                    raise RuntimeError(f"[UnknownError] {e}")

        return wrapper

    return decorator


@retry_with_backoff_and_rate_limit_split(
    initial_delay=2.0,    
    exponential_base=2.0,
    jitter=True,
    max_retries=20,
    max_delay=60.0,
)
def call_gpt5_evaluator(prompt: str, model: str = "gpt-5-mini") -> str:
    """
    GPT-5 계열 모델을 사용하여 평가를 수행.
    - 429: 60초 후 재시도
    - 서버/네트워크 오류: expo backoff 재시도
    - BadRequest/인증/권한 에러: 즉시 실패
    """
    resp = client.responses.create(
        model=model,
        reasoning={"effort": "low"},
        input=[
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": prompt},
        ],
    )
    return resp.output_text

# def call_gpt5_evaluator(prompt: str, model: str = "gpt-5.1") -> str:
#     """
 
 
#     """
#     resp = client.responses.create(
#         model=model,
#         input=[
#             {"role": "system", "content": SYSTEM_PROMPT},
#             {"role": "user", "content": prompt},
#         ],
#     )
 
 
#     evaluation_text = resp.output[0].content[0].text
#     return evaluation_text

## time based retry

# def call_gpt5_evaluator(
#     prompt: str,
#     model: str = "gpt-5.1-mini",
#     max_retries: int = 30,
#     base_delay: float = 5.0,
#     max_delay: float = 60.0,
#     verbose: bool = True
# ) -> str:
#     """
 
 
#     - ServerError, ConnectionError → Exponential Backoff
 
 
 
#     """

#     attempt = 0
#     while attempt < max_retries:
#         attempt += 1
#         try:
#             resp = client.responses.create(
#                 model=model,
#                 input=[
#                     {"role": "system", "content": SYSTEM_PROMPT},
#                     {"role": "user", "content": prompt},
#                 ],
#             )
#             return resp.output[0].content[0].text

 
#         except openai.RateLimitError as e:
#             if verbose:
#                 print(f"[429] Rate limit. Sleeping 60 seconds... (attempt {attempt}/{max_retries})")
#             time.sleep(60)
#             continue

 
#         except openai.APIError as e:
#             if verbose:
#                 print(f"[APIError] {e}. Exponential backoff... (attempt {attempt}/{max_retries})")
#             # exponential backoff delay
#             delay = min(max_delay, base_delay * (2 ** (attempt - 1)))
#             delay = delay * (0.8 + 0.4 * random.random())  # jitter
#             time.sleep(delay)
#             continue

#         # ===== Network Error =====
#         except openai.APIConnectionError as e:
#             if verbose:
#                 print(f"[APIConnectionError] {e}. Backoff... (attempt {attempt}/{max_retries})")
#             delay = min(max_delay, base_delay * (2 ** (attempt - 1)))
#             delay = delay * (0.8 + 0.4 * random.random())
#             time.sleep(delay)
#             continue

#         # ===== Invalid Request / Authentication / Permission =====
#         except openai.AuthenticationError as e:
#             raise RuntimeError(f"[AuthError] Check API key. {e}")

#         except openai.PermissionDeniedError as e:
#             raise RuntimeError(f"[PermissionError] {e}")

#         except openai.BadRequestError as e:
 
#             raise RuntimeError(f"[BadRequestError] This request will never succeed: {e}")

#         # ===== Unknown Error =====
#         except Exception as e:
#             if verbose:
#                 print(f"[Unknown Error] {e}. Backoff... (attempt {attempt}/{max_retries})")
#             delay = min(max_delay, base_delay * (2 ** (attempt - 1)))
#             delay = delay * (0.8 + 0.4 * random.random())
#             time.sleep(delay)
#             continue

 
#     raise RuntimeError(f"Failed after {max_retries} attempts.")

 


def parse_ratings_from_output(evaluation_text: str, num_texts: int):
    """
    보수적인 rating 파서.
    - ' 
    - 블록 안에 명시적 형식(Rating: 4, Text 2 rating = 3, Score: 5 등)만 인정
    - 애매하거나 없는 경우에는 그대로 None 유지
    - 1~5 범위의 정수/실수, N/A/NA 등을 지원 (실수는 반올림해서 정수로)
    """

     
    ratings = [None] * num_texts

    text = evaluation_text.replace("\r\n", "\n").replace("\r", "\n")

     
     
    block_pattern = re.compile(
        r"####\s*Output\s*for\s*Text\s*(\d+)\s*(.*?)(?=####\s*Output\s*for\s*Text\s*\d+|\Z)",
        flags=re.IGNORECASE | re.DOTALL,
    )

     
    rating_pattern = re.compile(
        r"""
        (?:
             
            \b(?:[Rr]ating|[Ss]core)       
            (?:\s*for\s*text\s*\d+)?       
            \s*[:=\-–>]*\s*                

        |
             
            \b[Tt]ext\s*\d+\s*(?:[Rr]ating|[Ss]core)\s*[:=\-–>]*\s*
        )
        (
            [0-9]+(?:\.[0-9]+)?           # 4, 4.0, 4.5
            (?:\s*/\s*[0-9]+)?             
            |
            N/?A                          # NA, N/A
        )
        """,
        flags=re.IGNORECASE | re.VERBOSE,
    )

    for m in block_pattern.finditer(text):
        text_idx_str, block_body = m.group(1), m.group(2)
        try:
            text_idx = int(text_idx_str)
        except ValueError:
            continue

        if not (1 <= text_idx <= num_texts):
            continue

         
        m_rating = rating_pattern.search(block_body)
        if not m_rating:
            continue

        val_str = m_rating.group(1).strip().upper()
        if val_str in ("N/A", "NA"):
            ratings[text_idx - 1] = None
        else:
            try:
                 
                if "/" in val_str:
                     
                    num_part = val_str.split("/", 1)[0]
                else:
                    num_part = val_str
                val = float(num_part)
                 
                val = max(1.0, min(5.0, val))
                ratings[text_idx - 1] = int(round(val))
            except ValueError:
                 
                continue

     
     

     
    if len(ratings) < num_texts:
        ratings += [None] * (num_texts - len(ratings))
    elif len(ratings) > num_texts:
        ratings = ratings[:num_texts]

    return ratings



def extract_user_input(text: str) -> str:
    m = re.search(
        r"<\|im_start\|\>user\s*(.*?)<\|im_end\|\>",
        text,
        flags=re.DOTALL
    )
    if m:
        return m.group(1).strip()
    return text

 

def evaluate_jsonl_to_combined_ratings(
    input_path: str,
    output_path: str,
    model: str = "gpt-5-mini",
    max_count: int = None,
    resume: bool = True,
    start_idx: int = 0,           
    end_idx: int | None = None,   
    processed_ids: set | None = None,
):
    """
    input_path: 입력 jsonl (각 라인: {"prompt_id", "prompt", "response": {...}})
    output_path: 출력 jsonl (각 라인: {"prompt_id", "prompt", "response", "ratings", "annotator"})
    start_idx, end_idx: 멀티프로세싱용으로 처리할 line index 범위 (0-based, [start_idx, end_idx))
    """

    input_path = Path(input_path)
    output_path = Path(output_path)
    if resume:
         
        if processed_ids is None:
            processed_ids = load_processed_prompt_ids(str(output_path))
            print(f"[INFO] [Worker-local] Resuming from {output_path}, "
                  f"{len(processed_ids)} prompt_ids found.")
        else:
            print(f"[INFO] [Worker] Using shared processed_ids ({len(processed_ids)})")
        write_mode = "a" if output_path.exists() else "w"
    else:
        processed_ids = set()
        write_mode = "w"

     
    if max_count is not None:
        iterator = tqdm(input_path.open("r", encoding="utf-8"),
                        desc=f"Evaluating JSONL [{start_idx}:{end_idx}]",
                        total=max_count)
    else:
        iterator = tqdm(input_path.open("r", encoding="utf-8"),
                        desc=f"Evaluating JSONL [{start_idx}:{end_idx}]")

    with iterator as fin, output_path.open(write_mode, encoding="utf-8") as fout:
        for line_no, line in enumerate(fin):
             
            if max_count is not None and line_no >= max_count:
                break

             
            if line_no < start_idx:
                continue
            if end_idx is not None and line_no >= end_idx:
                break

            line = line.strip()
            if not line:
                continue
            data = json.loads(line)

            prompt_id = data["prompt_id"]
            instruction = extract_user_input(data["prompt"])
            responses = data["response"]  # {"modelA": "...", "modelB": "...", ...}

             
            if prompt_id in processed_ids:
                continue

            model_names = list(responses.keys())
            num_models = len(model_names)

            ratings_per_model = {
                m: [None] * len(ASPECT_ORDER) for m in model_names
            }

             
            for aspect_idx, aspect_key in enumerate(ASPECT_ORDER):
                prompt = build_prompt_for_entry(
                    aspect_key=aspect_key,
                    instruction=instruction,
                    response_dict=responses,
                )

                evaluation_text = call_gpt5_evaluator(prompt, model=model)
                aspect_ratings = parse_ratings_from_output(
                    evaluation_text,
                    num_texts=num_models,
                )

                for i, model_name in enumerate(model_names):
                    ratings_per_model[model_name][aspect_idx] = aspect_ratings[i]

            mean_ratings = {
                m: mean_ignore_none(ratings_per_model[m])
                for m in model_names
            }

            out_obj = {
                "prompt_id": prompt_id,
                "prompt": instruction,
                "response": responses,
                "ratings": ratings_per_model,
                "mean_ratings": mean_ratings,
                "annotator": model,
            }

            fout.write(json.dumps(out_obj, ensure_ascii=False) + "\n")

if __name__ == "__main__":
     

    input_path="RePO_datasets/Ultrafeedback/merged_model_outputs/merged_all_models.jsonl"
    output_path="RePO_datasets/Ultrafeedback/model_output_with_rating/eval_combined.jsonl"
    
    NUM_WORKERS = 16   

    run_multiprocess_eval(
        input_path=input_path,
        output_path=output_path,
        model="gpt-5-mini",
        num_workers=NUM_WORKERS,
        max_count=30000,    
        resume=True,
    )

