import argparse
import json
import os
from pathlib import Path
import sys
import re
import time
from typing import List, Tuple, Optional, Dict, Any

# Ensure CUDA arch for some mac/linux setups (mirrors utils/inference.py behavior)
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "9.0")

# Heavy deps are lazily imported inside functions to speed up module import and avoid environment issues


# Add project root to sys.path so we can import from src
current_dir = Path(__file__).parent
project_root = current_dir.parent
sys.path.insert(0, str(project_root))
sys.path.insert(0, str(current_dir))

from src.templates import SYSTEM_PROMPT, modify_user_message_for_reasoning  # noqa: E402
from types import SimpleNamespace

# Base/eval modules are imported lazily where needed


def parse_args():
    parser = argparse.ArgumentParser(description="Motivation-style inference: inject reference answer to guide reasoning.")

    parser.add_argument("--model_name", type=str,
                        default="../modelscope/Qwen/Qwen3-4B-Thinking-2507",
                        help="Model name or path")

    parser.add_argument("--dataset_name", type=str,
                        default="gpqa_diamond_Avg4",
                        help="Dataset base name (without extension)")

    parser.add_argument("--dataset_path", type=str,
                        default="../datasets/rlpr/test",
                        help="Path to the dataset directory containing <dataset_name>.jsonl")

    parser.add_argument("--tensor_parallel_size", type=int, default=4,
                        help="Tensor parallel size for vLLM")

    parser.add_argument("--gpu_memory_utilization", type=float, default=0.95,
                        help="GPU memory utilization ratio for vLLM")

    parser.add_argument("--temperature", type=float, default=0.6,
                        help="Sampling temperature")

    parser.add_argument("--top_p", type=float, default=0.95,
                        help="Top-p sampling parameter")

    parser.add_argument("--top_k", type=int, default=20,
                        help="Top-k sampling parameter")

    parser.add_argument("--n_generations", type=int, default=4,
                        help="Number of generations per prompt")

    parser.add_argument("--max_tokens", type=int, default=32000,
                        help="Maximum new tokens to generate")

    parser.add_argument("--num_samples", type=int, default=None,
                        help="Limit the number of samples to run (None = all)")

    parser.add_argument("--output_file", type=str, default=None,
                        help="Optional output JSON file path. If omitted, an auto path is used.")

    parser.add_argument("--use_system_prompt", action="store_true", default=False,
                        help="Whether to include a SYSTEM message with SYSTEM_PROMPT.")

    return parser.parse_args()


def _auto_output_path(model_name: str, dataset_name: str) -> str:
    base = Path(model_name).name
    out_dir = project_root / "inference_results" / base
    out_dir.mkdir(parents=True, exist_ok=True)
    return str(out_dir / f"{dataset_name}_motivation.json")


def _apply_reasoning_prompt(sample: Dict[str, Any]) -> Dict[str, Any]:
    """Inject the reference answer into the user's message using the template helper.

    Falls back to the original sample if validation fails.
    """
    try:
        return modify_user_message_for_reasoning(sample)
    except Exception as e:
        # Be robust: if something's off with the structure, just return the sample unchanged
        print(f"[Warn] modify_user_message_for_reasoning failed, using original sample. Err: {e}")
        return sample


def _prepare_chat_messages(sample: Dict[str, Any], use_system_prompt: bool) -> List[Dict[str, str]]:
    msgs = list(sample.get("prompt", []))
    if use_system_prompt:
        if len(msgs) > 0 and msgs[0].get("role") == "system":
            msgs[0] = {"role": "system", "content": SYSTEM_PROMPT}
        else:
            msgs.insert(0, {"role": "system", "content": SYSTEM_PROMPT})
    else:
        # Some datasets include system + user; if not using system, keep only user content
        msgs = [m for m in msgs if m.get("role") == "user"]
    return msgs


