#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import json
import os
import csv
from typing import List, Dict, Any
import numpy as np
from ds import re_response_pro_con_extraction, re_response_number_extraction
from collections import Counter


def word_count(s: str) -> int:
    return len(s.split()) if isinstance(s, str) and s else 0


def percent(x: int, total: int) -> float:
    return (100.0 * x / total) if total > 0 else 0.0


def describe(nums: List[int]) -> Dict[str, float]:
    """Return basic descriptive stats using NumPy."""
    if not nums:
        return {"count": 0, "mean": np.nan, "std": np.nan, "min": np.nan,
                "p25": np.nan, "median": np.nan, "p75": np.nan, "max": np.nan}
    arr = np.asarray(nums, dtype=float)
    return {
        "count": int(arr.size),
        "mean": float(np.mean(arr)),
        "std": float(np.std(arr, ddof=0)),  # population std (match previous behavior)
        "min": float(np.min(arr)),
        "p25": float(np.percentile(arr, 25)),
        "median": float(np.percentile(arr, 50)),
        "p75": float(np.percentile(arr, 75)),
        "max": float(np.max(arr))
    }


def load_data(path: str) -> List[Dict[str, Any]]:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def gather_stats(data: List[Dict[str, Any]]) -> Dict[str, Any]:
    topic_count = len(data)

    topic_pro_counts: List[int] = []
    topic_con_counts: List[int] = []
    topic_total_opinions: List[int] = []

    total_pro_opinions = 0
    total_con_opinions = 0

    opinion_lengths: List[int] = []   # word counts of opinion texts
    expanded_lengths: List[int] = []  # word counts of expanded texts

    # CSV rows
    topic_rows: List[Dict[str, Any]] = []
    opinion_rows: List[Dict[str, Any]] = []
    expanded_rows: List[Dict[str, Any]] = []

    for t_idx, item in enumerate(data):
        topic = item.get("topic", "")
        pros = item.get("pros", []) or []
        cons = item.get("cons", []) or []

        n_pro = len(pros)
        n_con = len(cons)
        topic_pro_counts.append(n_pro)
        topic_con_counts.append(n_con)
        topic_total_opinions.append(n_pro + n_con)

        total_pro_opinions += n_pro
        total_con_opinions += n_con

        topic_rows.append({
            "topic_index": t_idx,
            "topic": topic,
            "num_pros": n_pro,
            "num_cons": n_con,
            "num_opinions_total": n_pro + n_con
        })

        # iterate pros
        for o_idx, op in enumerate(pros):
            opinion_text = op.get("opinion", "")
            expanded_list = op.get("expanded", []) or []

            op_len = word_count(opinion_text)
            opinion_lengths.append(op_len)

            exp_lens = [word_count(x) for x in expanded_list]
            avg_exp_len = float(np.mean(exp_lens)) if exp_lens else 0.0

            opinion_rows.append({
                "topic_index": t_idx,
                "topic": topic,
                "side": "pro",
                "opinion_index": o_idx,
                "opinion_length_words": op_len,
                "num_expanded": len(expanded_list),
                "avg_expanded_length_words": round(avg_exp_len, 2),
            })

            for e_idx, exp in enumerate(expanded_list):
                ex_len = word_count(exp)
                expanded_lengths.append(ex_len)
                expanded_rows.append({
                    "topic_index": t_idx,
                    "topic": topic,
                    "side": "pro",
                    "opinion_index": o_idx,
                    "expanded_index": e_idx,
                    "expanded_length_words": ex_len
                })

        # iterate cons
        for o_idx, op in enumerate(cons):
            opinion_text = op.get("opinion", "")
            expanded_list = op.get("expanded", []) or []

            op_len = word_count(opinion_text)
            opinion_lengths.append(op_len)

            exp_lens = [word_count(x) for x in expanded_list]
            avg_exp_len = float(np.mean(exp_lens)) if exp_lens else 0.0

            opinion_rows.append({
                "topic_index": t_idx,
                "topic": topic,
                "side": "con",
                "opinion_index": o_idx,
                "opinion_length_words": op_len,
                "num_expanded": len(expanded_list),
                "avg_expanded_length_words": round(avg_exp_len, 2),
            })

            for e_idx, exp in enumerate(expanded_list):
                ex_len = word_count(exp)
                expanded_lengths.append(ex_len)
                expanded_rows.append({
                    "topic_index": t_idx,
                    "topic": topic,
                    "side": "con",
                    "opinion_index": o_idx,
                    "expanded_index": e_idx,
                    "expanded_length_words": ex_len
                })

    # Distributions/descriptives
    topic_pro_desc = describe(topic_pro_counts)
    topic_con_desc = describe(topic_con_counts)
    topic_total_desc = describe(topic_total_opinions)
    opinion_len_desc = describe(opinion_lengths)
    expanded_len_desc = describe(expanded_lengths)

    total_opinions = total_pro_opinions + total_con_opinions
    pro_ratio = percent(total_pro_opinions, total_opinions)
    con_ratio = percent(total_con_opinions, total_opinions)

    return {
        "topic_count": topic_count,
        "total_pro_opinions": total_pro_opinions,
        "total_con_opinions": total_con_opinions,
        "total_opinions": total_opinions,
        "pro_ratio_percent": pro_ratio,
        "con_ratio_percent": con_ratio,
        "topic_pro_desc": topic_pro_desc,
        "topic_con_desc": topic_con_desc,
        "topic_total_desc": topic_total_desc,
        "opinion_len_desc_words": opinion_len_desc,
        "expanded_len_desc_words": expanded_len_desc,
        "topic_rows": topic_rows,
        "opinion_rows": opinion_rows,
        "expanded_rows": expanded_rows,
    }


