import os
import json
import argparse
import csv
from typing import Dict, List, Optional, Tuple

import torch
import numpy as np
import matplotlib
# Use headless backend in no-display environment to avoid QXcbConnection error
matplotlib.use("Agg")
import matplotlib.pyplot as plt


def try_infer_total_categories_from_info(source_path: str) -> Optional[int]:
    """Try to infer total_categories from source_path/masks/info.json."""
    masks_dir = os.path.join(source_path, "masks")
    info_json_path = os.path.join(masks_dir, "info.json")
    if os.path.exists(info_json_path):
        try:
            with open(info_json_path, "r", encoding="utf-8") as f:
                info = json.load(f)
            if isinstance(info, dict) and "total_categories" in info:
                return int(info["total_categories"])
        except Exception:
            pass
    return None


def load_labels(model_path: str, total_categories: int) -> Tuple[Dict[int, torch.Tensor], int]:
    """
    Load label.pth for each category under mid_result.

    Returns:
      - A dictionary: {class_id: label_tensor(torch.Long/Bool, shape[N])}
      - Actual number of successfully loaded categories
    """
    stats_counts_path = os.path.join(model_path, "mid_result")
    if not os.path.exists(stats_counts_path):
        raise FileNotFoundError(f"Directory not found: {stats_counts_path}")

    loaded: Dict[int, torch.Tensor] = {}
    for class_id in range(total_categories):
        label_file = os.path.join(
            stats_counts_path,
            f"class_id_{class_id:03d}_total_categories_{total_categories:03d}_label.pth",
        )
        if not os.path.exists(label_file):
            # Allow missing, skip
            continue
        try:
            label_tensor = torch.load(label_file, map_location="cpu")
        except Exception as e:
            print(f"Load failed, skip {label_file}: {e}")
            continue

        # Compatible shape and type
        if isinstance(label_tensor, torch.Tensor):
            label_tensor = label_tensor.view(-1).to(torch.float32)
            label_tensor = (label_tensor > 0.5).to(torch.int64)  # Binarize to {0,1}
        else:
            # Unsupported type
            print(f"Warning: {label_file} is not torch.Tensor, skipped")
            continue

        loaded[class_id] = label_tensor

    return loaded, len(loaded)


def compute_point_category_stats(loaded_labels: Dict[int, torch.Tensor]) -> Tuple[np.ndarray, np.ndarray]:
    """
    Calculate the number of categories and probability for each point being classified as foreground.

    Returns:
      - counts_per_point: np.ndarray [N], number of categories each point belongs to foreground
      - probs_per_point:  np.ndarray [N], probability each point belongs to foreground (counts / num_loaded_categories)
    """
    if not loaded_labels:
        raise ValueError("No label files loaded")

    # Align length N
    class_ids = sorted(loaded_labels.keys())
    first_len = loaded_labels[class_ids[0]].numel()
    for cid in class_ids:
        if loaded_labels[cid].numel() != first_len:
            raise ValueError(
                f"Label length for category {cid} inconsistent with others: {loaded_labels[cid].numel()} vs {first_len}"
            )

    # Stack to [C, N]
    label_mat = torch.stack([loaded_labels[cid] for cid in class_ids], dim=0)  # int64
    counts_per_point = label_mat.sum(dim=0).to(torch.int64).cpu().numpy()  # [N]
    num_loaded = label_mat.shape[0]
    probs_per_point = (counts_per_point.astype(np.float64) / float(num_loaded))
    return counts_per_point, probs_per_point


def save_csv(output_csv: str, counts: np.ndarray, probs: np.ndarray) -> None:
    os.makedirs(os.path.dirname(output_csv), exist_ok=True)
    with open(output_csv, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["point_index", "category_count", "probability"])  # Header
        for idx, (c, p) in enumerate(zip(counts, probs)):
            writer.writerow([idx, int(c), float(p)])


def save_hist_png(output_png: str, counts: np.ndarray, max_bins: Optional[int] = None) -> None:
    os.makedirs(os.path.dirname(output_png), exist_ok=True)
    max_count = int(counts.max()) if counts.size > 0 else 0
    if max_bins is not None:
        max_count = min(max_count, max_bins)

    # Count points for each count value
    hist = np.bincount(counts, minlength=max_count + 1)
    xs = np.arange(hist.shape[0])

    plt.figure(figsize=(10, 5))
    plt.bar(xs, hist, color="#4C72B0")
    plt.xlabel("Number of categories belonging to foreground (count)")
    plt.ylabel("Number of points (num points)")
    plt.title("Distribution of number of categories each point is classified as foreground")
    plt.tight_layout()
    plt.savefig(output_png)
    plt.close()