def load_and_prepare_dataset(
    dataset_path: str,
    dataset_name: str,
    use_system_prompt: bool = False,
    num_samples: Optional[int] = None,
) -> List[Tuple[List[Dict[str, str]], Any]]:
    """Load dataset and build chats with the reference-answer-injected user message.

    Returns a list of tuples: (chat_messages, ground_truth)
    """
    results: List[Tuple[List[Dict[str, str]], Any]] = []
    data_file = Path(dataset_path) / f"{dataset_name}.jsonl"
    with open(data_file, "r", encoding="utf-8") as fr:
        for line in fr:
            raw = json.loads(line)
            modified = _apply_reasoning_prompt(raw)
            chat = _prepare_chat_messages(modified, use_system_prompt)
            gt = raw.get("reward_model", {}).get("ground_truth")
            results.append((chat, gt))
            if num_samples is not None and len(results) >= num_samples:
                break
    return results


def motivation_inference(args, llm: Optional[Any] = None, tokenizer: Optional[Any] = None):
    """Run motivation-style inference using vLLM.

    - Injects the reference answer into the user message via modify_user_message_for_reasoning.
    - Optionally includes a SYSTEM message.
    - Uses chat template with enable_thinking and opens a <think> tag to encourage detailed reasoning.

    Returns: (predictions, output_file_path, llm)
    """
    # Prepare output path
    output_path = args.output_file or _auto_output_path(args.model_name, args.dataset_name)

    # Initialize model + tokenizer once (can be reused across calls)
    if llm is None:
        print(f"Loading model: {args.model_name}")
        # Lazy imports here to avoid module import overhead
        from vllm import LLM, SamplingParams  # type: ignore
        from transformers import AutoTokenizer  # type: ignore
        llm = LLM(
            model=args.model_name,
            tensor_parallel_size=args.tensor_parallel_size,
            gpu_memory_utilization=args.gpu_memory_utilization,
        )
        tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    else:
        print("Reusing already initialized LLM instance.")
        if tokenizer is None:
            from transformers import AutoTokenizer  # type: ignore
            tokenizer = AutoTokenizer.from_pretrained(args.model_name)

    # Dataset
    print(f"Loading dataset: {args.dataset_name}")
    dataset = load_and_prepare_dataset(
        args.dataset_path,
        args.dataset_name,
        use_system_prompt=args.use_system_prompt,
        num_samples=args.num_samples,
    )
    if not dataset:
        print("Dataset is empty. Nothing to do.")
        return [], None, llm

    chats = [item[0] for item in dataset]

    # Build rendered prompts for vLLM
    rendered_prompts: List[str] = []
    for msgs in chats:
        text = tokenizer.apply_chat_template(
            msgs,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True,
        )
        text += "<think>\n"  # Open thinking segment; most think-enabled models will continue within tags
        rendered_prompts.append(text)

    if rendered_prompts:
        print("=============================Prompt Case=============================")
        print(rendered_prompts[0])
        print("====================================================================")

    # Sampling params
    from vllm import SamplingParams  # type: ignore
    sampling_params = SamplingParams(
        temperature=args.temperature,
        top_p=args.top_p,
        top_k=args.top_k,
        n=args.n_generations,
        max_tokens=args.max_tokens,
    )

    # Generate
    outputs = llm.generate(rendered_prompts, sampling_params)

    # Always print first completion case (guarded by existence checks)
    try:
        if outputs and outputs[0].outputs:
            print("=============================Motivation Generation Case (Post-Generation)=============================")
            print("---- First completion text ----")
            print(outputs[0].outputs[0].text)
            print("=====================================================================================================")
    except Exception as e:
        print(f"[Warn] Unable to print motivation completion case: {e}")

    # Collect predictions
    predictions = []
    for i, out in enumerate(outputs):
        chat = chats[i]
        gt = dataset[i][1]
        completions = [o.text for o in out.outputs]
        predictions.append({
            "chat": chat,
            "gt": gt,
            "response_motivation": completions,
        })

    # Persist
    Path(output_path).parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(predictions, f, indent=2, ensure_ascii=False)
    print(f"Results saved to: {output_path}")

    return predictions, output_path, llm


def main():
    args = parse_args()
    motivation_inference(args)


if __name__ == "__main__":
    main()


# ------------------------------
# Utilities over evaluation outputs
# ------------------------------

