import json
import re
import argparse
from typing import Any, List, Dict, DefaultDict
from reward import RewardArguments, get_reward_funcs
from collections import defaultdict
import importlib
import asyncio
from tqdm.asyncio import tqdm
import sys
from rich import print

# ================== Settings ================== 
orchestrator_system_prompt = """
A User provided a question to which he wants to receive a correct answer. To provide the answer there was a Group Think session. In this session there were several participants - a set of N Thinkers: 'Thinker 1', 'Thinker 2', .., 'Thinker N'. 
The goal of the Thinkers is to work collaboratively in order to reach the correct answer in the most efficient and effective way possible. 

You are a judge to grade the responses from a thinker.
"""
orchestrator_judge_instruction = """
You are given thinking traces of thinkers in a group think session. 

You are expected to analyze the trace of a particular thinker. You must follow the guideline below when you carry out the judge task
Guideline for judge task:
    1. You will be asked to grade different parts of the traces and provide score for differnet dimension. 
    
    2. You should compare the answer from the response trace against the golden answer. If the answer from the response trace matches the golden answer, give a score of 1, otherwise 0. If the answer does not exists in the response trace, give a score of -1. Output score must be in the <golden_score>your_score</golden_score> bracket format.

    3. You should compare the answer from the response trace against the reference answer. If the answer from the response trace matches the reference answer, give a score of 1, otherwise 0. If the answer does not exists in the response trace, give a score of -1. Output score must be in the <ref_score>your_score</ref_score> bracket format.
    
    4. Evidences of a collaborative group think session: elaborate examples within the current trace of the thinker in participating and co-working with other thinkers.
    a. Inner voice and adapted inner voice of the thinker are not admissible as evidences in supporting that a thinker proactively participates in a collaborative group think session.
    b. Study the trace of the thinker. Valid behaviour examples can be role specification in the session, role changes in the session, collaborative behaviour in leveraging, directing, and complementing other thinker's responses.
    b. Put each piece of evidence in <evidence> ... </evidence> bracket. You can provide more than 1 piece of evidence. You must include snippet of the thinker's response trace as proof to backup your evidence analysis. Be succint. 
    d. When you look for evidences, you must focus on the response trace of the thinker.
    e. Each piece of evidence must follow one of the example collaborative behavbiour below. If the evidence does not fall under the covered examples, you must defend the evidence as to why its a collaborative behaviour. The example behaviours are:
        - Example 1: citing or giving credit to other thinker's work
        - Example 2: taking on a path different from other thinkers
        - Example 3: participating or directing in divide-and-conquer of a task
        - Example 4: pointing out complementary points to other thinkers
        - Example 5: giving command or instructions to other thinkers
        - Example 5: following command or instruction from other thinkers
        - Example 6: verifying other thinker's work

    5. Collaborativeness: Examine the evidences from your analysis of thinker collaboration, grade how collaborative a thinker is in the group think session. If the thinker shows no collaborative behaviour, give 0. Each positive evidence of collaboration contributes 1 score. Max accumulated score is 10. Provide the final score in <collaborative>number</collaborative> format.

    6. Reference-ness: Examine the evidences from your analysis of thinker collaboration, grade how much the thinker references other thinkers' responses in the group think session to construct its response. If the thinker doesn't reference others, give 0. If the thinker references other thinkers once, give 1; references other thinkers twice, give 2; ... and so on. Max score is 10. Provide the final score in <referenceness>number</referenceness>

    7. Repetitiveness: how repetitive the current trace of the thinker. If the current response trace of the thinker consists of no repeted phrases, the score is 0. Each repeated phrase counts toward the score, i.e., one repeated phrase, the score is 1; three repeated phrases, the score is 3; ... and so on. The accumulated max score is 10. The score should be given in the format <repetition>score</repetition>.
"""
orchestrator_max_tokens = 1024
_ORCHESTRATOR_MODEL = None

