#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Convert TIME-Dial style data into VERL v2 format with temporal memory prompts.

Key features:
- No hardcoded local paths (all file paths are CLI args).
- Deterministic shuffling via --seed.
- Token budget packing of dialogue history.
- Robust handling of tiktoken model names and fallbacks.
- Clear typing, docstrings, logging, and error handling.

License: MIT
"""

from __future__ import annotations

import argparse
import json
import logging
import random
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple

# Optional: tqdm is nice to have but not required
try:
    from tqdm import tqdm
    _HAS_TQDM = True
except Exception:  # pragma: no cover
    _HAS_TQDM = False

# tiktoken is used for approximate token counting
try:
    import tiktoken
except Exception as e:  # pragma: no cover
    tiktoken = None  # type: ignore


# ------------------------------
# Data structures
# ------------------------------

@dataclass
class Utterance:
    speaker: str
    text: str
    utterance_time: str
    event_time: Dict[str, str]
    is_relevant: str  # "yes" | "no"


@dataclass
class Session:
    session_id: str
    session_time: str
    is_relevant: str  # "yes" | "no"
    utterances: List[Utterance]


# ------------------------------
# Token counting
# ------------------------------

def count_tokens(text: str, model_name: str = "gpt-3.5-turbo") -> int:
    """
    Count tokens for a given text using tiktoken.
    Falls back to 'cl100k_base' if model name is unknown.
    If tiktoken is unavailable, fall back to a simple whitespace-based heuristic.
    """
    if tiktoken is None:
        # Simple fallback: words as tokens (coarse upper bound)
        return max(1, len(text.split()))

    try:
        enc = tiktoken.encoding_for_model(model_name)
    except Exception:
        enc = tiktoken.get_encoding("cl100k_base")
    return len(enc.encode(text))


# ------------------------------
# Formatting
# ------------------------------

def format_sessions(sessions: Sequence[Session]) -> str:
    """
    Render sessions into the textual memory block fed to the model.
    """
    lines: List[str] = []
    for s in sessions:
        ut_lines = []
        for u in s.utterances:
            ut_lines.append(f"{u.speaker}: {u.utterance_time}: {u.text}\n")
        utterances = " ".join(ut_lines)
        lines.append(f"{s.session_time}: <{s.session_id}> \n {utterances}")
    return "\n".join(lines)


def build_dialogue_history_with_token_budget(
    gold_sessions: Sequence[Session],
    non_gold_sessions: Sequence[Session],
    question_text: str,
    now: datetime,
    max_tokens: int = 7500,
    model_name: str = "gpt-3.5-turbo",
    rng: Optional[random.Random] = None,
) -> str:
    """
    Fill the dialogue history under a token budget:
    1) Start with all gold sessions
    2) Shuffle non-gold (deterministically if rng provided)
    3) Greedily append until budget exceeded
    4) Shuffle final selection order
    """
    rng = rng or random.Random()
    selected = list(gold_sessions)
    candidates = list(non_gold_sessions)
    rng.shuffle(candidates)

    for sess in candidates:
        candidate_sessions = selected + [sess]
        dh = format_sessions(candidate_sessions)
        if count_tokens(dh, model_name=model_name) > max_tokens:
            break
        selected.append(sess)

    rng.shuffle(selected)
    return format_sessions(selected)


def build_cot_prompt(question, dialogue_history, now=None):
    response_format = f"""You are a memory-aware assistant tasked with answering temporal questions based on multi-turn dialogue history.
Output JSON MUST include:
1. "selected_memory": a list of memory identifiers (e.g., "session_1") or quoted utterances that are relevant to answering the question.
2. "answer": must strictly follow one of the formats below:
    - Single choice: "A", "B", etc.
    - Multiple choice: "A B C", "B D", etc.
    - Time: "10:45:41 pm, January 15, 2024"
    - Sequence: "(1)(4)(3)(2)"

Important: When analyzing the dialogue history:
- Pay close attention to the temporal relationships between events
- Consider the chronological order of events
- Identify explicit and implicit time references
- Ensure the answer is consistent with selected memory and the temporal context

**Expected Response:**
```json
{{
  "selected_memory": ["<session_1>", "<session_2>"],
  "answer": "Friday"
}}
```


### Example

**User Question:**  
2025-07-23 09:35:42.982502: What has Nicolas experienced in just 22 years before 11:49 pm, January 09, 2024?
A.Nicolas feels confident during interviews due to his skincare routine.
B.Nicolas finds recorded video interviews annoying despite completing them.
C.Nicolas avoids interviews by focusing on his culinary career.
D.Nicolas feels anxious during interviews despite taking anxiety medicati

**session memory:**  
03:45:22 PM, July 14, 2025: <session_1>: India: I went to the museum last Friday.  
                                       Debra: Which one?  
05:12:51 PM, July 16, 2025: <session_2>: India: The MoMA.  
06:32:42 PM, July 17, 2025: <session_3>: Debra: That sounds fun.  
07:34:13 PM, July 18, 2025: <session_4>: India: I stayed there until 6 pm.