def find_all_wrong_samples(evaluation_results_path: str, treat_empty_as_wrong: bool = False):
    """From evaluation.py outputs (evaluation_results_*.json), identify samples where all candidate responses are incorrect.

    Compatible with different result structures:
    - Top-level can be a dict with key 'detailed_results'
    - Or a list where each element is a sample result
    - The correctness field within each sample can vary:
        * Boolean list, e.g., judgements, judgement_results, is_correct_list
        * String list, e.g., ["Yes","No"]
        * List of dicts, e.g., [{"label":"Yes"},{"label":"No"}] or {"is_correct": true}

    Args:
        evaluation_results_path: Path to the JSON produced by evaluation.py's LLM-as-Judge
        treat_empty_as_wrong: If a sample has no parsable judgements, treat it as all-wrong. Default False

    Returns:
        A list of dicts, each containing:
        {
            "index": sample index (0-based),
            "sample_id": id/uid extracted from sample if present (or None),
            "num_responses": number of candidate responses for the sample,
            "num_correct": number judged as correct,
            "all_wrong": whether all are incorrect (True/False)
        }
    """
    path = Path(evaluation_results_path)
    if not path.exists():
        raise FileNotFoundError(f"Evaluation results file not found: {evaluation_results_path}")

    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)

    # Normalize to a list of samples
    if isinstance(data, dict) and "detailed_results" in data:
        samples = data["detailed_results"]
    elif isinstance(data, list):
        samples = data
    else:
        raise ValueError("Unrecognized evaluation results format: expected dict with 'detailed_results' or a list")

    def _to_bool(val) -> Optional[bool]:
    # Normalize various markers to a boolean is_correct
        if isinstance(val, bool):
            return val
        if isinstance(val, str):
            v = val.strip().lower()
            if v in ("yes", "correct", "true", "1"):
                return True
            if v in ("no", "incorrect", "false", "0"):
                return False
            return None
        if isinstance(val, (int, float)):
            return bool(val)
        if isinstance(val, dict):
            # Common fields: label or is_correct
            if "is_correct" in val:
                return _to_bool(val["is_correct"])  # type: ignore[index]
            if "label" in val:
                return _to_bool(val["label"])  # type: ignore[index]
        return None

    def _extract_booleans(sample_obj) -> List[bool]:
    # Try multiple possible field names
        candidates = [
            "judgements",
            "judgement_results",
            "is_correct_list",
            "results",
            "labels",
        ]
        vals = None
        for key in candidates:
            if isinstance(sample_obj, dict) and key in sample_obj:
                vals = sample_obj[key]
                break
        if vals is None:
            # Some implementations may place the judgements directly within an arbitrary list field in the sample.
            # Fallback: iterate over values and pick the first list-like field to try.
            if isinstance(sample_obj, dict):
                for v in sample_obj.values():
                    if isinstance(v, list):
                        vals = v
                        break
        if vals is None:
            return []
        bools: List[bool] = []
        if isinstance(vals, list):
            for item in vals:
                b = _to_bool(item)
                if b is not None:
                    bools.append(b)
        return bools

    results = []
    for idx, s in enumerate(samples):
        bools = _extract_booleans(s)
        if not bools:
            # No parsable judgements
            all_wrong = bool(treat_empty_as_wrong)
            num_resp = 0
            num_correct = 0
        else:
            num_resp = len(bools)
            num_correct = sum(1 for b in bools if b)
            all_wrong = num_correct == 0

        sample_id = None
        if isinstance(s, dict):
            sample_id = s.get("uid") or s.get("id") or s.get("sample_id")

        results.append({
            "index": idx,
            "sample_id": sample_id,
            "num_responses": num_resp,
            "num_correct": num_correct,
            "all_wrong": all_wrong,
        })

    # Return only the all-wrong samples
    return [r for r in results if r["all_wrong"]]


# ------------------------------
# Judge motivation reasoning quality using an external LLM
# ------------------------------

def _extract_question_from_user_content(user_content: str) -> str:
    """Extract the content between **Question:** … **Reference Answer:** from templated user content.
    If parsing fails, return the first 4096 characters as a fallback.
    """
    try:
        pattern = r"\*\*Question:\*\*(.*?)\*\*Reference Answer:\*\*"
        m = re.search(pattern, user_content, re.DOTALL)
        if m:
            return m.group(1).strip().strip("`\n ")
    except Exception:
        pass
    return user_content[:4096]