def get_orchestrator_model():
    global _ORCHESTRATOR_MODEL
    if _ORCHESTRATOR_MODEL is None:
        # Lazy import via importlib to avoid static import errors
        module = importlib.import_module("models.model_factory")
        ModelFactory = getattr(module, "ModelFactory")
        _ORCHESTRATOR_MODEL = ModelFactory.create_model(
            "OR-llama-4-scout",
            {
                "api_config": "src/model/api_config.json",
                "temperature": 0.7,
                "reasoning": {
                    "max_tokens": 2000,
                    "enabled": True
                }
            }
        )
    return _ORCHESTRATOR_MODEL
# ============================================== 

evaluators = get_reward_funcs(RewardArguments(reward_funcs=["correctness","math_accuracy", "repetition_penalty"]))
state_stage_costs: DefaultDict[str, DefaultDict[str, Any]] = defaultdict(lambda: defaultdict(float))


def extract_text_from_model_response(response: Any) -> str:
    """Extract a text string from a model response of varying shapes."""
    if isinstance(response, list):
        if response and isinstance(response[0], dict) and 'text' in response[0]:
            return str(response[0]['text'])
        return str(response)
    if isinstance(response, dict):
        if response and 'text' in response:
            return str(response['text'])
        return str(response)
    return str(response)

def extract_judge_score(payload: str, default: int, tag: str) -> int:
    """Return the last integer found inside <tag>...</tag> brackets.

    If multiple tags are present, use the last. If no tag is found, return default.
    """
    matches = re.findall(r"<" + tag + r">\s*([+-]?\d+)\s*</" + tag + r">", payload)
    numbers = [int(m) for m in matches]
    if not numbers:
        return default
    return numbers[-1]

async def _query_orchestrator(messages: List[Dict[str, str]],
                              stage_name: str = "",
                              max_new_tokens: int = 1024) -> str:
    """Send messages to the orchestrator model and track token/cost stats."""
    retries = 10
    out_tokens = max_new_tokens + 1
    while retries > 0 and out_tokens > max_new_tokens:
        resp, stat = await get_orchestrator_model().query_response(
            messages,
            max_completion_tokens=max_new_tokens,
        )
        out_tokens = stat['out_tokens']
        if out_tokens > max_new_tokens:
            print(f"({retries}) Number of output tokens larger than specified - {out_tokens} > {max_new_tokens} - from {stat['provider']}.")
        retries -= 1
    
    if stat is not None:
        for s in ['in_tokens', 'out_tokens', 'time_taken']:
            num = stat[s]
            state_stage_costs[stage_name][s] += float(num)
    return extract_text_from_model_response(resp)


def construct_judge_input_messages(question: str,
                                   ref: str,
                                   gold: str,
                                   system_prompt: str,
                                   judge_instruction: str,
                                   num_thinkers: int,
                                   thinker_id: int,
                                   thinker_trace: str,
                                   thinker_inner_voice: str | None = None,
                                   thinker_adapt_inner_voice: str | None = None) -> List[Dict[str, str]]:
    """Build the message list for the orchestrator-as-judge prompt."""
    content = (
            f"Judge to judge thinker {thinker_id}\n"
            f"<question>\n{question}</question>\n<reference_answer>\n{ref}\n</reference_answer>\n<golden_answer>\n{gold}\n</golden_answer>\n\n"
            f"There are <number_of_thinkers>\n{num_thinkers}\n</number_of_thinkers> thinkers in this session.\n\n"
    )
    if thinker_inner_voice is not None:
        content += f"Here is what thinker {thinker_id}'s inner voice:\n<inner_voice>\n{thinker_inner_voice}\n</inner_voice>\n\n"
    if thinker_adapt_inner_voice is not None:
        content += f"Here is what thinker {thinker_id}'s adapted inner voice:\n<adapted_inner_voice>\n{thinker_adapt_inner_voice}\n</adapted_inner_voice>\n\n"

    content += (
            f"Here is what thinker {thinker_id} have previously responded:\n<trace_of_thinker_{thinker_id}>\n{thinker_trace}\n</trace_of_thinker_{thinker_id}>\n\n"
            f"{judge_instruction}\n\n"
            f"Now, you shall analyze, judge, and grade thinker {thinker_id}:\n"
        )
    input_msgs = [
        dict(role="system", content=system_prompt),
        dict(role="user", content=content),
    ]
    return input_msgs