**Expected Response:**
```json
{{
  "selected_memory": ["<session_1>", "<session_2>"],
  "answer": "Friday"
}}
```

Now Your Turn

**User Question:**
{str(now)}: {question}

***session memory:** 
{dialogue_history}

**Expected Response:**
```json
{{
  "selected_memory": ["<relevant_session_or_utterance1>", "<relevant_session_or_utterance2>"],
  "answer": "<your generated response in the exact required format>"
}}
```
"""
    return response_format

def extract_time_range_from_events(events: List[Dict[str, Any]]) -> Dict[str, str]:
    """
    Merge a list of event_time intervals into an overall [start, end] dict.
    Unknowns are preserved if no better bound exists.
    """
    start_time: Optional[str] = None
    end_time: Optional[str] = None


    for ev in events or []:
        et = ev.get("event_time")
        if not et or not isinstance(et, (list, tuple)) or len(et) != 2:
            continue

        s, e = et[0], et[1]

        if start_time is None:
            start_time = s
        if end_time is None:
            end_time = e

        if s != "unknown" and (start_time == "unknown" or (isinstance(start_time, str) and s < start_time)):
            start_time = s
        if e != "unknown" and (end_time == "unknown" or (isinstance(end_time, str) and e > end_time)):
            end_time = e

    if start_time is None or end_time is None:
        return {"start": "unknown", "end": "unknown"}
    return {"start": start_time, "end": end_time}

def extract_time_range_from_question(
    qmap: Dict[str, Any],
    question_id: Any,
    default_start: str = "2021-01-01",
    default_end: str = "2025-12-31",
    ) -> Dict[str, str]:
    """
    Look up time range for a question; return defaults if not found.
    """
    item = qmap.get(str(question_id))
    if item and isinstance(item, dict) and "question_time_range" in item:
        tr = item["question_time_range"]
        if isinstance(tr, dict) and "start" in tr and "end" in tr:
            return {"start": tr["start"], "end": tr["end"]}
        return {"start": default_start, "end": default_end}

def to_yes_no(flag: bool) -> str:
    return "yes" if flag else "no"

def build_sessions_for_context(
    context_obj: Dict[str, Any],
    evidence_sessions: Sequence[str],
    evidence_utterances: Sequence[Dict[str, Any]]
    ) -> List[Session]:
    """
    Convert the labeled contexts into normalized Session objects.
    """
    all_sessions: List[Session] = []
    ev_sessions = {str(sid) for sid in evidence_sessions}

    # Speed up utterance matching by text
    ev_ut_texts = {ev.get("utterance", "") for ev in evidence_utterances if isinstance(ev, dict)}

    for raw_id, session_content in (context_obj or {}).items():
        is_session_relevant = to_yes_no(str(raw_id) in ev_sessions)
        utterances: List[Utterance] = []
        for ut in session_content.get("content", []):
            ut_text = ut.get("utterance", "")
            ut_relevant = to_yes_no(ut_text in ev_ut_texts)
            utterances.append(
                Utterance(
                    speaker=str(ut.get("speaker", "")).split(" ")[0],
                    text=ut_text,
                    utterance_time=ut.get("utterance_date", ""),
                    event_time=extract_time_range_from_events(ut.get("events", [])),
                    is_relevant=ut_relevant
                )
            )
        session = Session(
            session_id=f"session_{raw_id}",
            session_time=session_content.get("session_date", ""),
            is_relevant=is_session_relevant,
            utterances=utterances
        )
        all_sessions.append(session)

    return all_sessions

def convert_to_verl_v2_format(
    input_path: Path,
    output_path: Path,
    contexts_path: Path,
    qmap_path: Path,
    max_tokens: int = 7500,
    model_name: str = "gpt-3.5-turbo",
    start_index: Optional[int] = None,
    end_index: Optional[int] = None,
    seed: int = 20250717,
    print_prompts: bool = False,
    ) -> List[Dict[str, Any]]:
    """
    Core converter: reads TIME-Dial-like JSON and outputs VERL v2 JSON list.
    """
    # Load files
    with contexts_path.open("r", encoding="utf-8") as f:
        labeled_contexts = json.load(f)
    with input_path.open("r", encoding="utf-8") as f:
        time_dial_data = json.load(f)
    with qmap_path.open("r", encoding="utf-8") as f:
        qmap = json.load(f)


    verl_data: List[Dict[str, Any]] = []
    rng = random.Random(seed)
    now = datetime.now()

    # Iteration bounds
    total = len(time_dial_data)
    s_idx = max(0, (start_index - 1)) if start_index else 0
    e_idx = min(total, end_index) if end_index else total

    iterator = range(s_idx, e_idx)
    if _HAS_TQDM:
        iterator = tqdm(iterator, total=e_idx - s_idx, desc="Converting")

    for i in iterator:
        item = time_dial_data[i]

        question_text = item.get("Question", "")
        question_id = item.get("Question ID", "")
        gold_answer = item.get("Gold Answer", "")
        evidence_list = item.get("Evidence", []) or []

        evidence_sessions: List[str] = []
        evidence_utterances: List[Dict[str, Any]] = []

        for ev in evidence_list:
            sid = ev.get("session_id")
            if sid is not None:
                evidence_sessions.append(str(sid))
            evidence_utterances.append(ev)

        question_time_range = extract_time_range_from_question(qmap, question_id)
        context_key = item.get("Context", "")

        full_sessions = labeled_contexts.get(context_key, {})
        all_sessions = build_sessions_for_context(full_sessions, evidence_sessions, evidence_utterances)

        selected_sessions = sorted([f"session_{sid}" for sid in set(evidence_sessions)])

        # Split gold vs non-gold
        gold = [s for s in all_sessions if s.session_id in selected_sessions]
        non_gold = [s for s in all_sessions if s.session_id not in selected_sessions]

        # Build dialogue history under token budget
        dialogue_history = build_dialogue_history_with_token_budget(
            gold, non_gold, question_text, now,
            max_tokens=max_tokens,
            model_name=model_name,
            rng=rng
        )

        prompt = build_cot_prompt(question_text, dialogue_history, now)
        if print_prompts:
            logging.info("Prompt for item %s:\n%s", question_id, prompt)

        verl_item = {
            "data_source": item.get("Dataset Name", "TIME-Dial"),
            "prompt": [{"content": prompt, "role": "user"}],
            "ability": item.get("Task", item.get("Task", "")),
            "reward_model": {
                "ground_truth": {
                    "answer": gold_answer,
                    "selected_sessions": selected_sessions
                },
                "style": "rule-lighteval/TIME-Dial_v1"
            },
            "extra_info": {
                "index": question_id,
                "gold_selected_sessions": selected_sessions,
                "question_time_range": question_time_range,
                # Serialize sessions to plain dicts
                "all_sessions": [
                    {
                        "session_id": s.session_id,
                        "session_time": s.session_time,
                        "is_relevant": s.is_relevant,
                        "utterances": [
                            {
                                "speaker": u.speaker,
                                "text": u.text,
                                "utterance_time": u.utterance_time,
                                "event_time": u.event_time,
                                "is_relevant": u.is_relevant
                            } for u in s.utterances
                        ]
                    } for s in all_sessions
                ]
            }
        }
        verl_data.append(verl_item)

    # Write output
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with output_path.open("w", encoding="utf-8") as f:
        json.dump(verl_data, f, ensure_ascii=False, indent=2)

    logging.info("Wrote %d examples to %s", len(verl_data), output_path)
    return verl_data

def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
    description="Convert TIME-Dial style context + questions to VERL v2 format.",
    formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    p.add_argument("--input", type=Path, required=True,
    help="TIME-Dial context-with-evidence JSON (e.g., 0717.json)")
    p.add_argument("--output", type=Path, required=True,
    help="Output VERL v2 JSON path")
    p.add_argument("--contexts", type=Path, required=True,
    help="Context file (e.g., time_range/time_dial_labeled_contexts.json)")
    p.add_argument("--qmap", type=Path, required=True,
    help="Question → time range map JSON (e.g., question_time_range_annotated.json)")
    p.add_argument("--max-tokens", type=int, default=7500,
    help="Approximate token budget for dialogue history packing")
    p.add_argument("--model-name", type=str, default="gpt-3.5-turbo",
    help="Model name for tiktoken tokenization")
    p.add_argument("--seed", type=int, default=20250717,
    help="Random seed for deterministic shuffling")
    p.add_argument("--start-index", type=int, default=None,
    help="1-based inclusive start index of records to convert")
    p.add_argument("--end-index", type=int, default=None,
    help="1-based exclusive end index of records to convert")
    p.add_argument("--print-prompts", action="store_true",
    help="Print generated prompts to logs")
    p.add_argument("--log-level", type=str, default="INFO",
    choices=["DEBUG", "INFO", "WARNING", "ERROR"])
    return p.parse_args()

def main() -> None:
    args = parse_args()
    logging.basicConfig(level=getattr(logging, args.log_level.upper(), logging.INFO),
    format="%(levelname)s - %(message)s")

    try:
        convert_to_verl_v2_format(
            input_path=args.input,
            output_path=args.output,
            contexts_path=args.contexts,
            qmap_path=args.qmap,
            max_tokens=args.max_tokens,
            model_name=args.model_name,
            start_index=args.start_index,
            end_index=args.end_index,
            seed=args.seed,
            print_prompts=args.print_prompts,
        )
    except FileNotFoundError as e:
        logging.error("File not found: %s", e)
    except json.JSONDecodeError as e:
        logging.error("JSON decode error: %s", e)
    except Exception as e:
        logging.exception("Unexpected error: %s", e)
if __name__ == "__main__":
    main()