def _extract_reasoning_process_from_completion(text: str, end_tag: str = "</think>") -> str:
    """Extract the reasoning process segment from a model completion:
    - If </think> exists, return the content before it;
    - Otherwise return the entire text (the model may not close the tag).
    """
    try:
        pos = text.find(end_tag)
        if pos != -1:
            return text[:pos]
        return text
    except Exception:
        return text


def _extract_post_think_content(text: str, end_tag: str = "</think>") -> str:
    """Extract the visible content after </think> from a model completion.

    - If end_tag is found, return the trailing text (stripping left whitespace);
    - If not found, return the entire text (the model may not close the tag).
    """
    try:
        pos = text.find(end_tag)
        if pos != -1:
            return text[pos + len(end_tag):].lstrip()
        return text
    except Exception:
        return text


def _parse_yes_no_block(text: str, label: str) -> Optional[bool]:
    """Parse a 'Label: Yes/No' judgement from the evaluator's output."""
    try:
        # e.g., "Correctness Assessment: Yes" or in variants without colon spacing
        pattern = rf"{re.escape(label)}\s*:\s*(Yes|No)\b"
        m = re.search(pattern, text, re.IGNORECASE)
        if m:
            return m.group(1).strip().lower() == "yes"
    except Exception:
        pass
    return None


def _parse_explanation_block(text: str, label: str) -> str:
    """Parse an explanation paragraph: starts with Label: and goes until the next label or end of text."""
    try:
        # Capture from the given label to next double newline or next 'Assessment' label
        pattern = rf"{re.escape(label)}\s*:\s*(.*?)(?=\n\s*\w+\s*Assessment\s*:|\Z)"
        m = re.search(pattern, text, re.IGNORECASE | re.DOTALL)
        if m:
            return m.group(1).strip()
    except Exception:
        pass
    return ""


def _evaluate_with_retry_reasoning(client, messages, sample_idx: int, resp_idx: int):
    """Evaluator with retries and fallback models (simplified). Returns a dict:
    { judgement_raw, is_correct_reasoning, is_independent, explanations, model_sequence, retry_count }
    """
    max_total_retries = 9
    per_model_retry_limit = 3
    retry_count_total = 0
    retry_count_current = 0
    original_model = client.model

    # Read fallback models
    env_fallback = os.getenv("EVAL_FALLBACK_MODELS", "")
    if env_fallback.strip():
        fallback_models = [m.strip() for m in env_fallback.split(',') if m.strip()]
    else:
        fallback_models = [
            "gpt-41-mini-0414-global",
            "o4-mini-0416-global",
            "o3-mini-2025-01-31",
            "gpt-4o-1120-global",
            "gemini-2.5-pro-06-17",
            "gemini-2.5-flash-06-17",
            "claude_sonnet4",
            "claude37_sonnet",
        ]
    # Deduplicate and remove the same as current
    seen = set()
    cleaned = []
    for m in fallback_models:
        if m == client.model:
            continue
        if m not in seen:
            cleaned.append(m)
            seen.add(m)
    fallback_models = cleaned
    fb_idx = 0
    model_sequence = [client.model]

    while retry_count_total < max_total_retries:
        try:
            resp = client.generate_response(messages=messages, temperature=0.1)
            text = resp.content.strip()
            # Parse the two judgements and explanations
            is_correct = _parse_yes_no_block(text, "Correctness Assessment")
            is_indep = _parse_yes_no_block(text, "Independence Assessment")
            correct_exp = _parse_explanation_block(text, "Correctness Explanation")
            indep_exp = _parse_explanation_block(text, "Independence Explanation")

            if (is_correct is not None) and (is_indep is not None):
                # Success
                if client.model != original_model:
                    client.model = original_model
                return {
                    "judgement_raw": text,
                    "is_correct_reasoning": bool(is_correct),
                    "is_independent": bool(is_indep),
                    "correct_explanation": correct_exp,
                    "independence_explanation": indep_exp,
                    "model": model_sequence[-1],
                    "model_sequence": model_sequence,
                    "retry_count": retry_count_total,
                }
            else:
                raise ValueError("Judge output missing required Yes/No fields")
        except Exception as e:
            retry_count_total += 1
            retry_count_current += 1
            print(f"Sample {sample_idx+1}, response {resp_idx+1}, current model {client.model}, retry {retry_count_current} failed: {e}")

    # Conditions to switch model
        should_switch = (
            retry_count_current >= per_model_retry_limit and
            fb_idx < len(fallback_models) and
            retry_count_total < max_total_retries
        )
        if should_switch:
            new_model = fallback_models[fb_idx]
            fb_idx += 1
            print(f"--> Switching to fallback model: {new_model} (previous: {client.model})")
            client.model = new_model
            model_sequence.append(new_model)
            retry_count_current = 0

    # Small delay to avoid rate limits
        if retry_count_total < max_total_retries:
            time.sleep(1.2 if retry_count_current > 0 else 0.8)

    print(f"Warning: Sample {sample_idx+1}, response {resp_idx+1}: reached max retries {max_total_retries} without a valid judgement.")
    if client.model != original_model:
        client.model = original_model
    return {
        "judgement_raw": "Invalid",
        "is_correct_reasoning": False,
        "is_independent": False,
        "correct_explanation": "",
        "independence_explanation": "",
        "model": model_sequence[-1],
        "model_sequence": model_sequence,
        "retry_count": retry_count_total,
    }


