#!/usr/bin/env python3

import argparse
import json
import sqlite3
from pathlib import Path
from typing import List, Dict, Any, Optional
import pandas as pd
import yaml
import importlib
from concurrent.futures import ProcessPoolExecutor
import tqdm
import uuid
from joblib import Parallel, delayed
from json_repair import repair_json

from time_utils.time_utils import time_to_seconds, seconds_to_time
from prompts.en_llm_judge import get_prompt

PLACEHOLDERS = {
    "", "silent", "<NO_INFORMATION>", "<SILENT>",
    "got it, i'll let you know.",
    "收到，我会留意的。",
    "没问题，到时候提醒你。",
    "好的，到时候我会提醒你。",
    "没问题，到时候我会告诉你。",
    "No problem, I'll remind you then.",
    "Okay, I will alert you when it happens.",
    "I will make sure to remind you at that time.",
    "好的，那到时候我会提醒你。",
    "好的，到时候我会发提醒给你。",
    "Sure, I will let you know at that time.",
    "Noted, expect a reminder from me then.",
    "Ok, I will remind you then.",
    "Certainly, I will provide the reminder then.",
    "No problem, I'll remind you then.",
    "ok", "okay", "sure", "yes", "收到", "好的", "明白了", "got it, i will notify you at that moment."
}

def is_placeholder(text: str) -> bool:
    return text.strip().lower() in set(x.strip().lower() for x in PLACEHOLDERS)

def get_llm_prompt(question: str, model_output: str, reference_answer: str) -> str:
    return get_prompt(
        question=question, 
        model_output=model_output, 
        reference_answer=reference_answer
    )

class LLMJudger:
    def __init__(self, llm_client):
        self.llm = llm_client

    def judge(self, question: str, model_output: str, reference: str, retries=3) -> float:
        messages = [{"role": "user", "content": [{"type": "text", "text": get_llm_prompt(question, model_output, reference)}]}]
        for _ in range(retries):
            try:
                resp = self.llm.generate(messages, max_new_tokens=128).strip()
                print(resp)
                score = max(0.0, min(5.0, float(json.loads(repair_json(resp))["score"])))
                explanation = str(json.loads(resp)["explanation"])
                return {"explanation": explanation, "score": score}
            except Exception:
                raise
                continue
        return {"explanation": "Judger LLM parsing error", "score": 0}

def parse_logical_questions(raw_sqa: List[Dict]) -> List[Dict]:
    logical = []
    i = 0
    while i < len(raw_sqa):
        item = raw_sqa[i]
        if "question" in item and "response" in item:
            t_question = time_to_seconds(item["timestamp"])
            t_answer_event = t_question
            logical.append({
                "question_time_sec": t_question,
                "answer_event_time_sec": t_answer_event,
                "question": item["question"],
                "ground_truth": item["response"],
                "is_objective": "options" in item,
                "options": item.get("options"),
                "task_type": item.get("type", item.get("task_type", "DefaultType"))
            })
            i += 1
        elif "question" in item and "response" not in item:
            t_question = time_to_seconds(item["timestamp"])
            t_answer_event = t_question
            ground_truth = ""
            if i + 1 < len(raw_sqa):
                next_item = raw_sqa[i + 1]
                if "response" in next_item and "question" not in next_item:
                    t_answer_event = time_to_seconds(next_item["timestamp"])
                    ground_truth = next_item["response"]
                    i += 2
                else:
                    i += 1
            else:
                i += 1
            logical.append({
                "question_time_sec": t_question,
                "answer_event_time_sec": t_answer_event,
                "question": item["question"],
                "ground_truth": ground_truth,
                "is_objective": "options" in item,
                "options": item.get("options"),
                "task_type": item.get("type", item.get("task_type", "DefaultType"))
            })
        else:
            i += 1
    return logical

def evaluate_sample(
    sample: Dict[str, Any],
    llm_judger: Optional[LLMJudger],
    time_window: float = 3.0
) -> List[Dict[str, Any]]:
    logical_questions = parse_logical_questions(sample["sqa"])
    if not logical_questions:
        return []

    model_responses = []
    for r in sample["responses"]:
        t = time_to_seconds(r["timestamp"])
        model_responses.append((t, r["response"]))
    model_responses.sort(key=lambda x: x[0])

    results = []
    used_indices = set()

    for q in logical_questions:
        t_question = q['question_time_sec']
        t_answer = q["answer_event_time_sec"]
        is_future = t_question < t_answer
        ground_truth = q['ground_truth']

        window_start = t_question
        window_end = t_answer + time_window
        correct_time_end = t_answer if t_question==t_answer else (t_answer + time_window)

        for idx, (t, r) in enumerate(model_responses):
            if (idx not in used_indices) and window_start <= t <= window_end:
                if r==ground_truth or (not is_placeholder(r)):
                    q["model_response_time_sec"] = t
                    q["model_response_content"] = r
                    used_indices.add(idx)
                    break
                else:
                    continue
        
        t_model_reponse = q.get("model_response_time_sec", t_answer)
        c_model_response = q.get("model_response_content", "")

        explanation = ""
        if t_model_reponse < t_answer:
            score_100 = 0.0
            category = "EarlyResponse"
        elif c_model_response != ground_truth and is_placeholder(c_model_response):
            score_100 = 0.0
            category = "NoResponse"
        elif t_model_reponse > correct_time_end:
            score_100 = 0.0
            category = "LateResponse"
        elif q["is_objective"]:
            clean_up = lambda x: x.strip().replace(".", "")[:1]

            if c_model_response.lower() == ground_truth.lower() or \
                  clean_up(c_model_response).lower() == ground_truth.lower():
                score_100 = 100.0
                category = "Correct"
            else:
                score_100 = 0.0
                category = "WrongAnswer"
        else:
            if llm_judger is None:
                score_100 = 0.0
                category = "Error (no LLM)"
            else:
                raw = llm_judger.judge(q["question"], c_model_response, ground_truth)
                score_100 = raw["score"] * 20.0
                explanation = raw.get("explanation", "")
                category = "PartlyCorrect"

        results.append({
            "sample_id": sample["id"],
            "question_time": seconds_to_time(q["question_time_sec"]),
            "question": q['question'],
            "answer_time": seconds_to_time(q["answer_event_time_sec"]),
            "answer": ground_truth,
            "response_time": seconds_to_time(int(t_model_reponse)),
            "response": c_model_response,
            "score": score_100,
            "category": category,
            "task_type": q["task_type"],
            "is_objective": q["is_objective"],
            "explanation": explanation
        })
    return results

