from dataclasses import dataclass
from typing import List, Tuple, Optional
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
from PIL import Image
import scipy.io as sio

from pathlib import Path
from torchvision.datasets import VOCSegmentation
from torchvision.transforms import functional as F
from torchvision.transforms import InterpolationMode
from torchvision.models.segmentation import deeplabv3_resnet50
from torchvision.models import ResNet50_Weights


# ------------------------------- GLOBAL CONFIG (KEEP AS-IS) ------------------------------- #


NUM_CLASSES = 21
IGNORE_INDEX = 255

IMG_SIZE = (320, 320)
BATCH_SIZE = 4
EPOCHS = 30
LR = 1e-3
WEIGHT_DECAY = 1e-4

EVAL_EVERY = 1
VAL_MAX_BATCHES = None  # None이면 full val

NUM_SOURCES_DEFAULT = 5




# ------------------------------- Dataset utils ------------------------------- #
class JointSegTransform:
    def __init__(self, train: bool):
        self.train = train
        self.mean = (0.485, 0.456, 0.406)
        self.std  = (0.229, 0.224, 0.225)

    def __call__(self, img, mask):
        img  = F.resize(img, IMG_SIZE, interpolation=InterpolationMode.BILINEAR)
        mask = F.resize(mask, IMG_SIZE, interpolation=InterpolationMode.NEAREST)

        if self.train:
            if torch.rand(()) < 0.5:
                img  = F.hflip(img)
                mask = F.hflip(mask)

        img = F.to_tensor(img)
        img = F.normalize(img, self.mean, self.std)

        mask = torch.from_numpy(np.array(mask, dtype=np.int64))
        return img, mask


class SegWrap(Dataset):
    def __init__(self, base_ds, joint_tf: JointSegTransform):
        self.base = base_ds
        self.tf = joint_tf

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        img, mask = self.base[idx]
        img, mask = self.tf(img, mask)
        return img, mask


class SBDFromList(Dataset):
    def __init__(self, sbd_root: str, id_list_path: str):
        self.ids = self._read_ids(id_list_path)

        base1 = os.path.join(sbd_root, "benchmark_RELEASE", "dataset")
        base2 = sbd_root
        self.base = base1 if os.path.isdir(base1) else base2

        self.img_dir = os.path.join(self.base, "img")
        if not os.path.isdir(self.img_dir):
            raise FileNotFoundError(f"SBD img dir not found: {self.img_dir}")

        cls_candidates = [os.path.join(self.base, "cls"), os.path.join(self.base, "class")]
        self.cls_dir = next((p for p in cls_candidates if os.path.isdir(p)), None)
        if self.cls_dir is None:
            raise FileNotFoundError(f"SBD cls/class dir not found. Tried: {cls_candidates}")

    @staticmethod
    def _read_ids(path: str):
        p = Path(path)
        if not p.exists():
            raise FileNotFoundError(f"SBD id list not found: {p}")
        with open(p, "r", encoding="utf-8", errors="ignore") as f:
            return [ln.strip() for ln in f if ln.strip()]

    def __len__(self):
        return len(self.ids)

    def _find_img(self, sid: str):
        for ext in [".jpg", ".png", ".jpeg"]:
            p = os.path.join(self.img_dir, sid + ext)
            if os.path.exists(p):
                return p
        return None

    def __getitem__(self, idx: int):
        sid = self.ids[idx]
        img_path = self._find_img(sid)
        if img_path is None:
            raise FileNotFoundError(f"Image not found for id={sid} (looked in {self.img_dir})")

        mat_path = os.path.join(self.cls_dir, sid + ".mat")
        if not os.path.exists(mat_path):
            raise FileNotFoundError(f"Mask .mat not found for id={sid} (looked in {self.cls_dir})")

        img = Image.open(img_path).convert("RGB")
        mat = sio.loadmat(mat_path)
        mask_arr = mat["GTcls"]["Segmentation"][0, 0]
        mask = Image.fromarray(mask_arr.astype(np.uint8))
        return img, mask


def build_full_trainset_and_valset(voc_root: str, sbd_root: str, sbd_list_path: str):
    tf_train = JointSegTransform(train=True)
    tf_val   = JointSegTransform(train=False)

    voc_train_raw = VOCSegmentation(root=voc_root, year="2012", image_set="train", download=False)
    voc_val_raw   = VOCSegmentation(root=voc_root, year="2012", image_set="val", download=False)
    sbd_raw       = SBDFromList(sbd_root, sbd_list_path)

    voc_train = SegWrap(voc_train_raw, tf_train)
    sbd_train = SegWrap(sbd_raw, tf_train)
    train_full = ConcatDataset([voc_train, sbd_train])

    valset = SegWrap(voc_val_raw, tf_val)
    return train_full, valset


def load_source_pools(split_dir: str, K: int) -> List[np.ndarray]:
    pools = []
    for k in range(K):
        p = Path(split_dir) / f"idxs_source{k}.npy"
        if not p.exists():
            raise FileNotFoundError(f"Missing split file: {p}")
        pools.append(np.load(p).astype(np.int64))
    return pools


