#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Temporal Memory Retrieval + Local LM Inference (Open-Source Ready)

- No hardcoded private paths; all I/O via CLI args
- Works on NPU / CUDA / CPU with auto-detect (and manual override)
- BM25 session retrieval with time-range filtering
- Qwen-like chat template inference with optional LoRA adapter
- Deterministic shuffling via --seed
- Clean logs and robust error handling

License: MIT
"""

from __future__ import annotations

import argparse
import json
import logging
import os
import re
import time
from dataclasses import dataclass
from datetime import datetime, date
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple

import numpy as np
import torch

# --- Optional deps ---
try:
    import torch_npu  # noqa: F401  # optional NPU support
    _HAS_NPU_LIB = True
except Exception:
    _HAS_NPU_LIB = False

try:
    from tqdm import tqdm
    _HAS_TQDM = True
except Exception:
    _HAS_TQDM = False

try:
    import tiktoken
    _HAS_TIKTOKEN = True
except Exception:
    _HAS_TIKTOKEN = False

try:
    from rank_bm25 import BM25Okapi
    _HAS_BM25 = True
except Exception:
    _HAS_BM25 = False

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel


# =========================
# Dataclasses
# =========================
@dataclass
class TimeRange:
    start: date
    end: date


# =========================
# Device utilities
# =========================
def pick_device(prefer: Optional[str] = None) -> torch.device:
    """
    Choose device by preference and availability.
    prefer ∈ {"npu","cuda","cpu", None}
    """
    if prefer is not None:
        prefer = prefer.lower()
    if prefer == "npu":
            # torch.npu may exist even without torch_npu module in some envs; guard both.
        if hasattr(torch, "npu") and torch.npu.is_available():
            return torch.device("npu:0")
        logging.warning("NPU requested but not available; falling back to auto.")
    if prefer == "cuda":
        if torch.cuda.is_available():
            return torch.device("cuda:0")
        logging.warning("CUDA requested but not available; falling back to auto.")
    if prefer == "cpu":
        return torch.device("cpu")

    # auto mode
    if hasattr(torch, "npu") and torch.npu.is_available():
        return torch.device("npu:0")
    if torch.cuda.is_available():
        return torch.device("cuda:0")
    return torch.device("cpu")


# =========================
# Token counting
# =========================
def count_tokens(text: str, tokenizer: Optional[AutoTokenizer] = None,
                 tiktoken_model: str = "cl100k_base") -> int:
    """
    Prefer tiktoken if available; else fall back to provided HF tokenizer; else naive split.
    """
    if _HAS_TIKTOKEN:
        try:
            # try model-specific first
            enc = tiktoken.get_encoding(tiktoken_model)
            return len(enc.encode(text))
        except Exception:
            pass
    if tokenizer is not None:
        try:
            return len(tokenizer.encode(text))
        except Exception:
            pass
    # naive fallback
    return max(1, len(text.split()))


# =========================
# Data loading
# =========================
def load_json(path: Path) -> Any:
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)


def extract_time_range_from_question(
    qmap: Dict[str, Any],
    question_id: Any,
    default_unknown_start: str = "1900-01-01",
    default_unknown_end: Optional[str] = None,
) -> TimeRange:
    """
    Map question_id → {start,end} (ISO dates). Unknown bounds are replaced by defaults.
    """
    default_unknown_end = default_unknown_end or datetime.now().date().isoformat()
    item = qmap.get(str(question_id))
    if item and "question_time_range" in item:
        tr = item["question_time_range"]
        s = tr.get("start", "unknown")
        e = tr.get("end", "unknown")
        start = datetime.fromisoformat(default_unknown_start if s == "unknown" else s).date()
        end = datetime.fromisoformat(default_unknown_end if e == "unknown" else e).date()
        return TimeRange(start=start, end=end)
    # fallback: [1900-01-01, today]
    return TimeRange(
        start=datetime.strptime(default_unknown_start, "%Y-%m-%d").date(),
        end=date.today()
    )


# =========================
# Formatting
# =========================
def format_session_context(session_id: str, session: Dict[str, Any]) -> str:
    """
    Render a session into a flat text document for retrieval scoring.
    """
    session_date = session.get("session_date", "")
    lines: List[str] = [f"{session_date}: <{session_id}>"]
    for utt in session.get("content", []):
        speaker = utt.get("speaker", "")
        # Prefer unified keys; fall back gracefully
        u_date = utt.get("utterance_date", utt.get("utterance_time", ""))
        u_text = utt.get("utterance", utt.get("content", ""))
        # concise format (date is often noise for BM25; keep text dominant)
        if u_text:
            lines.append(f"{speaker}: {u_text}")
    return "\n".join(lines)


def build_prompt(question: str, dialogue_sessions: str, now_str: str) -> str:
    return f"""You are presented with a temporal question and a previous memory, please answer the question with the correct format. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem.