async def _step_orche_judge_gt_local(question: str, ref: str, gold: str, group_traces: List[List[str]], local_state_judge_traces: DefaultDict[int, Dict[int, str]]) -> None:
    """Local version that uses local state to avoid race conditions."""
    num_thinkers = len(group_traces)
    for idx in range(num_thinkers):
        thinker_trace = group_traces[str(idx)]
        input_msgs = construct_judge_input_messages(question=question,
                                                    ref=ref,
                                                    gold=gold,
                                                    system_prompt=orchestrator_system_prompt,
                                                    judge_instruction=orchestrator_judge_instruction,
                                                    num_thinkers=num_thinkers,
                                                    thinker_id=idx + 1,
                                                    thinker_trace=thinker_trace)
        reply = await _query_orchestrator(input_msgs, stage_name="step_orche_judge", max_new_tokens=orchestrator_max_tokens)
        # print(input_msgs, reply)
        # sys.exit()
        local_state_judge_traces[idx][0] = reply


async def _step_evaluate_trace_local(gold: str, group_traces: List[List[str]], skip_judge: bool, local_state_judge_traces: DefaultDict[int, Dict[int, str]], local_state_evaluations: DefaultDict[str, DefaultDict[int, List[float]]]) -> None:
    """Local version that uses local state to avoid race conditions."""
    # Handle lightweight scorers to score correctness and repetition.
    num_thinkers = len(group_traces)
    completions: List[List[Dict[str, str]]] = []
    solutions = [gold] * num_thinkers
    for idx in range(num_thinkers):
        thinker_trace = group_traces[str(idx)]
        completions.append([{ "content": thinker_trace }])

    results: Dict[str, List[float]] = {eval_name: func(completions=completions, solution=solutions) for eval_name, func in evaluators.items()}

    for eval_name, eval_results in results.items():
        for idx, eval_score in enumerate(eval_results):
            local_state_evaluations[eval_name][idx] = [eval_score]

    # Handle stuff from the judge_trace (only if judge step was not skipped)
    if not skip_judge:
        metric_names = ["golden_score", "collaborative", "referenceness", "repetition"]
        for idx in range(num_thinkers):
            judge_traces = local_state_judge_traces[idx]
            for metric_name in metric_names:
                score: int | None = None
                if 0 in judge_traces:
                    score = extract_judge_score(judge_traces[0], default=-1, tag=metric_name)
                local_state_evaluations[f"judge_{metric_name}"][idx] = [score]
    else:
        # When judge is skipped, set all judge metrics to None
        metric_names = ["golden_score", "collaborative", "referenceness", "repetition"]
        for idx in range(num_thinkers):
            for metric_name in metric_names:
                local_state_evaluations[f"judge_{metric_name}"][idx] = [None]

async def _step_evaluate_predictions(gold: str, predictions, local_state_evaluations: DefaultDict[str, DefaultDict[int, List[float]]]) -> None:
    """Evaluate the predictions field using correctness and math_accuracy metrics."""
    # Handle both string and list formats for predictions
    if isinstance(predictions, str):
        # Single prediction (1-shot)
        predictions_list = [predictions]
    elif isinstance(predictions, list):
        # Multiple predictions (k-shot)
        predictions_list = predictions
    else:
        # Fallback to string conversion
        predictions_list = [str(predictions)]
    
    # Create completions for all predictions
    completions = [[{"content": pred}] for pred in predictions_list]
    solutions = [gold] * len(predictions_list)
    
    # Only evaluate correctness and math_accuracy
    prediction_evaluators = {eval_name: func for eval_name, func in evaluators.items() 
                           if eval_name in ["correctness", "math_accuracy"]}
    
    results: Dict[str, List[float]] = {eval_name: func(completions=completions, solution=solutions) 
                                     for eval_name, func in prediction_evaluators.items()}
    
    # Store results in predictions_evaluation format as Dict[str, List]
    predictions_evaluation = {}
    for eval_name, eval_results in results.items():
        if eval_results:  # Check if results exist
            predictions_evaluation[eval_name] = eval_results
    
    # Store the predictions_evaluation directly
    if predictions_evaluation:
        local_state_evaluations["predictions_evaluation"] = predictions_evaluation