class IncrementalCollectorMultiDet:
    """per-source fixed order + prefix selection => nested."""
    def __init__(self, pools: List[np.ndarray]):
        self.K = len(pools)
        self.pools = []
        self.ptrs = [0] * self.K
        self.selected = [[] for _ in range(self.K)]

        for k, arr in enumerate(pools):
            rng = np.random.RandomState(42 + k)    ### seed 
            perm = rng.permutation(len(arr))
            self.pools.append(arr[perm])

    def ensure_size(self, q_targets: List[int]) -> List[int]:
        added = [0] * self.K
        for k in range(self.K):
            curr = len(self.selected[k])
            need = max(0, int(q_targets[k]) - curr)
            if need == 0:
                continue
            start = self.ptrs[k]
            end = min(start + need, len(self.pools[k]))
            take = self.pools[k][start:end].tolist()
            self.ptrs[k] = end
            self.selected[k].extend(take)
            added[k] = len(take)
        return added

    def indices_flat(self) -> List[int]:
        out = []
        for k in range(self.K):
            out.extend(self.selected[k])
        return out

    def current_q_vec(self) -> List[int]:
        return [len(self.selected[k]) for k in range(self.K)]


@torch.no_grad()
def compute_miou(model, loader, device, num_classes=21, ignore_index=255, max_batches=None):
    model.eval()
    conf = np.zeros((num_classes, num_classes), dtype=np.int64)

    for bi, (imgs, masks) in enumerate(loader):
        if max_batches is not None and bi >= max_batches:
            break

        imgs = imgs.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)

        out = model(imgs)["out"]
        pred = out.argmax(1)

        pred_np = pred.detach().cpu().numpy().reshape(-1)
        mask_np = masks.detach().cpu().numpy().reshape(-1)

        valid = (mask_np != ignore_index)
        pred_np = pred_np[valid]
        mask_np = mask_np[valid]

        k = (mask_np * num_classes + pred_np).astype(np.int64)
        binc = np.bincount(k, minlength=num_classes * num_classes)
        conf += binc.reshape(num_classes, num_classes)

    diag = np.diag(conf)
    denom = conf.sum(1) + conf.sum(0) - diag
    iou = diag / np.maximum(denom, 1)

    miou_all = float(np.nanmean(iou))
    miou_fg  = float(np.nanmean(iou[1:]))
    return miou_all, miou_fg


@dataclass
class SegTrainCfg:
    epochs: int = EPOCHS
    batch_size: int = BATCH_SIZE
    lr: float = LR
    weight_decay: float = WEIGHT_DECAY
    num_workers: int = 0
    val_max_batches: Optional[int] = VAL_MAX_BATCHES


class VOCSegDeepLabOracleMulti:
    def __init__(self, device: str, cfg: SegTrainCfg, K: int):
        self.device, self.cfg = device, cfg
        self.K = K
        self.voc_root = "./data/VOC"
        self.sbd_root = "./data/SBD"
        self.sbd_list_path = "./data/SBD/trainval_9118.txt"
        self.split_dir = "./data/segmentation"

        self.train_full, self.valset = build_full_trainset_and_valset(
            voc_root=self.voc_root,
            sbd_root=self.sbd_root,
            sbd_list_path=self.sbd_list_path,
        )

        pools = load_source_pools(self.split_dir, K=self.K)
        self.collector = IncrementalCollectorMultiDet(pools)

    def ensure_collected(self, q_vec: List[int]) -> List[int]:
        return self.collector.ensure_size([int(x) for x in q_vec])

    def current_q_vec(self) -> List[int]:
        return self.collector.current_q_vec()

    def _train_eval(self, idxs: List[int]) -> Tuple[float, float]:


        trainloader = DataLoader(
            Subset(self.train_full, idxs),
            batch_size=self.cfg.batch_size,
            shuffle=True,
            num_workers=self.cfg.num_workers,
            pin_memory=True,
            drop_last=True,
        )
        valloader = DataLoader(
            self.valset,
            batch_size=self.cfg.batch_size,
            shuffle=False,
            num_workers=self.cfg.num_workers,
            pin_memory=True,
        )

        model = deeplabv3_resnet50(
            weights=None,
            weights_backbone=ResNet50_Weights.IMAGENET1K_V2,
            num_classes=NUM_CLASSES,
        ).to(self.device)

        if self.device == "cuda":
            torch.backends.cudnn.benchmark = False

        criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
        optimizer = optim.AdamW(model.parameters(), lr=self.cfg.lr, weight_decay=self.cfg.weight_decay)

        best_miou_fg = -1.0
        best_miou_all = -1.0
        prev_det = torch.are_deterministic_algorithms_enabled()
        torch.use_deterministic_algorithms(False)
        for epoch in range(EPOCHS):

            model.train()
            for imgs, masks in trainloader:
                imgs = imgs.to(self.device, non_blocking=True)
                masks = masks.to(self.device, non_blocking=True)

                optimizer.zero_grad(set_to_none=True)
                out = model(imgs)["out"]

                loss = criterion(out, masks)
                loss.backward()
                optimizer.step()

            if (epoch + 1) % EVAL_EVERY == 0:
                miou_all, miou_fg = compute_miou(
                    model, valloader, self.device,
                    num_classes=NUM_CLASSES, ignore_index=IGNORE_INDEX,
                    max_batches=self.cfg.val_max_batches
                )
                if miou_fg > best_miou_fg:
                    best_miou_fg = miou_fg
                    best_miou_all = miou_all
        torch.use_deterministic_algorithms(prev_det)
        return best_miou_all*100

    def __call__(self, q_vec: List[int]) -> float:
        self.ensure_collected(q_vec)
        idxs = self.collector.indices_flat()
        miou_fg = self._train_eval(idxs)
        return miou_fg

    def call_both(self, q_vec: List[int]) -> Tuple[float, float]:
        self.ensure_collected(q_vec)
        idxs = self.collector.indices_flat()
        return self._train_eval(idxs)


def get_voc_seg_oracle_multi(K: int = NUM_SOURCES_DEFAULT, epochs=EPOCHS, lr=LR):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return VOCSegDeepLabOracleMulti(device, SegTrainCfg(epochs=EPOCHS, lr=LR), K=K)