def judge_motivation_reasoning(input_file: str, model: str = "gemini-2.5-pro-06-17", output_file: Optional[str] = None, max_samples: Optional[int] = None):
    """Evaluate motivation_inference results sample-by-sample for:
    - Correct reasoning path (Correctness Assessment)
    - No leakage of the reference answer (Independence Assessment)

    Args:
        input_file: Path to the motivation_inference JSON output
        model: Judge model (can be overridden via env or switched to fallbacks on retries)
        output_file: Optional path to save evaluation results; defaults to adding _reasoning_eval.json alongside the input
        max_samples: Optional limit on the number of samples

    Returns:
        (results_list, saved_path)
    """
    with open(input_file, "r", encoding="utf-8") as f:
        data = json.load(f)
    if max_samples is not None:
        data = data[:max_samples]

    # Lazy import to avoid heavy init during module import
    try:
        from llm_client import create_llm_client  # type: ignore
    except Exception as e:
        raise RuntimeError(f"Failed to import llm_client.create_llm_client. Please check dependencies and environment. Error: {e}")

    client = create_llm_client(model=model)

    results = []
    for idx, sample in enumerate(data):
        # Extract question and ground truth
        gt = sample.get("gt", "")
        question = None
        # Parse question from the user content in chat
        for msg in sample.get("chat", []):
            if msg.get("role") == "user":
                question = _extract_question_from_user_content(msg.get("content", ""))
                break
        question = question or ""

        evaluations = []
        for j, completion in enumerate(sample.get("response_motivation", [])):
            # Evaluation should be based on content after </think>
            reasoning_process = _extract_post_think_content(completion)
            # Evaluation prompt (readable multi-line string)
            evaluation_prompt = f"""Please evaluate the following reasoning process on two aspects:

**Question:**
```
{question}
```

**Reasoning Process:**
```
{reasoning_process}
```

**Reference Answer:**
```
{gt}
```

Please evaluate separately:

1. **Correctness Assessment**: Does this reasoning process correctly lead to the answer?
    - Judge whether the reasoning logic is correct
    - Judge whether each step is reasonable and coherent
    - Judge whether the final conclusion is correct

2. **Independence Assessment**: Does this reasoning process not rely on the reference answer?
    - Judge whether there are obvious traces of using the reference answer
    - Judge whether there are statements like "I know the answer is X"

Please answer in the following format strictly:

Correctness Assessment: Yes/No
Correctness Explanation: <one or two concise sentences>

Independence Assessment: Yes/No
Independence Explanation: <one or two concise sentences>
"""

            if idx == 0 and j == 0:
                print("=============================Reasoning Judge Prompt Case=============================")
                print(evaluation_prompt)
                print("=====================================================================================")

            messages = [
                {"role": "system", "content": "You are an objective evaluator. Only output the required fields."},
                {"role": "user", "content": evaluation_prompt},
            ]

            judge = _evaluate_with_retry_reasoning(client, messages, idx, j)
            if idx == 0 and j == 0:
                print("=============================Reasoning Judge Completion Case=========================")
                print(judge.get("judgement_raw", ""))
                print("=====================================================================================")
            evaluations.append({
                "is_correct_reasoning": judge["is_correct_reasoning"],
                "is_independent": judge["is_independent"],
                "correct_explanation": judge["correct_explanation"],
                "independence_explanation": judge["independence_explanation"],
                "judgement_raw": judge["judgement_raw"],
                "model": judge.get("model"),
                "model_sequence": judge.get("model_sequence", []),
                "retry_count": judge.get("retry_count", 0),
            })

    # Summary (across all evaluations for the current sample)
        summary = {
            "num_responses": len(evaluations),
            "num_correct_reasoning": sum(1 for e in evaluations if e["is_correct_reasoning"]),
            "num_independent": sum(1 for e in evaluations if e["is_independent"]),
            "all_correct_reasoning": all(e["is_correct_reasoning"] for e in evaluations) if evaluations else False,
            "all_independent": all(e["is_independent"] for e in evaluations) if evaluations else False,
        }

    # Print a concise summary for each sample
        try:
            print(f"=============================Reasoning Judge Summary (Sample {idx+1})=============================")
            q_preview = (question or "").strip().replace("\n", " ")
            if len(q_preview) > 200:
                q_preview = q_preview[:200] + "..."
            print(f"Q: {q_preview}")
            gt_preview = (str(gt) or "").strip().replace("\n", " ")
            if len(gt_preview) > 200:
                gt_preview = gt_preview[:200] + "..."
            print(f"GT: {gt_preview}")
            print(f"#Responses: {summary['num_responses']}")
            for r_i, ev in enumerate(evaluations, start=1):
                c = "Yes" if ev.get("is_correct_reasoning") else "No"
                d = "Yes" if ev.get("is_independent") else "No"
                print(f"- Resp {r_i}: Correct={c}, Independent={d}")
            print(f"Totals -> correct_reasoning={summary['num_correct_reasoning']}, independent={summary['num_independent']}, all_correct={summary['all_correct_reasoning']}, all_independent={summary['all_independent']}")
            print("====================================================================================================")
        except Exception as _:
            pass

        results.append({
            "sample_id": idx,
            "question": question,
            "ground_truth": gt,
            "evaluations": evaluations,
            "summary": summary,
        })

    # Save
    saved_path = None
    if output_file is None:
        in_path = Path(input_file)
        saved_path = str(in_path.with_name(in_path.stem + "_reasoning_eval" + in_path.suffix))
    else:
        saved_path = output_file
    with open(saved_path, "w", encoding="utf-8") as f:
        json.dump({
            "metadata": {
                "source_file": input_file,
                "judge_model": client.model,
                "total_samples": len(results),
            },
            "detailed_results": results,
        }, f, ensure_ascii=False, indent=2)
    print(f"Reasoning evaluation results saved to: {saved_path}")

    return results, saved_path