def save_summary_json(output_json: str, counts: np.ndarray, probs: np.ndarray, num_loaded_categories: int) -> None:
    os.makedirs(os.path.dirname(output_json), exist_ok=True)
    hist = np.bincount(counts)
    summary = {
        "num_points": int(counts.size),
        "num_loaded_categories": int(num_loaded_categories),
        "counts_histogram": {str(i): int(v) for i, v in enumerate(hist.tolist())},
        "min_count": int(counts.min() if counts.size > 0 else 0),
        "max_count": int(counts.max() if counts.size > 0 else 0),
        "mean_count": float(counts.mean() if counts.size > 0 else 0.0),
        "mean_probability": float(probs.mean() if probs.size > 0 else 0.0),
    }
    with open(output_json, "w", encoding="utf-8") as f:
        json.dump(summary, f, ensure_ascii=False, indent=2)


def main():
    parser = argparse.ArgumentParser(
        description=(
            "Statistics of foreground probability for each point in different categories: read label.pth for each category under mid_result, "
            "calculate number of categories and probability each point is classified as foreground, and export CSV and bar chart"
        )
    )
    parser.add_argument("--model_path", type=str, required=True, help="Training output model_path, containing mid_result directory")
    parser.add_argument("--total_categories", type=int, default=None, help="Total number of categories; if not provided, try to infer from source_path/masks/info.json")
    parser.add_argument("--source_path", type=str, default=None, help="Optional, dataset source_path, used to infer total_categories")
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
        help="Output directory; default write to model_path/analysis_point_category",
    )
    parser.add_argument(
        "--max_bins",
        type=int,
        default=None,
        help="Maximum bin for histogram (upper limit is actual maximum count); default no filtering",
    )
    args = parser.parse_args()

    model_path = args.model_path
    output_dir = args.output_dir or os.path.join(model_path, "analysis_point_category")

    # Parse/infer total_categories
    total_categories = args.total_categories
    if total_categories is None and args.source_path:
        total_categories = try_infer_total_categories_from_info(args.source_path)
    if total_categories is None:
        # Fallback: try to infer maximum total_categories from mid_result filenames
        mid_dir = os.path.join(model_path, "mid_result")
        if os.path.isdir(mid_dir):
            candidates: List[int] = []
            for name in os.listdir(mid_dir):
                # Match suffix _label.pth
                if name.endswith("_label.pth") and "total_categories_" in name:
                    try:
                        tail = name.split("total_categories_")[-1]
                        num = int(tail.split("_")[0])
                        candidates.append(num)
                    except Exception:
                        pass
            if candidates:
                total_categories = max(candidates)

    if total_categories is None:
        raise ValueError("Cannot determine total_categories, please provide via --total_categories or --source_path")

    print(f"model_path: {model_path}")
    print(f"total_categories(expected): {total_categories}")

    # 加载标签
    loaded_labels, num_loaded = load_labels(model_path, total_categories)
    if num_loaded == 0:
        raise RuntimeError("No label files found for any category, check if *_label.pth exists under mid_result")
    if num_loaded != total_categories:
        print(f"Warning: Expected {total_categories} categories, actually loaded only {num_loaded} categories, will calculate probability based on loaded categories")

    # Calculate statistics
    counts_per_point, probs_per_point = compute_point_category_stats(loaded_labels)

    # Export
    os.makedirs(output_dir, exist_ok=True)
    csv_path = os.path.join(output_dir, "points_category_stats.csv")
    png_path = os.path.join(output_dir, "point_category_count_hist.png")
    json_path = os.path.join(output_dir, "summary.json")

    save_csv(csv_path, counts_per_point, probs_per_point)
    save_hist_png(png_path, counts_per_point, max_bins=args.max_bins)
    save_summary_json(json_path, counts_per_point, probs_per_point, num_loaded)

    # 控制台输出部分具体数据
    print(f"Saved CSV: {csv_path}")
    print(f"Saved bar chart: {png_path}")
    print(f"Saved summary: {json_path}")
    print("Example (first 10 points):")
    for i in range(min(10, counts_per_point.shape[0])):
        print(f"Point {i}: count={int(counts_per_point[i])}, prob={probs_per_point[i]:.4f}")


if __name__ == "__main__":
    main()