### Output Requirements:
1. **Answer Format** (MUST match one of these exactly):
   - Single choice: `A` | `B` | etc. (uppercase, no quotes)
   - Multiple choice: `A B C` (space-separated uppercase)
   - Time: `HH:MM:SS [am|pm], Month Day, Year` (e.g., `10:45:41 pm, January 15, 2024`)
   - Sequence: `(1)(3)(2)(4)(5)(6)(7)(8)` (numbered parentheses, no spaces)

<previous_memory>
{dialogue_sessions}
</previous_memory>

<question>
Time: {now_str}
Question: {question}
</question>

Remember to put your answer on its own line after "Answer:".
"""


# =========================
# Retrieval
# =========================
def bm25_retriever(query: str, sessions: Dict[str, Any], top_k: int = 10) -> List[str]:
    if not _HAS_BM25:
        # Fallback: trivial ordering by length heuristic
        logging.warning("rank_bm25 not installed; using length-heuristic fallback.")
        scored = []
        q_terms = set(query.lower().split())
        for sid, sess in sessions.items():
            doc = format_session_context(sid, sess)
            # very crude overlap score
            score = sum(1 for w in doc.lower().split() if w in q_terms)
            scored.append((score, sid))
        scored.sort(reverse=True)
        return [sid for _, sid in scored[:top_k]]

    documents: List[str] = []
    session_ids: List[str] = []
    for sid, sess in sessions.items():
        documents.append(format_session_context(sid, sess))
        session_ids.append(sid)

    tokenized_docs = [doc.split() for doc in documents]
    bm25 = BM25Okapi(tokenized_docs)
    tokenized_query = query.split()
    scores = bm25.get_scores(tokenized_query)
    top_idx = np.argsort(scores)[-top_k:][::-1]
    return [session_ids[i] for i in top_idx]


# =========================
# Inference
# =========================
def load_model_and_tokenizer(
    base_model_path: Path,
    lora_model_path: Optional[Path],
    device: torch.device,
    torch_dtype: torch.dtype = torch.bfloat16,
) -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
    tokenizer = AutoTokenizer.from_pretrained(str(base_model_path), trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        str(base_model_path),
        device_map=None,  # we place everything on one device below
        torch_dtype=torch_dtype,
        trust_remote_code=True,
        do_sample=False,
    ).to(device)

    if lora_model_path:
        model = PeftModel.from_pretrained(model, str(lora_model_path), device_map={"": device})
        logging.info("Loaded LoRA adapter from %s", lora_model_path)

    return tokenizer, model


@torch.inference_mode()
def infer_with_local_model(
    prompt: str,
    tokenizer: AutoTokenizer,
    model: AutoModelForCausalLM,
    temperature: float = 0.1,
    top_p: float = 0.95,
    max_new_tokens: int = 128,
) -> str:
    # Build chat template if available (Qwen-style)
    try:
        prompt_text = tokenizer.apply_chat_template(
            [{"role": "user", "content": prompt}],
            add_generation_prompt=True,
            tokenize=False
        )
    except Exception:
        prompt_text = prompt

    # Log estimated tokens
    est_tokens = count_tokens(prompt_text, tokenizer=tokenizer)
    logging.debug("Estimated prompt tokens: %d", est_tokens)

    inputs = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=15500)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    input_token_len = inputs["input_ids"].shape[1]
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
        eos_token_id=tokenizer.eos_token_id
    )
    new_ids = outputs[0][input_token_len:]
    text = tokenizer.decode(new_ids, skip_special_tokens=True).strip()

    # NPU cache cleanup if available
    if hasattr(torch, "npu"):
        try:
            torch.npu.empty_cache()
        except Exception:
            pass

    return text


def parse_model_output(output: str) -> str:
    # keep only the content after "Answer: "
    return output.replace("Answer: ", "").strip()


# =========================
# Pipeline
# =========================
def run(
    contexts_path: Path,
    questions_path: Path,
    qmap_path: Path,
    output_path: Path,
    base_model_path: Path,
    lora_model_path: Optional[Path],
    device_pref: Optional[str],
    max_prompt_tokens: int,
    top_k: int,
    seed: int,
    temperature: float,
    top_p: float,
    max_new_tokens: int,
    log_prompts: bool,
) -> None:
    # Seed
    np.random.seed(seed)
    import random as pyrand
    pyrand.seed(seed)
    torch.manual_seed(seed)

    # Device
    device = pick_device(device_pref)
    logging.info("Using device: %s", device)

    # Load data
    logging.info("Loading contexts from %s", contexts_path)
    time_contexts: Dict[str, Any] = load_json(contexts_path)

    logging.info("Loading questions from %s", questions_path)
    questions: List[Dict[str, Any]] = load_json(questions_path)

    logging.info("Loading question time map from %s", qmap_path)
    qmap: Dict[str, Any] = load_json(qmap_path)

    # Load model
    logging.info("Loading model/tokenizer from %s", base_model_path)
    tokenizer, model = load_model_and_tokenizer(base_model_path, lora_model_path, device)

    results: List[Dict[str, Any]] = []
    recall_at_5_hits = 0
    recall_at_5_total = 0

    iterator = enumerate(questions)
    if _HAS_TQDM:
        iterator = tqdm(iterator, total=len(questions), desc="Inference")

    for i, q in iterator:
        query = q.get("Question", "")
        gold_sessions = q.get("Evidence", None)  # list of dicts with session_id?
        context_key = q.get("Context", "")
        ctx_sessions = time_contexts.get(context_key, {})

        tr = extract_time_range_from_question(qmap, q.get("Question ID"))
        # Time filter
        candidate_sessions: Dict[str, Any] = {}
        for sid, sess in ctx_sessions.items():
            s_date = sess.get("session_date", "")
            try:
                s_d = datetime.fromisoformat(s_date).date()
            except Exception:
                s_d = None
            if s_d is None or (tr.start <= s_d <= tr.end):
                candidate_sessions[sid] = sess

        if not candidate_sessions:
            candidate_sessions = ctx_sessions

        # Retrieve
        top_ids = bm25_retriever(query, candidate_sessions, top_k=top_k)

        # Build memory block under token budget
        retrieved_blocks: List[str] = []
        used_tokens = 0
        for sid in top_ids:
            block = format_session_context(sid, candidate_sessions[sid])
            block_tokens = count_tokens(block, tokenizer=tokenizer)
            if used_tokens + block_tokens <= max_prompt_tokens:
                retrieved_blocks.append(block)
                used_tokens += block_tokens
            else:
                break

        # Recall@5
        if gold_sessions:
            gold_ids = [str(ev.get("session_id")) for ev in gold_sessions if "session_id" in ev]
            top5 = set(top_ids[:5])
            for gid in gold_ids:
                if gid in top5:
                    recall_at_5_hits += 1
            recall_at_5_total += len(gold_ids)

        # Build prompt & infer
        now_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
        memory_text = "\n\n".join(retrieved_blocks)
        prompt = build_prompt(query, memory_text, now_str)

        if log_prompts:
            logging.debug("Prompt for QID=%s:\n%s", q.get("Question ID"), prompt)

        raw_output = infer_with_local_model(
            prompt, tokenizer, model,
            temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens
        )
        parsed = parse_model_output(raw_output)

        # Aggregate
        result = {
            "question_id": q.get("Question ID"),
            "question": query,
            "level": q.get("Level"),
            "task": q.get("Task"),
            "bm25_topk_session_ids": top_ids,
            "prompt": prompt if log_prompts else "(omitted; run with --log-prompts DEBUG to include)",
            "model_output": raw_output,
            "parsed_output": parsed,
            "gold_answer": q.get("Gold Answer", "")
        }
        results.append(result)

        # Save incremental
        output_path.parent.mkdir(parents=True, exist_ok=True)
        with output_path.open("w", encoding="utf-8") as f:
            json.dump(results, f, ensure_ascii=False, indent=2)

    if recall_at_5_total > 0:
        logging.info("Recall@5 = %.3f (%d/%d)", recall_at_5_hits / recall_at_5_total, recall_at_5_hits, recall_at_5_total)
    logging.info("Done. Results saved to %s", output_path)


# =========================
# CLI
# =========================
def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="Temporal Memory Retrieval + Local LM Inference",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    # Data
    p.add_argument("--contexts", type=Path, required=True, help="time_dial_labeled_contexts.json")
    p.add_argument("--questions", type=Path, required=True, help="time_dial_context_list_with_evidence.json")
    p.add_argument("--qmap", type=Path, required=True, help="question_time_range_annotated.json")
    p.add_argument("--output", type=Path, required=True, help="Path to save results JSON")

    # Model
    p.add_argument("--base-model", type=Path, required=True, help="Path to HF base model (e.g., Qwen / MemAgent)")
    p.add_argument("--lora", type=Path, default=None, help="Optional path to LoRA adapter")

    # Runtime
    p.add_argument("--device", type=str, default=None, choices=["npu", "cuda", "cpu", None],
                   help="Preferred device (auto if omitted)")
    p.add_argument("--seed", type=int, default=20250717)
    p.add_argument("--log-level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"])

    # Retrieval & Prompt
    p.add_argument("--top-k", type=int, default=10, help="BM25 top-K sessions")
    p.add_argument("--max-prompt-tokens", type=int, default=15500, help="Budget for retrieved memory text")

    # Generation
    p.add_argument("--temperature", type=float, default=0.1)
    p.add_argument("--top-p", type=float, default=0.95)
    p.add_argument("--max-new-tokens", type=int, default=128)

    # Logging content
    p.add_argument("--log-prompts", action="store_true", help="Include full prompts in output JSON and DEBUG logs")

    return p.parse_args()


def main() -> None:
    args = parse_args()
    logging.basicConfig(
        level=getattr(logging, args.log_level.upper(), logging.INFO),
        format="%(levelname)s - %(message)s"
    )

    try:
        run(
            contexts_path=args.contexts,
            questions_path=args.questions,
            qmap_path=args.qmap,
            output_path=args.output,
            base_model_path=args.base_model,
            lora_model_path=args.lora,
            device_pref=args.device,
            max_prompt_tokens=args.max_prompt_tokens,
            top_k=args.top_k,
            seed=args.seed,
            temperature=args.temperature,
            top_p=args.top_p,
            max_new_tokens=args.max_new_tokens,
            log_prompts=args.log_prompts or (args.log_level.upper() == "DEBUG"),
        )
    except FileNotFoundError as e:
        logging.error("File not found: %s", e)
    except json.JSONDecodeError as e:
        logging.error("JSON parse error: %s", e)
    except Exception as e:
        logging.exception("Unexpected error: %s", e)


if __name__ == "__main__":
    main()