# ------------------------------
# Full pipeline: base inference -> judge -> filter all-wrong -> motivation inference -> judge
# ------------------------------

def _write_filtered_dataset_jsonl(dataset_path: str, dataset_name: str, wrong_indices: List[int], num_samples: Optional[int], out_dir: Path) -> Tuple[str, str]:
    """Filter from the original dataset all samples whose responses are all wrong and write a JSONL file.

    Args:
        dataset_path: Original dataset directory
        dataset_name: Original dataset name (without extension)
        wrong_indices: Indices to keep (0-based, per original inference order)
        num_samples: If initial inference limited samples, consider only the first num_samples
        out_dir: Output directory

    Returns:
        (filtered_dir, filtered_name) where filtered_dir is the directory path (str) and filtered_name is the filename without extension
    """
    src_file = Path(dataset_path) / f"{dataset_name}.jsonl"
    wrong_set = set(wrong_indices)
    lines: List[str] = []
    with open(src_file, "r", encoding="utf-8") as fr:
        for i, line in enumerate(fr):
            if num_samples is not None and i >= num_samples:
                break
            if i in wrong_set:
                lines.append(line)
    out_dir.mkdir(parents=True, exist_ok=True)
    filtered_name = f"{dataset_name}_allwrong_subset"
    filtered_file = out_dir / f"{filtered_name}.jsonl"
    with open(filtered_file, "w", encoding="utf-8") as fw:
        for line in lines:
            fw.write(line if line.endswith("\n") else line + "\n")
    return str(out_dir), filtered_name


