
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import re
import csv
import math
import json
import time
import argparse
import random
from pathlib import Path
from typing import List, Tuple, Optional, Dict

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

# ---------------------------
# Helpers
# ---------------------------

IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tiff"}

def is_image(p: Path) -> bool:
    return p.suffix.lower() in IMG_EXTS

def default_preprocess(img_size: int = 224):
    # Standard ImageNet normalization, commonly used with ViT backbones like DINOv2.
    return transforms.Compose([
        transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

def seed_everything(seed: int = 42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# ---------------------------
# Dataset
# ---------------------------

class ImageRowsDataset(Dataset):
    """
    Dataset reading from rows describing image path and (model, period).
    If model/period are None, attempts to infer them from the directory structure:
       .../<model>/<period>/<...>/<file>
    """
    def __init__(self, rows: List[Dict[str, str]], preprocess, images_root: Optional[Path] = None):
        self.rows = rows
        self.preprocess = preprocess
        self.images_root = images_root

    def __len__(self):
        return len(self.rows)

    def _resolve_path(self, p: str) -> Path:
        pth = Path(p)
        if not pth.is_absolute() and self.images_root is not None:
            pth = self.images_root / p
        return pth

    @staticmethod
    def infer_model_period_from_path(path: Path) -> Tuple[Optional[str], Optional[str]]:
        # Heuristic: take the two nearest parent folders as (model, period)
        # e.g., /.../Flux_Schnell/Centuries/21st_century/... -> model="Flux_Schnell", period="21st_century"
        parts = list(path.parts)
        # find indices of path parts that could be model/period. We'll take the two nearest dirs above the file.
        if len(parts) >= 3:
            model = parts[-3]
            period = parts[-2]
            return model, period
        return None, None

    def __getitem__(self, idx):
        row = self.rows[idx]
        img_path_str = row["image_path"]
        img_path = self._resolve_path(img_path_str)

        model = row.get("model")
        period = row.get("period")

        if (model is None or model == "") or (period is None or period == ""):
            inf_model, inf_period = self.infer_model_period_from_path(img_path)
            model = model or inf_model or "unknown_model"
            period = period or inf_period or "unknown_period"

        img = Image.open(img_path).convert("RGB")
        x = self.preprocess(img)
        meta = {
            "image_path": str(img_path),
            "model": model,
            "period": period,
        }
        return x, meta

# ---------------------------
# Backbone + Head
# ---------------------------

def load_dinov2(backbone_name: str):
    """
    Loads a DINOv2 backbone via torch.hub.
    Valid names include: dinov2_vits14, dinov2_vitb14, dinov2_vitl14, dinov2_vitg14.
    """
    try:
        backbone = torch.hub.load('facebookresearch/dinov2', backbone_name)
    except Exception as e:
        raise RuntimeError(
            f"Failed to load backbone '{backbone_name}' via torch.hub. "
            f"Ensure internet is available or that the model is cached. Error: {e}"
        )
    backbone.eval()
    return backbone

def build_linear_head_from_checkpoint(head_ckpt_path: Path, in_dim: Optional[int] = None, num_classes: Optional[int] = None):
    sd = torch.load(head_ckpt_path, map_location="cpu")
    # Try to infer dimensions from checkpoint if not provided
    if in_dim is None or num_classes is None:
        # Expect weight shape [num_classes, in_dim]
        for k, v in sd.items():
            if isinstance(v, torch.Tensor) and v.ndim == 2:
                num_classes_infer, in_dim_infer = v.shape
                if in_dim is None:
                    in_dim = in_dim_infer
                if num_classes is None:
                    num_classes = num_classes_infer
                break
        if in_dim is None or num_classes is None:
            raise ValueError("Could not infer in_dim/num_classes from checkpoint. Specify them explicitly.")

    head = nn.Linear(in_dim, num_classes, bias=True)
    head.load_state_dict(sd, strict=True)
    head.eval()
    return head, in_dim, num_classes

# ---------------------------
# Inference + Metrics
# ---------------------------

@torch.no_grad()
def extract_features(backbone, batch: torch.Tensor, device: torch.device):
    batch = batch.to(device, non_blocking=True)
    feats = backbone(batch)
    if isinstance(feats, (list, tuple)):
        feats = feats[0]
    return feats

@torch.no_grad()
def run_inference(
    backbone_name: str,
    head_path: Path,
    image_rows: List[Dict[str, str]],
    images_root: Optional[Path],
    label_names: List[str],
    batch_size: int = 32,
    num_workers: int = 4,
    device: Optional[str] = None,
    img_size: int = 224,
    in_dim: Optional[int] = None,
):
    device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
    preprocess = default_preprocess(img_size=img_size)

    # Load models
    backbone = load_dinov2(backbone_name).to(device)
    # Infer in_dim from backbone if possible
    if in_dim is None:
        in_dim = getattr(backbone, "embed_dim", None)
    head, in_dim_ckpt, num_classes = build_linear_head_from_checkpoint(head_path, in_dim=in_dim, num_classes=len(label_names))
    head = head.to(device)

    # Sanity checks
    if len(label_names) != num_classes:
        raise ValueError(f"Number of label names ({len(label_names)}) != num_classes in head ({num_classes}).")
    if in_dim is not None and in_dim != in_dim_ckpt:
        print(f"[WARN] Provided in_dim ({in_dim}) != checkpoint in_dim ({in_dim_ckpt}). Proceeding with checkpoint in_dim.")

    # Data
    ds = ImageRowsDataset(image_rows, preprocess, images_root=images_root)
    dl = DataLoader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=True)

    # Inference
    per_image_rows = []
    for xb, meta_batch in dl:
        feats = extract_features(backbone, xb, device)
        logits = head(feats)
        probs = torch.softmax(logits, dim=-1).cpu()
        preds = probs.argmax(dim=-1)

        for i in range(len(meta_batch["image_path"])):
            per_image_rows.append({
                "image_path": meta_batch["image_path"][i],
                "model": meta_batch["model"][i],
                "period": meta_batch["period"][i],
                "pred_idx": int(preds[i].item()),
                "pred_label": label_names[int(preds[i].item())],
                "prob": float(probs[i][preds[i]].item()),
            })

    return per_image_rows

def compute_style_distribution(per_image_rows: List[Dict[str, str]], label_names: List[str]):
    """
    Returns dict keyed by (model, period) with counts and proportions over labels.
    """
    # Count
    counts: Dict[Tuple[str,str], Dict[str,int]] = {}
    totals: Dict[Tuple[str,str], int] = {}
    for r in per_image_rows:
        key = (r["model"], r["period"])
        if key not in counts:
            counts[key] = {lbl: 0 for lbl in label_names}
            totals[key] = 0
        counts[key][r["pred_label"]] += 1
        totals[key] += 1

    # Proportions
    distributions = {}
    for key, d in counts.items():
        total = max(1, totals[key])
        distributions[key] = {lbl: d[lbl] / total for lbl in label_names}
    return counts, totals, distributions

def bootstrap_vsd(per_image_rows: List[Dict[str, str]], label_names: List[str], n_boot: int = 1000, seed: int = 42):
    """
    For each (model, period), compute VSD = max_s P(s|t) and 95% CI via bootstrap.
    """
    seed_everything(seed)
    # Group images by (model, period)
    groups: Dict[Tuple[str,str], List[str]] = {}
    for r in per_image_rows:
        key = (r["model"], r["period"])
        groups.setdefault(key, []).append(r["pred_label"])

    results = {}
    for key, labels in groups.items():
        N = len(labels)
        if N == 0:
            results[key] = {"vsd": 0.0, "ci_low": 0.0, "ci_high": 0.0}
            continue

        # Empirical VSD
        base_counts = {lbl: 0 for lbl in label_names}
        for lbl in labels:
            base_counts[lbl] += 1
        base_props = {lbl: base_counts[lbl] / N for lbl in label_names}
        vsd = max(base_props.values())

        # Bootstrap
        boot_stats = []
        for _ in range(n_boot):
            sample = [labels[random.randrange(N)] for _ in range(N)]
            c = {lbl: 0 for lbl in label_names}
            for s in sample:
                c[s] += 1
            props = [c[lbl] / N for lbl in label_names]
            boot_stats.append(max(props))
        boot_stats.sort()
        lo_idx = int(0.025 * n_boot)
        hi_idx = int(0.975 * n_boot) - 1
        ci_low = boot_stats[lo_idx]
        ci_high = boot_stats[hi_idx]

        results[key] = {"vsd": vsd, "ci_low": ci_low, "ci_high": ci_high, "n": N}

    return results

# ---------------------------
# IO
# ---------------------------

def read_rows_from_csv(csv_path: Path) -> List[Dict[str, str]]:
    rows = []
    with open(csv_path, "r", newline="", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        # Attempt to identify column names
        col_map = {k.lower(): k for k in reader.fieldnames}
        img_col = col_map.get("image_path") or col_map.get("path") or col_map.get("file")
        model_col = col_map.get("model")
        period_col = col_map.get("period")
        if not img_col:
            raise ValueError("CSV must include an 'image_path' column (or 'path'/'file').")

        for row in reader:
            rows.append({
                "image_path": row[img_col],
                "model": row[model_col] if model_col and row.get(model_col) is not None else None,
                "period": row[period_col] if period_col and row.get(period_col) is not None else None,
            })
    return rows

def collect_image_rows_from_dir(root: Path) -> List[Dict[str, str]]:
    rows = []
    for p in root.rglob("*"):
        if p.is_file() and is_image(p):
            model, period = ImageRowsDataset.infer_model_period_from_path(p)
            rows.append({
                "image_path": str(p),
                "model": model or "unknown_model",
                "period": period or "unknown_period",
            })
    return rows

def write_csv(path: Path, rows: List[Dict[str, object]], fieldnames: List[str]):
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader()
        for r in rows:
            w.writerow({k: r.get(k, "") for k in fieldnames})

# ---------------------------
# Main
# ---------------------------

def main():
    parser = argparse.ArgumentParser(description="Classify images with DINOv2 + linear head into 5 styles and compute VSD scores with 95% CI.")
    parser.add_argument("--images-root", type=str, required=True, help="Root folder for images (and base path for relative CSV paths).")
    parser.add_argument("--head-path", type=str, required=True, help="Path to your trained linear head .pt (e.g., linear_head_final.pt).")
    parser.add_argument("--backbone", type=str, default="dinov2_vitb14", help="DINOv2 backbone name (dinov2_vits14|dinov2_vitb14|dinov2_vitl14|dinov2_vitg14).")
    parser.add_argument("--labels", type=str, nargs="+", default=["drawing","engraving","illustration","painting","photography"], help="Label names in the EXACT order used during training.")
    parser.add_argument("--csv", type=str, default=None, help="Optional CSV with columns: image_path, model, period. If omitted, will scan images-root and infer (model, period) from parent dirs.")
    parser.add_argument("--out-dir", type=str, default="vsd_outputs", help="Output directory for CSV results.")
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--num-workers", type=int, default=4)
    parser.add_argument("--device", type=str, default=None, help="cuda|cpu (default auto)")
    parser.add_argument("--img-size", type=int, default=224)
    parser.add_argument("--bootstrap", type=int, default=1000, help="Number of bootstrap resamples for CI (default 1000).")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--in-dim", type=int, default=None, help="Optional: specify head input dim if auto-infer fails.")
    args = parser.parse_args()

    seed_everything(args.seed)

    images_root = Path(args.images_root).resolve()
    head_path = Path(args.head_path).resolve()
    out_dir = Path(args.out_dir).resolve()

    # Input rows
    if args.csv:
        rows = read_rows_from_csv(Path(args.csv).resolve())
    else:
        rows = collect_image_rows_from_dir(images_root)

    if len(rows) == 0:
        raise RuntimeError("No images found. Provide a CSV or ensure images exist under --images-root.")

    # Run inference
    per_image = run_inference(
        backbone_name=args.backbone,
        head_path=head_path,
        image_rows=rows,
        images_root=images_root,
        label_names=args.labels,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        device=args.device,
        img_size=args.img_size,
        in_dim=args.in_dim,
    )

    # Save per-image predictions
    per_image_fields = ["image_path","model","period","pred_idx","pred_label","prob"]
    write_csv(out_dir / "per_image_predictions.csv", per_image, per_image_fields)

    # Style distribution
    counts, totals, dists = compute_style_distribution(per_image, args.labels)
    dist_rows = []
    for (model, period), props in dists.items():
        row = {"model": model, "period": period, "n": totals[(model, period)]}
        row.update({f"prop_{lbl}": props[lbl] for lbl in args.labels})
        row.update({f"count_{lbl}": counts[(model, period)][lbl] for lbl in args.labels})
        dist_rows.append(row)
    dist_fields = ["model","period","n"] + [f"prop_{lbl}" for lbl in args.labels] + [f"count_{lbl}" for lbl in args.labels]
    write_csv(out_dir / "style_distributions.csv", dist_rows, dist_fields)

    # VSD + CI
    vsd_stats = bootstrap_vsd(per_image, args.labels, n_boot=args.bootstrap, seed=args.seed)
    vsd_rows = []
    for (model, period), stats in vsd_stats.items():
        vsd_rows.append({
            "model": model,
            "period": period,
            "n": stats.get("n", 0),
            "vsd": stats["vsd"],
            "ci_low": stats["ci_low"],
            "ci_high": stats["ci_high"],
        })
    vsd_fields = ["model","period","n","vsd","ci_low","ci_high"]
    write_csv(out_dir / "vsd_summary.csv", vsd_rows, vsd_fields)

    print(f"Done. Wrote:")
    print(f" - {out_dir / 'per_image_predictions.csv'}")
    print(f" - {out_dir / 'style_distributions.csv'}")
    print(f" - {out_dir / 'vsd_summary.csv'}")

if __name__ == "__main__":
    main()