def print_summary(stats: Dict[str, Any]) -> None:
    def fmt_desc(name: str, d: Dict[str, Any]):
        print(f"- {name}:")
        if d["count"] == 0 or np.isnan(d["mean"]):
            print("  (no data)")
            return
        print(
            "  count={count}, mean={mean:.2f}, std={std:.2f}, "
            "min={min:.0f}, p25={p25:.0f}, median={median:.0f}, "
            "p75={p75:.0f}, max={max:.0f}".format(**d)
        )

    print("=== DATASET SUMMARY ===")
    print(f"# Topics: {stats['topic_count']}")
    print(f"# Opinions: {stats['total_opinions']}  (pro={stats['total_pro_opinions']}, con={stats['total_con_opinions']})")
    print(f"Pro:Con ratio: {stats['pro_ratio_percent']:.1f}% : {stats['con_ratio_percent']:.1f}%")

    print("\n--- Per-topic opinion counts ---")
    fmt_desc("Pros per topic", stats["topic_pro_desc"])
    fmt_desc("Cons per topic", stats["topic_con_desc"])
    fmt_desc("Total opinions per topic", stats["topic_total_desc"])

    print("\n--- Text length (words) ---")
    fmt_desc("Opinion length", stats["opinion_len_desc_words"])
    fmt_desc("Expanded length", stats["expanded_len_desc_words"])


def write_csv(path: str, rows: List[Dict[str, Any]], fieldnames: List[str]) -> None:
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)


def get_human_scores_stats(data):
    all_ratings = []

    for topic in data:
        for opinion in topic["pros"]:
            if len(opinion["scores"]) > 0:
                all_ratings.append(opinion["scores"])

        for opinion in topic["cons"]:
            if len(opinion["scores"]) > 0:
                all_ratings.append(opinion["scores"])
    print(len(all_ratings))
    scores = np.array(all_ratings)
    out = {
        "Fidelity to Original Opinion" : describe_scores(scores[:, 0]),
        "Relevance of Added Content" : describe_scores(scores[:, 1]),
        "Use of Best Match Post" : describe_scores(scores[:, 2]),
        "Naturalness of Writing" : describe_scores(scores[:, 3]),
        "Relevance of Best Match Post to the Topic" : describe_scores(scores[:, 4]),
    }
    return out


def describe_scores(scores: np.ndarray) -> Dict[str, Any]:
    return {
        "mean" : np.mean(scores),
        "std" : np.std(scores),
        "min" : np.min(scores),
        "max" : np.max(scores),
        "median" : np.median(scores),
        "p25" : np.percentile(scores, 25),
        "p75" : np.percentile(scores, 75),
    }

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", required=True, help="Path to input JSON file.")
    parser.add_argument("--outdir", default=None, help="Optional output directory for CSVs.")
    args = parser.parse_args()

    data = load_data(args.input)
    stats = gather_stats(data)
    print_summary(stats)

    if args.outdir:
        topic_fields = ["topic_index", "topic", "num_pros", "num_cons", "num_opinions_total"]
        opinion_fields = ["topic_index", "topic", "side", "opinion_index", "opinion_length_words", "num_expanded", "avg_expanded_length_words"]
        expanded_fields = ["topic_index", "topic", "side", "opinion_index", "expanded_index", "expanded_length_words"]

        write_csv(os.path.join(args.outdir, "topic_stats.csv"), stats["topic_rows"], topic_fields)
        write_csv(os.path.join(args.outdir, "opinion_stats.csv"), stats["opinion_rows"], opinion_fields)
        write_csv(os.path.join(args.outdir, "expanded_stats.csv"), stats["expanded_rows"], expanded_fields)
        print(f"\nCSV written to: {args.outdir}/topic_stats.csv, opinion_stats.csv, expanded_stats.csv")