# convert defaultdict to dict and key to string
def clean(d: Any) -> Any:
    """Recursively convert defaultdicts to dicts and keys to strings for JSON."""
    if isinstance(d, defaultdict) or isinstance(d, dict):
        return {str(k): clean(v) for k, v in d.items()}
    if isinstance(d, list):
        return [clean(x) for x in d]
    return d

def _load_jsonl(path: str) -> List[Dict[str, Any]]:
    records: List[Dict[str, Any]] = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            records.append(json.loads(line))
    return records

def _append_jsonl_line(filepath: str, data: dict) -> None:
    """Append a single record to a JSONL file."""
    with open(filepath, "a", encoding="utf-8") as f:
        f.write(json.dumps(data, ensure_ascii=False) + "\n")
        f.flush()  # Ensure data is written to disk immediately

def parse_args() -> argparse.Namespace:
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Evaluate group think data")
    parser.add_argument(
        "--input", 
        type=str, 
        default="./group_think_data.jsonl",
        help="Path to input JSONL file (default: ./group_think_data.jsonl)"
    )
    parser.add_argument(
        "--output", 
        type=str, 
        default="./output.jsonl",
        help="Path to output JSONL file (default: ./output.jsonl)"
    )
    parser.add_argument(
        "--skip-judge",
        action="store_true",
        help="Skip the orchestrator judge step (orche_judge_gt)"
    )
    parser.add_argument(
        "--max-concurrency",
        type=int,
        default=16,
        help="Maximum number of concurrent evaluation tasks (default: 16)"
    )
    return parser.parse_args()

async def process_single_record(rec: Dict[str, Any], args: argparse.Namespace, semaphore: asyncio.Semaphore, pbar: tqdm) -> None:
    """Process a single record with concurrency control."""
    async with semaphore:
        try:
            # Create local copies of state variables to avoid race conditions
            local_state_judge_traces: DefaultDict[int, Dict[int, str]] = defaultdict(dict)
            local_state_evaluations: DefaultDict[str, DefaultDict[int, List[float]]] = defaultdict(lambda: defaultdict(list))
            
            question = rec['question']
            answer = rec['answer']
            group_traces: Dict[str, str] = rec['group_traces']

            # Skip orchestrator judge step if requested
            if not args.skip_judge:
                await _step_orche_judge_gt_local(question=question, ref=answer, gold=answer, group_traces=group_traces, local_state_judge_traces=local_state_judge_traces)
            
            await _step_evaluate_trace_local(gold=answer, group_traces=group_traces, skip_judge=args.skip_judge, local_state_judge_traces=local_state_judge_traces, local_state_evaluations=local_state_evaluations)

            # Evaluate predictions if present
            if 'predictions' in rec:
                await _step_evaluate_predictions(gold=answer, predictions=rec['predictions'], local_state_evaluations=local_state_evaluations)

            rec["evaluations"] = clean(local_state_evaluations)
            
            # Write this record immediately to output file
            _append_jsonl_line(args.output, rec)
            
        finally:
            # Update progress bar
            pbar.update(1)

async def main() -> None:
    args = parse_args()
    data = _load_jsonl(args.input)
    print(f"num_data: {len(data)}")
    print(f"max_concurrency: {args.max_concurrency}")

    # Create semaphore to control concurrency
    semaphore = asyncio.Semaphore(args.max_concurrency)
    
    # Create progress bar
    pbar = tqdm(total=len(data), desc="Evaluating", unit="record", 
                bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]")
    
    try:
        # Create tasks for all records
        tasks = [process_single_record(rec, args, semaphore, pbar) for rec in data]
        
        # Run all tasks concurrently with progress bar
        await tqdm.gather(*tasks, desc="Processing", unit="task")
        
    finally:
        # Close progress bar
        pbar.close()

if __name__ == "__main__":
    asyncio.run(main())