def load_llm(config):
    cls_path = config["class"]
    mod_name, cls_name = cls_path.rsplit(".", 1)
    mod = importlib.import_module(mod_name)
    cls = getattr(mod, cls_name)
    return LLMJudger(llm_client=cls(**config.get("args", {})))

def worker(samples_chunk, llm_config, time_window):
    llm = load_llm(llm_config) if llm_config else None
    all_results = []
    for sample in tqdm.tqdm(samples_chunk):
        all_results.extend(evaluate_sample(sample, llm, time_window))
    return all_results

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name", required=True)
    parser.add_argument("--model-output", required=True)
    parser.add_argument("--output-dir", default=None)
    parser.add_argument("--config", default="config/stream_config.yaml")
    parser.add_argument("--workers", type=int, default=32)
    parser.add_argument("--collections", type=str, default="v4.4")
    args = parser.parse_args()

    time_window = 3.0

    output_dir = Path(args.output_dir or Path(args.model_output).parent / "eval")
    output_dir.mkdir(exist_ok=True, parents=True)

    with open(args.config) as f:
        config = yaml.safe_load(f)
    llm_config = config.get("judger")

    with open(args.model_output) as f:
        samples = [json.loads(line) for line in f if line.strip()]

    chunks = [samples[i::args.workers] for i in range(args.workers)]
    all_results = Parallel(n_jobs=args.workers, backend='threading')(
        delayed(worker)(chunk, llm_config, time_window) for chunk in chunks
    )
    all_results = [item for sublist in all_results for item in sublist]

    print("saving")
    db_path = output_dir / f"{args.collections}.db"
    pd.DataFrame(all_results).to_sql(f"{args.model_name}", sqlite3.connect(db_path), if_exists="replace", index=False)

    df = pd.DataFrame(all_results)
    total = len(df)
    summary = {
        "model_name": args.model_name,
        "#samples": total,
        "final_score": df["score"].mean(),
    }

    for kind, mask in [("objective", df["is_objective"]), ("subjective", ~df["is_objective"])]:
        obj_sub_set = df[mask]
        score = obj_sub_set["score"].mean() if len(obj_sub_set) else 0.0
        percent = round(len(obj_sub_set) / total * 100, 1)
        summary[kind] = round(score, 1)

    def show_type_statistics(dataframe, types, type_name, subset_name, summary_dict):
        for typ in types:
            subset = dataframe[dataframe[type_name] == typ]
            score = subset["score"].mean() if len(subset) else 0.0
            percent = int(len(subset) / total * 100)
            key = f"{typ}({subset_name})"
            summary_dict[key] = round(score, 1)

    task_types = df["task_type"].dropna().unique()
    show_type_statistics(df[df["is_objective"]], task_types, "task_type", "objective", summary)
    show_type_statistics(df[~df["is_objective"]], task_types, "task_type", "subjective", summary)

    def show_type_distribution(dataframe, types, type_name, subset_name, summary_dict):
        for typ in types:
            subset = dataframe[dataframe[type_name] == typ]
            score = subset["score"].mean() if len(subset) else 0.0
            percent = round(len(subset) / len(dataframe) * 100, 1) if len(dataframe) else 0.0
            summary_dict[f"{typ}({subset_name})"] = f"{percent}%({round(score, 1)})"

    data = df[df["task_type"]=="future"]
    data = data[data["is_objective"]]
    categories = df["category"].unique()
    show_type_distribution(data, categories, "category", "objective-future", summary)

    data = df[df["task_type"]=="future"]
    data = data[~data["is_objective"]]
    categories = df["category"].unique()
    show_type_distribution(data, categories, "category", "subjective-future", summary)

    csv_path = output_dir / f"{args.collections}.csv"
    out_df = pd.DataFrame([summary])
    if csv_path.exists():
        out_df = pd.concat([pd.read_csv(csv_path), out_df], ignore_index=True)
    out_df.round(1).to_csv(csv_path, index=False)

    print(f"✅ Done. Summary saved to: {csv_path}")
if __name__ == "__main__":
    main()
