import os
import xml.etree.ElementTree as ET
import logging
import math

from PIL import Image
import torch
from torch import tensor
from torch.utils.data import Dataset
from torchvision.transforms.functional import (
    to_tensor,
    crop as torch_crop,
    to_pil_image,
    resize as functional_resize,
)

# Configure logging (errors only, output to stderr)
logging.basicConfig(level=logging.ERROR, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


def parse_xml_for_bbox(xml_file):
    """
    Parse an XML file to extract valid bounding boxes.
    Returns a list of [xmin, ymin, xmax, ymax] boxes.
    """
    tree = ET.parse(xml_file)
    root = tree.getroot()

    bboxes = []
    for obj in root.findall("object"):
        bnd = obj.find("bndbox")
        xmin = int(float(bnd.find("xmin").text))
        ymin = int(float(bnd.find("ymin").text))
        xmax = int(float(bnd.find("xmax").text))
        ymax = int(float(bnd.find("ymax").text))
        if xmin < xmax and ymin < ymax:
            bboxes.append([xmin, ymin, xmax, ymax])
    return bboxes


def rescale_bbox(bbox, orig_size, target_size):
    """
    Rescale a bbox [xmin, ymin, xmax, ymax] from orig_size to target_size.
    """
    orig_w, orig_h = orig_size
    tgt_w, tgt_h = target_size
    sx, sy = tgt_w / orig_w, tgt_h / orig_h

    xmin, ymin, xmax, ymax = bbox
    return [
        int(xmin * sx),
        int(ymin * sy),
        int(xmax * sx),
        int(ymax * sy),
    ]


class FocLVanillaBoxDataset(Dataset):
    """
    Dataset yielding full-image and single-crop based on the first XML bbox.
    """
    def __init__(
        self,
        subset,
        annotation_folder,
        crop_transform=None,
        resize_size=(224, 224),
    ):
        self.subset = subset
        self.annotation_folder = annotation_folder
        self.crop_transform = crop_transform
        self.resize_size = resize_size

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        img_path, label = self.subset.dataset.samples[self.subset.indices[idx]]
        image_id = os.path.splitext(os.path.basename(img_path))[0]
        cls = os.path.basename(os.path.dirname(img_path))
        ann_file = os.path.join(self.annotation_folder, cls, f"{image_id}.xml")

        img = Image.open(img_path).convert("RGB")
        orig = img.copy()
        try:
            bboxes = parse_xml_for_bbox(ann_file) if os.path.exists(ann_file) else []
            if bboxes:
                xmin, ymin, xmax, ymax = bboxes[0]
                W, H = orig.size
                if not (0 <= xmin < xmax <= W and 0 <= ymin < ymax <= H):
                    raise ValueError(f"Invalid bbox {bboxes[0]} in {ann_file}")

                # Crop and resize
                crop = torch_crop(
                    to_tensor(orig), top=ymin, left=xmin,
                    height=ymax - ymin, width=xmax - xmin
                )
                crop = to_pil_image(crop)
                crop = functional_resize(crop, self.resize_size)
                if self.crop_transform:
                    crop = self.crop_transform(crop)

                full = functional_resize(orig, self.resize_size)
                return to_tensor(full), crop, label, tensor(bboxes[0], dtype=torch.float)
        except Exception as e:
            logger.error(f"{img_path}: {e}")

        # Fallback: full image as crop
        full = functional_resize(orig, self.resize_size)
        crop = self.crop_transform(full) if self.crop_transform else to_tensor(full)
        bbox = [0, 0, orig.size[0], orig.size[1]]
        return to_tensor(full), crop, label, tensor(bbox, dtype=torch.float)


class MultiGlimpseDistortionAwareDataset(Dataset):
    """
    Multi-glimpse dataset with optional distortion-aware and offset-scale crops.
    If multi_crop=False, returns one random crop; else returns stacked crops.
    """
    def __init__(
        self,
        subset,
        annotation_folder,
        crop_transform=None,
        resize_size=(224, 224),
        train_mode=True,
        offset_fraction=0.2,
        scale_jitter=0.1,
        area_threshold=0.2,
        augmentation_mode="medium",
        num_glimpses=3,
        max_crop_ratio=0.2,
        multi_crop=False,
    ):
        self.subset = subset
        self.annotation_folder = annotation_folder
        self.crop_transform = crop_transform
        self.resize_size = resize_size
        self.train_mode = train_mode
        self.offset_fraction = offset_fraction
        self.scale_jitter = scale_jitter
        self.area_threshold = area_threshold
        self.num_glimpses = num_glimpses
        self.max_crop_ratio = max_crop_ratio
        self.multi_crop = multi_crop

        mode = augmentation_mode.lower()
        if mode == "conservative":
            self.offset_fraction *= 0.5
            self.scale_jitter *= 0.5
        elif mode == "aggressive":
            self.offset_fraction *= 1.5
            self.scale_jitter *= 1.5

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        path, label = self.subset.dataset.samples[self.subset.indices[idx]]
        img = Image.open(path).convert("RGB")
        orig = img.copy()
        cls = os.path.basename(os.path.dirname(path))
        img_id = os.path.splitext(os.path.basename(path))[0]
        ann = os.path.join(self.annotation_folder, cls, f"{img_id}.xml")
        bboxes = parse_xml_for_bbox(ann) if os.path.exists(ann) else []

        if not bboxes:
            return self._fallback(orig, label)

        xmin, ymin, xmax, ymax = bboxes[0]
        W, H = orig.size
        if not (0 <= xmin < xmax <= W and 0 <= ymin < ymax <= H):
            return self._fallback(orig, label)

        box_w, box_h = xmax - xmin, ymax - ymin
        area_frac = (box_w * box_h) / (W * H)

        crops = []
        for _ in range(self.num_glimpses):
            if self.train_mode and area_frac < self.area_threshold:
                c = self._distortion_aware_expand(orig, xmin, ymin, xmax, ymax)
            else:
                c = self._offset_and_scale_crop(orig, xmin, ymin, xmax, ymax)
            c = self.crop_transform(c) if self.crop_transform else to_tensor(c)
            crops.append(c)

        full = to_tensor(functional_resize(orig, self.resize_size))
        stack = torch.stack(crops)
        bbox = tensor([xmin, ymin, xmax, ymax], dtype=torch.float)

        if not self.multi_crop:
            idx = torch.randint(0, stack.size(0), (1,)).item()
            return full, stack[idx], label, bbox
        return full, stack, label, bbox

    def _fallback(self, orig, label):
        full = functional_resize(orig, self.resize_size)
        t = to_tensor(full)
        crops = torch.stack([t] * self.num_glimpses) if self.multi_crop else t
        bbox = tensor([0, 0, orig.width, orig.height], dtype=torch.float)
        return t, crops, label, bbox

    def _distortion_aware_expand(self, image, xmin, ymin, xmax, ymax):
        w, h = image.size
        bw, bh = xmax - xmin, ymax - ymin
        tx, ty = self.resize_size
        sf = max(tx / bw, ty / bh)
        thresh = 1.0 / (1.0 - self.max_crop_ratio)
        if sf > thresh:
            factor = math.sqrt((sf / thresh))
            ew = bw * (factor - 1) / 2
            eh = bh * (factor - 1) / 2
            xmin = max(0, xmin - ew)
            ymin = max(0, ymin - eh)
            xmax = min(w, xmax + ew)
            ymax = min(h, ymax + eh)
        crop = image.crop((xmin, ymin, xmax, ymax))
        return functional_resize(crop, self.resize_size)

    def _offset_and_scale_crop(self, image, xmin, ymin, xmax, ymax):
        w, h = image.size
        bw, bh = xmax - xmin, ymax - ymin
        sfw = 1 + (torch.rand(1).item() * 2 - 1) * self.scale_jitter
        sfh = 1 + (torch.rand(1).item() * 2 - 1) * self.scale_jitter
        nw, nh = bw * sfw, bh * sfh
        ox = (torch.rand(1).item() * 2 - 1) * (self.offset_fraction * bw)
        oy = (torch.rand(1).item() * 2 - 1) * (self.offset_fraction * bh)
        cx, cy = (xmin + xmax) / 2 + ox, (ymin + ymax) / 2 + oy
        x0, y0 = max(0, cx - nw / 2), max(0, cy - nh / 2)
        x1, y1 = min(w, cx + nw / 2), min(h, cy + nh / 2)
        if x1 <= x0 or y1 <= y0:
            return self._simple_bbox_crop(image, xmin, ymin, xmax, ymax)
        crop = torch_crop(to_tensor(image), top=int(y0), left=int(x0), height=int(y1 - y0), width=int(x1 - x0))
        return functional_resize(to_pil_image(crop), self.resize_size)

    def _simple_bbox_crop(self, image, xmin, ymin, xmax, ymax):
        crop = torch_crop(to_tensor(image), top=ymin, left=xmin, height=ymax - ymin, width=xmax - xmin)
        return functional_resize(to_pil_image(crop), self.resize_size)


__all__ = [
    "FocLVanillaBoxDataset",
    "MultiGlimpseDistortionAwareDataset",
    "parse_xml_for_bbox",
    "rescale_bbox",
]
