from dataclasses import dataclass
from typing import List, Tuple, Optional
import os
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
from PIL import Image

from torchvision.datasets import VOCDetection
from torchvision.transforms import functional as TF
from torchvision.models.detection import ssd300_vgg16
from torchvision.models import VGG16_Weights


# ------------------------------- GLOBAL CONFIG (KEEP AS-IS) ------------------------------- #
NUM_CLASSES = 21  # 20 VOC + background(0)
NUM_SOURCES_DEFAULT = 5

BATCH_SIZE = 8
EPOCHS = 20
LR = 1e-3
WEIGHT_DECAY = 5e-4
MOMENTUM = 0.9

EVAL_EVERY = 1
VAL_MAX_BATCHES = None  # None = all

# keep behavior similar to your detection script
EVAL_LAST_K_EPOCHS = 5
GRAD_CLIP_NORM = 10.0
USE_AMP = False
NUM_WORKERS = 0
DEBUG_FINITE = True


# ------------------------------- VOC class mapping ------------------------------- #
VOC_CLASSES = [
    "aeroplane","bicycle","bird","boat","bottle","bus","car","cat","chair","cow",
    "diningtable","dog","horse","motorbike","person","pottedplant","sheep","sofa","train","tvmonitor"
]
CLASS_TO_ID = {name: i + 1 for i, name in enumerate(VOC_CLASSES)}  # 1..20


# ------------------------------- Dataset utils ------------------------------- #
def _parse_voc_annotation(ann: dict, img_w: int, img_h: int):
    """
    Returns:
      boxes (Nx4 float32) in 0-based xyxy, clamped
      labels (N int64)
      difficult (N bool)

    NOTE:
      - No small box removal.
      - Empty targets allowed.
    """
    objects = ann.get("annotation", {}).get("object", [])
    if isinstance(objects, dict):
        objects = [objects]

    boxes, labels, difficult = [], [], []

    for obj in objects:
        name = obj.get("name", None)
        if name not in CLASS_TO_ID:
            continue

        bb = obj.get("bndbox", None)
        if bb is None:
            continue

        try:
            xmin = float(bb["xmin"]) - 1.0
            ymin = float(bb["ymin"]) - 1.0
            xmax = float(bb["xmax"]) - 1.0
            ymax = float(bb["ymax"]) - 1.0
        except Exception:
            continue

        xmin = max(0.0, min(xmin, img_w - 1.0))
        ymin = max(0.0, min(ymin, img_h - 1.0))
        xmax = max(0.0, min(xmax, img_w - 1.0))
        ymax = max(0.0, min(ymax, img_h - 1.0))

        if (xmax <= xmin) or (ymax <= ymin):
            continue

        boxes.append([xmin, ymin, xmax, ymax])
        labels.append(CLASS_TO_ID[name])

        d = obj.get("difficult", "0")
        try:
            difficult.append(bool(int(d)))
        except Exception:
            difficult.append(False)

    if len(boxes) == 0:
        return (
            torch.zeros((0, 4), dtype=torch.float32),
            torch.zeros((0,), dtype=torch.int64),
            torch.zeros((0,), dtype=torch.bool),
        )

    return (
        torch.tensor(boxes, dtype=torch.float32),
        torch.tensor(labels, dtype=torch.int64),
        torch.tensor(difficult, dtype=torch.bool),
    )


class DetectionTransformLite:
    """
    IMPORTANT: only minimal aug + to_tensor (NO resize, NO normalize).
    SSD model has its own internal transform (resize/normalize).
    """
    def __init__(self, train: bool):
        self.train = train

    def __call__(self, img: Image.Image, target: dict):
        if self.train and torch.rand(()) < 0.5:
            img = TF.hflip(img)
            w, _ = img.size
            boxes = target["boxes"]
            if boxes.numel() > 0:
                boxes = boxes.clone()
                x1 = boxes[:, 0].clone()
                x2 = boxes[:, 2].clone()
                boxes[:, 0] = (w - 1.0) - x2
                boxes[:, 2] = (w - 1.0) - x1
                target["boxes"] = boxes

        img = TF.to_tensor(img)  # [0,1], no normalize
        return img, target