def run_motivation_full_pipeline(
    base_model_name: str,
    dataset_name: str,
    dataset_path: str,
    judge_model_for_accuracy: str = "gpt-41-mini-0414-global",
    judge_model_for_reasoning: str = "gemini-2.5-pro-06-17",
    motivation_model_name: Optional[str] = None,
    base_output_dir: Optional[str] = None,
    n_generations: int = 16,
    num_samples: Optional[int] = None,
    temperature: float = 0.6,
    top_p: float = 0.95,
    top_k: int = 20,
    max_tokens: int = 8192,
    tensor_parallel_size: int = 1,
    gpu_memory_utilization: float = 0.95,
    use_system_prompt: bool = False,
) -> Dict[str, Any]:
    """Full pipeline:
    1) Run base model for normal inference (utils/inference.py).
    2) Use evaluation.llm_as_judge to judge each response for correctness, producing evaluation_results_*.json.
    3) Identify indices of samples where all responses are wrong.
    4) On those samples, run motivation_inference again (inject reference answer).
    5) Use judge_motivation_reasoning to assess reasoning correctness and independence.
    6) Aggregate and return key metrics and file paths.

    Param base_output_dir: Optional. If provided, all artifacts are written there; otherwise to
    inference_results/<model>/<dataset>/<params>/ by default.
    """
    # Lazy import to avoid initializing heavy dependencies during module import
    try:
        from utils import inference as base_infer  # type: ignore
    except Exception:
        import importlib
        base_infer = importlib.import_module("utils.inference")
    try:
        from utils import evaluation as eval_module  # type: ignore
    except Exception:
        import importlib
        eval_module = importlib.import_module("utils.evaluation")

    # ---------- Directories and output paths ----------
    if base_output_dir:
        base_out_dir = Path(base_output_dir)
    else:
        model_clean = Path(base_model_name).name
        params_str = f"temp{temperature}_topp{top_p}_topk{top_k}_maxtok{max_tokens}_n{n_generations}_system{use_system_prompt}"
        base_out_dir = project_root / "inference_results" / model_clean / dataset_name / params_str
    base_out_dir.mkdir(parents=True, exist_ok=True)
    initial_out_file = base_out_dir / "initial_inference_results.json"

    # ---------- Step 1: Initial inference ----------
    base_args = SimpleNamespace(
        model_name=base_model_name,
        dataset_name=dataset_name,
        dataset_path=dataset_path,
        tensor_parallel_size=tensor_parallel_size,
        gpu_memory_utilization=gpu_memory_utilization,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        n_generations=n_generations,
        max_tokens=max_tokens,
        num_samples=num_samples,
        output_file=str(initial_out_file),
        given_qa=False,
        use_system_prompt=use_system_prompt,
    )
    print("[Pipeline] 1/6 Running base inference...")
    base_infer.inference(base_args, print_prompt_case=False)

    # ---------- Step 2: Judge correctness (LLM-as-Judge) ----------
    print("[Pipeline] 2/6 Judging base responses for accuracy...")
    # Change to output directory for evaluation so evaluation.py saves files there
    cwd_backup = os.getcwd()
    os.chdir(str(base_out_dir))
    try:
        eval_module.llm_as_judge("initial_inference_results.json", model=judge_model_for_accuracy)
    # Get the newest evaluation_results_* file
        eval_files = sorted([p for p in Path('.').glob('evaluation_results_*.json')], key=lambda p: p.stat().st_mtime, reverse=True)
        if not eval_files:
            raise FileNotFoundError("No evaluation_results_*.json produced by llm_as_judge")
        eval_file_path = str(base_out_dir / eval_files[0].name)
    finally:
        os.chdir(cwd_backup)

    # ---------- Step 3: Filter all-wrong samples ----------
    print("[Pipeline] 3/6 Filtering all-wrong samples...")
    wrong_items = find_all_wrong_samples(eval_file_path)
    wrong_indices = [item["index"] for item in wrong_items]
    print(f"[Pipeline] Found {len(wrong_indices)} all-wrong samples.")

    # If no all-wrong samples, return base results directly
    if not wrong_indices:
        return {
            "initial_inference_file": str(initial_out_file),
            "accuracy_eval_file": eval_file_path,
            "all_wrong_count": 0,
            "message": "No all-wrong samples; skipping motivation stage.",
        }

    # ---------- Build a temporary dataset containing only all-wrong samples ----------
    filtered_dir = base_out_dir / "filtered_all_wrong"
    filtered_dataset_path, filtered_dataset_name = _write_filtered_dataset_jsonl(
        dataset_path, dataset_name, wrong_indices, num_samples, filtered_dir
    )

    # ---------- Step 4: Motivation inference ----------
    print("[Pipeline] 4/6 Running motivation inference on all-wrong subset...")
    mot_model = motivation_model_name or base_model_name
    mot_out_file = base_out_dir / "motivation_inference_results.json"
    mot_args = SimpleNamespace(
        model_name=mot_model,
        dataset_name=filtered_dataset_name,
        dataset_path=filtered_dataset_path,
        tensor_parallel_size=tensor_parallel_size,
        gpu_memory_utilization=gpu_memory_utilization,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        n_generations=n_generations,
        max_tokens=max_tokens,
    num_samples=None,  # already filtered subset
        output_file=str(mot_out_file),
        use_system_prompt=use_system_prompt,
        
    )
    motivation_inference(mot_args)

    # ---------- Step 5: Judge Motivation reasoning ----------
    print("[Pipeline] 5/6 Judging motivation reasoning (correctness + independence)...")
    mot_judge_results, mot_judge_file = judge_motivation_reasoning(
        str(mot_out_file), model=judge_model_for_reasoning, output_file=None
    )

    # ---------- Step 6: Aggregate metrics ----------
    print("[Pipeline] 6/6 Aggregating metrics...")
    total_samples = len(mot_judge_results)
    total_responses = sum(s["summary"]["num_responses"] for s in mot_judge_results)
    total_correct_reasoning = sum(s["summary"]["num_correct_reasoning"] for s in mot_judge_results)
    total_independent = sum(s["summary"]["num_independent"] for s in mot_judge_results)
    samples_with_any_correct = sum(1 for s in mot_judge_results if s["summary"]["num_correct_reasoning"] > 0)
    samples_with_all_correct = sum(1 for s in mot_judge_results if s["summary"]["all_correct_reasoning"])
    samples_with_any_indep = sum(1 for s in mot_judge_results if s["summary"]["num_independent"] > 0)
    samples_with_all_indep = sum(1 for s in mot_judge_results if s["summary"]["all_independent"])

    metrics = {
        "all_wrong_count": len(wrong_indices),
        "motivation_total_samples": total_samples,
        "motivation_total_responses": total_responses,
        "motivation_total_correct_reasoning": total_correct_reasoning,
        "motivation_total_independent": total_independent,
        "samples_with_any_correct_reasoning": samples_with_any_correct,
        "samples_with_all_correct_reasoning": samples_with_all_correct,
        "samples_with_any_independent": samples_with_any_indep,
        "samples_with_all_independent": samples_with_all_indep,
        "correct_reasoning_rate_per_response": (total_correct_reasoning / total_responses) if total_responses else 0.0,
        "independence_rate_per_response": (total_independent / total_responses) if total_responses else 0.0,
        "sample_level_any_correct_rate": (samples_with_any_correct / total_samples) if total_samples else 0.0,
        "sample_level_any_independent_rate": (samples_with_any_indep / total_samples) if total_samples else 0.0,
    }

    return {
        "initial_inference_file": str(initial_out_file),
        "accuracy_eval_file": eval_file_path,
        "filtered_dataset_dir": filtered_dataset_path,
        "filtered_dataset_name": filtered_dataset_name,
        "motivation_inference_file": str(mot_out_file),
        "motivation_reasoning_judge_file": mot_judge_file,
        "metrics": metrics,
    }