def incorrect_sample_filtering(data, task):
    re_find = re_response_pro_con_extraction if task == "polatiry_check" else re_response_number_extraction
    incorrect = []
    for item in data:
        extracted_out = re_find(item["output"])
        if extracted_out != item["gt"]:
            incorrect.append(item["dialog"]["messages"][1]["content"])

    return incorrect


def create_incorrect_sample_for_all_models(root, task):
    incorrects = []

    out = []

    for model in os.listdir(root):
        incorrects.append(incorrect_sample_filtering(json.load(open(os.path.join(root, model), "r")), task))

    all_errors = [q for sublist in incorrects for q in sublist]
    counter = Counter(all_errors)

    max_count = max(counter.values())


    for k, v in counter.items():
        if v == 12:
            out.append(k)
    print(len(out))

    with open(os.path.join("", f"{task}_incorrect_samples.json"), "w") as f:
        json.dump(out, f)
    return out


import os
import json
from collections import defaultdict

def collect_all_models_unanimous_incorrect(root_dir, out_json_path):
    model_files = [
        os.path.join(root_dir, fn) for fn in os.listdir(root_dir)
        if fn.lower().endswith(".json") and os.path.isfile(os.path.join(root_dir, fn))
    ]
    if not model_files:
        print(f"[WARN] No JSON files found in: {root_dir}")
        with open(out_json_path, "w") as f:
            json.dump([], f, ensure_ascii=False, indent=2)
        return []

    num_models = len(model_files)
    samples_index = defaultdict(lambda: {"gts": set(), "by_model": dict()})

    # 读取并汇总
    for mf in model_files:
        model_name = os.path.basename(mf)
        with open(mf, "r") as f:
            try:
                data = json.load(f)
            except json.JSONDecodeError as e:
                print(f"[WARN] Skip invalid JSON: {mf} ({e})")
                continue

        if not isinstance(data, list):
            print(f"[WARN] File is not a list, skip: {mf}")
            continue

        for item in data:
            try:
                sample_id = item["dialog"]["messages"][1]["content"]
                gt = item.get("gt")
                extracted = item.get("extracted")
                output = item.get("output")
            except Exception as e:
                # 跳过结构异常的数据
                continue

            # 记录 gt、模型输出
            entry = samples_index[sample_id]
            entry["gts"].add(gt)
            entry["by_model"][model_name] = {
                "extracted": extracted,
                "output": output,
                "is_incorrect": (extracted != gt)
            }

    # 过滤：必须在“所有模型都出现且都做错且 gt 一致”
    unanimous_incorrect = []
    gt_conflict_count = 0
    partial_coverage_count = 0

    for sample_id, info in samples_index.items():
        # gt 必须一致
        if len(info["gts"]) != 1:
            gt_conflict_count += 1
            # 可选：打印冲突详情（如需要可注释）
            # print(f"[GT-CONFLICT] sample id={sample_id} gts={info['gts']}")
            continue

        gt_value = next(iter(info["gts"]))

        # 覆盖度：该样本必须出现在所有模型的结果里
        if len(info["by_model"]) != num_models:
            partial_coverage_count += 1
            continue

        # 所有模型都做错
        if all(v["is_incorrect"] for v in info["by_model"].values()):
            models_list = []
            # 为了输出稳定性，按模型文件名排序
            for model_name in sorted(info["by_model"].keys()):
                v = info["by_model"][model_name]
                models_list.append({
                    "model": model_name,
                    "extracted": v["extracted"],
                    "output": v["output"]
                })

            unanimous_incorrect.append({
                "id": sample_id,
                "gt": gt_value,
                "models": models_list
            })

    # 保存结果
    os.makedirs(os.path.dirname(out_json_path), exist_ok=True)
    with open(out_json_path, "w") as f:
        json.dump(unanimous_incorrect, f, ensure_ascii=False, indent=2)

    # 控制台报告
    print(f"[INFO] Models detected: {num_models}")
    print(f"[INFO] Samples total: {len(samples_index)}")
    print(f"[INFO] Unanimous incorrect: {len(unanimous_incorrect)}")
    if gt_conflict_count:
        print(f"[WARN] Skipped due to GT conflicts: {gt_conflict_count}")
    if partial_coverage_count:
        print(f"[WARN] Skipped due to partial coverage (<{num_models} models present): {partial_coverage_count}")

    return unanimous_incorrect

# 使用示例
# result = collect_all_models_unanimous_incorrect(
#     root_dir="/path/to/your/models/jsons",
#     out_json_path="/path/to/save/unanimous_incorrect.json"
# )


if __name__ == "__main__":
    pass