class VOCDetWrapper(Dataset):
    """
    Wrap VOCDetection -> (img_tensor, target_dict).
    image_id is stable int with offset for ConcatDataset uniqueness.
    """
    def __init__(self, voc_ds: VOCDetection, tfm: DetectionTransformLite, offset: int = 0):
        self.base = voc_ds
        self.tfm = tfm
        self.offset = int(offset)

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx: int):
        img, ann = self.base[idx]
        w, h = img.size

        boxes, labels, difficult = _parse_voc_annotation(ann, w, h)

        img_id = self.offset + int(idx)
        target = {
            "boxes": boxes,
            "labels": labels,
            "difficult": difficult,
            "image_id": torch.tensor([img_id], dtype=torch.int64),
        }
        img, target = self.tfm(img, target)
        return img, target


def build_full_trainset_and_testset(voc_root: str):
    tf_train = DetectionTransformLite(train=True)
    tf_test  = DetectionTransformLite(train=False)

    voc07_trainval_raw = VOCDetection(root=voc_root, year="2007", image_set="trainval", download=False)
    voc12_trainval_raw = VOCDetection(root=voc_root, year="2012", image_set="trainval", download=False)
    voc07_test_raw     = VOCDetection(root=voc_root, year="2007", image_set="test", download=False)

    n07 = len(voc07_trainval_raw)

    voc07_trainval = VOCDetWrapper(voc07_trainval_raw, tf_train, offset=0)
    voc12_trainval = VOCDetWrapper(voc12_trainval_raw, tf_train, offset=n07)
    train_full = ConcatDataset([voc07_trainval, voc12_trainval])

    testset = VOCDetWrapper(voc07_test_raw, tf_test, offset=0)
    return train_full, testset


def load_source_pools(split_dir: str, K: int) -> List[np.ndarray]:
    pools = []
    for k in range(K):
        p = os.path.join(split_dir, f"idxs_source{k}.npy")
        if not os.path.exists(p):
            raise FileNotFoundError(f"Missing split file: {p}")
        pools.append(np.load(p).astype(np.int64))
    return pools


class IncrementalCollectorMultiDet:
    """per-source fixed order + prefix selection => nested."""
    def __init__(self, pools: List[np.ndarray]):
        self.K = len(pools)
        self.pools = []
        self.ptrs = [0] * self.K
        self.selected = [[] for _ in range(self.K)]

        for k, arr in enumerate(pools):
            rng = np.random.RandomState(42 + k)  # seg oracle와 동일
            perm = rng.permutation(len(arr))
            self.pools.append(arr[perm])

    def ensure_size(self, q_targets: List[int]) -> List[int]:
        added = [0] * self.K
        for k in range(self.K):
            curr = len(self.selected[k])
            need = max(0, int(q_targets[k]) - curr)
            if need == 0:
                continue
            start = self.ptrs[k]
            end = min(start + need, len(self.pools[k]))
            take = self.pools[k][start:end].tolist()
            self.ptrs[k] = end
            self.selected[k].extend(take)
            added[k] = len(take)
        return added

    def indices_flat(self) -> List[int]:
        out = []
        for k in range(self.K):
            out.extend(self.selected[k])
        return out

    def current_q_vec(self) -> List[int]:
        return [len(self.selected[k]) for k in range(self.K)]


def collate_fn(batch):
    imgs, targets = zip(*batch)
    return list(imgs), list(targets)


# ------------------------------- VOC07 11-point mAP@0.5 ------------------------------- #
def _box_iou_np(a, b):
    ax1, ay1, ax2, ay2 = a
    bx1 = b[:, 0]; by1 = b[:, 1]; bx2 = b[:, 2]; by2 = b[:, 3]

    inter_x1 = np.maximum(ax1, bx1)
    inter_y1 = np.maximum(ay1, by1)
    inter_x2 = np.minimum(ax2, bx2)
    inter_y2 = np.minimum(ay2, by2)

    iw = np.maximum(inter_x2 - inter_x1 + 1.0, 0.0)
    ih = np.maximum(inter_y2 - inter_y1 + 1.0, 0.0)
    inter = iw * ih

    area_a = (ax2 - ax1 + 1.0) * (ay2 - ay1 + 1.0)
    area_b = (bx2 - bx1 + 1.0) * (by2 - by1 + 1.0)
    union = area_a + area_b - inter
    return inter / np.maximum(union, 1e-12)


def voc07_ap(rec, prec):
    ap = 0.0
    for t in np.arange(0.0, 1.1, 0.1):
        p = prec[rec >= t].max() if np.any(rec >= t) else 0.0
        ap += p / 11.0
    return float(ap)


@torch.no_grad()
def evaluate_map_50(model, loader, device, iou_thr=0.5, max_batches=None):
    model.eval()

    gt_by_img = {}
    preds_by_class = {c: [] for c in range(1, 21)}

    for bi, (imgs, targets) in enumerate(loader):
        if max_batches is not None and bi >= max_batches:
            break

        imgs = [im.to(device) for im in imgs]
        outputs = model(imgs)

        for out, tgt in zip(outputs, targets):
            img_id = int(tgt["image_id"].item())

            gt_boxes = tgt["boxes"].detach().cpu().numpy().astype(np.float32)
            gt_labels = tgt["labels"].detach().cpu().numpy().astype(np.int64)
            gt_diff = tgt.get("difficult", torch.zeros((len(gt_labels),), dtype=torch.bool)).detach().cpu().numpy().astype(np.bool_)

            gt_by_img[img_id] = {"boxes": gt_boxes, "labels": gt_labels, "difficult": gt_diff}

            boxes = out["boxes"].detach().cpu().numpy().astype(np.float32)
            scores = out["scores"].detach().cpu().numpy().astype(np.float32)
            labels = out["labels"].detach().cpu().numpy().astype(np.int64)

            for b, s, l in zip(boxes, scores, labels):
                if 1 <= l <= 20:
                    preds_by_class[l].append((img_id, float(s), b.copy()))

    aps = []
    for cls in range(1, 21):
        preds = preds_by_class[cls]
        preds.sort(key=lambda x: -x[1])

        npos = 0
        gt_records = {}
        for img_id, gt in gt_by_img.items():
            m = (gt["labels"] == cls)
            if np.any(m):
                diff = gt["difficult"][m]
                npos += int(np.sum(~diff))
                gt_records[img_id] = {
                    "boxes": gt["boxes"][m],
                    "difficult": diff,
                    "det": np.zeros((int(np.sum(m)),), dtype=np.bool_),
                }

        if npos == 0:
            aps.append(0.0)
            continue

        tp = np.zeros((len(preds),), dtype=np.float32)
        fp = np.zeros((len(preds),), dtype=np.float32)

        for i, (img_id, score, box) in enumerate(preds):
            if img_id not in gt_records:
                fp[i] = 1.0
                continue

            rec = gt_records[img_id]
            gt_boxes = rec["boxes"]
            if gt_boxes.shape[0] == 0:
                fp[i] = 1.0
                continue

            ious = _box_iou_np(np.asarray(box, dtype=np.float32), gt_boxes.astype(np.float32))
            j = int(np.argmax(ious))
            ovmax = float(ious[j])

            if ovmax >= iou_thr:
                if rec["difficult"][j]:
                    tp[i] = 0.0
                    fp[i] = 0.0
                else:
                    if not rec["det"][j]:
                        tp[i] = 1.0
                        rec["det"][j] = True
                    else:
                        fp[i] = 1.0
            else:
                fp[i] = 1.0

        tp_cum = np.cumsum(tp)
        fp_cum = np.cumsum(fp)
        rec = tp_cum / float(npos)
        prec = tp_cum / np.maximum(tp_cum + fp_cum, 1e-12)
        aps.append(voc07_ap(rec, prec))

    return float(np.mean(aps))


# ------------------------------- VOCDetSSDOracleMulti ------------------------------- #
@dataclass
class DetTrainCfg:
    epochs: int = EPOCHS
    batch_size: int = BATCH_SIZE
    lr: float = LR
    weight_decay: float = WEIGHT_DECAY
    num_workers: int = 0
    val_max_batches: Optional[int] = VAL_MAX_BATCHES


class VOCDetSSDOracleMulti:
    """


    - ensure_collected(q_vec) -> List[int]
    - current_q_vec() -> List[int]
    - __call__(q_vec) -> float
    - call_both(q_vec) -> Tuple[float,float]
    """
    def __init__(self, device: str, cfg: DetTrainCfg, K: int):
        self.device, self.cfg = device, cfg
        self.K = K

        self.voc_root = "./data/VOC"
        self.split_dir = "./data/detection"

        self.train_full, self.testset = build_full_trainset_and_testset(self.voc_root)

        pools = load_source_pools(self.split_dir, K=self.K)
        self.collector = IncrementalCollectorMultiDet(pools)

    def ensure_collected(self, q_vec: List[int]) -> List[int]:
        return self.collector.ensure_size([int(x) for x in q_vec])

    def current_q_vec(self) -> List[int]:
        return self.collector.current_q_vec()

    def _train_eval(self, idxs: List[int]) -> Tuple[float, float]:
        trainloader = DataLoader(
            Subset(self.train_full, idxs),
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=NUM_WORKERS,
            pin_memory=True,
            drop_last=True,
            collate_fn=collate_fn,
        )
        testloader = DataLoader(
            self.testset,
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=NUM_WORKERS,
            pin_memory=True,
            collate_fn=collate_fn,
        )

        model = ssd300_vgg16(
            weights=None,
            weights_backbone=VGG16_Weights.IMAGENET1K_V1,
            num_classes=NUM_CLASSES,
        ).to(self.device)

        if self.device == "cuda":
            torch.backends.cudnn.benchmark = False

        params = [p for p in model.parameters() if p.requires_grad]
        optimizer = optim.SGD(params, lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

        scaler = torch.amp.GradScaler(enabled=(USE_AMP and self.device == "cuda"))

        best_map = -1.0

        def _finite_check_loss(loss_dict):
            if not DEBUG_FINITE:
                return
            for k, v in loss_dict.items():
                if not torch.isfinite(v).all():
                    print("\n[Non-finite loss detected]", k, v.detach().cpu())
                    raise RuntimeError("NaN/Inf in loss component")

        prev_det = torch.are_deterministic_algorithms_enabled()
        torch.use_deterministic_algorithms(False)

        for epoch in range(self.cfg.epochs):
            model.train()

            for imgs, targets in trainloader:
                imgs = [im.to(self.device, non_blocking=True) for im in imgs]
                targets = [{k: (v.to(self.device) if torch.is_tensor(v) else v) for k, v in t.items()} for t in targets]

                optimizer.zero_grad(set_to_none=True)

                with torch.amp.autocast(device_type="cuda", enabled=(USE_AMP and self.device == "cuda")):
                    loss_dict = model(imgs, targets)
                    _finite_check_loss(loss_dict)
                    losses = sum(loss for loss in loss_dict.values())

                if DEBUG_FINITE and (not torch.isfinite(losses).all()):
                    raise RuntimeError("NaN/Inf in total loss")

                scaler.scale(losses).backward()

                if GRAD_CLIP_NORM is not None:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=float(GRAD_CLIP_NORM))

                scaler.step(optimizer)
                scaler.update()

            if epoch >= (self.cfg.epochs - EVAL_LAST_K_EPOCHS) and ((epoch + 1) % EVAL_EVERY == 0):
                mAP = evaluate_map_50(
                    model,
                    testloader,
                    self.device,
                    iou_thr=0.5,
                    max_batches=self.cfg.val_max_batches,
                )
                if mAP > best_map:
                    best_map = mAP

        torch.use_deterministic_algorithms(prev_det)

        best_map_pct = best_map * 100.0
        return best_map_pct, best_map_pct  

    def __call__(self, q_vec: List[int]) -> float:
        self.ensure_collected(q_vec)
        idxs = self.collector.indices_flat()
        best_all, best_fg = self._train_eval(idxs)
        return best_fg  

    def call_both(self, q_vec: List[int]) -> Tuple[float, float]:
        self.ensure_collected(q_vec)
        idxs = self.collector.indices_flat()
        return self._train_eval(idxs)


def get_voc_det_oracle_multi(K: int = NUM_SOURCES_DEFAULT, epochs=EPOCHS, lr=LR):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return VOCDetSSDOracleMulti(device, DetTrainCfg(epochs=epochs, lr=lr), K=K)
