
import os
import tarfile
import hashlib
import urllib.request
from pathlib import Path
from typing import Optional, Tuple, List

from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms import AutoAugment, AutoAugmentPolicy
from PIL import Image
import torch.distributed as dist

THIS_PATH = os.path.dirname(os.path.abspath(__file__))
ROOT_PATH = os.path.abspath(os.path.join(THIS_PATH, '..', '..', '..', '..', '..'))
DATA_ROOT = os.path.join(ROOT_PATH, 'datasets/disentanglement/aircraft')

def is_ddp():
    return dist.is_available() and dist.is_initialized()

def is_main_process():
    return (not is_ddp()) or (dist.get_rank() == 0)

# --- Official URL & MD5 (VGG FGVC-Aircraft 2013b) ---
FGVC_URL  = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
FGVC_MD5  = "32eca553f897e747706144f3bb444df0"
ARCHIVE_NAME = "fgvc-aircraft-2013b.tar.gz"

# --- Directory / file conventions ---
SUBDIR = "fgvc-aircraft-2013b"
DATADIR = "data"
IMAGES_DIR = "images"

CLASS_FILES = {
    "variant": "variants.txt",
    "family": "families.txt",
    "manufacturer": "manufacturers.txt",
}
SPLIT_FILES = {
    "variant": {
        "train": "images_variant_train.txt",
        "val": "images_variant_val.txt",
        "trainval": "images_variant_trainval.txt",
        "test": "images_variant_test.txt",
    },
    "family": {
        "train": "images_family_train.txt",
        "val": "images_family_val.txt",
        "trainval": "images_family_trainval.txt",
        "test": "images_family_test.txt",
    },
    "manufacturer": {
        "train": "images_manufacturer_train.txt",
        "val": "images_manufacturer_val.txt",
        "trainval": "images_manufacturer_trainval.txt",
        "test": "images_manufacturer_test.txt",
    },
}

# ---------- helpers ----------
def _md5(path: Path, chunk: int = 1 << 20) -> str:
    h = hashlib.md5()
    with open(path, "rb") as f:
        while True:
            b = f.read(chunk)
            if not b:
                break
            h.update(b)
    return h.hexdigest()

def _download(url: str, dst: Path) -> None:
    dst.parent.mkdir(parents=True, exist_ok=True)
    print(f"[FGVC] Downloading: {url} -> {dst}")
    urllib.request.urlretrieve(url, dst)

def _extract(archive_path: Path, dst_dir: Path) -> None:
    print(f"[FGVC] Extracting: {archive_path} -> {dst_dir}")
    with tarfile.open(archive_path, "r:gz") as tar:
        tar.extractall(path=dst_dir)

def _resolve_base_dir(root: Path) -> Path:
    cands = [
        root / SUBDIR / DATADIR,
        root / DATADIR,
        root
    ]
    for bd in cands:
        if (bd / IMAGES_DIR).is_dir():
            return bd

    return root / SUBDIR / DATADIR

def _dataset_exists(root: str) -> bool:
    r = Path(root).expanduser()
    # 실제 필요한 건 images 폴더 존재
    return any((r / p).is_dir() for p in [
        f"{SUBDIR}/{DATADIR}/{IMAGES_DIR}",
        f"{DATADIR}/{IMAGES_DIR}",
        f"{IMAGES_DIR}",
    ])

# ---------- Dataset ----------
class FGVCAircraftDataset(Dataset):
    def __init__(
        self,
        root: str = DATA_ROOT,
        split: str = "train",
        annotation_level: str = "variant",
        transform: Optional[transforms.Compose] = None,
        download: bool = True,
        verify_md5: bool = True,
    ):
        self.root = Path(root).expanduser()
        self.split = split
        self.annotation_level = annotation_level
        self.transform = transform

        assert self.split in {"train", "val", "trainval", "test"}
        assert self.annotation_level in {"variant", "family", "manufacturer"}


        if download and not _dataset_exists(self.root):
            self._download_if_needed(verify_md5=verify_md5)

        self.base_dir = _resolve_base_dir(self.root)

        if not self._check_exists():
            raise RuntimeError(
                f"FGVC-Aircraft dataset not found under {self.base_dir}. "
                f"Set download=True to fetch it automatically."
            )


        self.class_names = self._read_class_names()
        self.class_to_idx = {name: i for i, name in enumerate(self.class_names)}
        self.samples: List[Tuple[Path, int]] = self._read_split_list()

    def _check_exists(self) -> bool:
        return (self.base_dir / IMAGES_DIR).is_dir()

    def _download_if_needed(self, verify_md5: bool = True) -> None:
        self.root.mkdir(parents=True, exist_ok=True)
        archive_path = self.root / ARCHIVE_NAME   # <- tar.gz 파일 경로 (중요)

        if not archive_path.exists():
            _download(FGVC_URL, archive_path)

        if verify_md5:
            md5sum = _md5(archive_path)
            if md5sum != FGVC_MD5:
                raise RuntimeError(
                    f"MD5 mismatch for {archive_path.name}: {md5sum} != {FGVC_MD5}"
                )
        _extract(archive_path, self.root)

    def _read_class_names(self) -> List[str]:
        class_file = self.base_dir / CLASS_FILES[self.annotation_level]
        if not class_file.exists():
            raise FileNotFoundError(f"Class file not found: {class_file}")
        with open(class_file, "r", encoding="utf-8") as f:
            names = [ln.strip() for ln in f if ln.strip()]
        return names

    def _read_split_list(self) -> List[Tuple[Path, int]]:
        list_file = self.base_dir / SPLIT_FILES[self.annotation_level][self.split]
        if not list_file.exists():
            raise FileNotFoundError(f"Split list not found: {list_file}")

        items: List[Tuple[Path, int]] = []
        with open(list_file, "r", encoding="utf-8") as f:
            for ln in f:
                ln = ln.strip()
                if not ln:
                    continue
                tok = ln.split()
                image_id, label_name = tok[0], " ".join(tok[1:])
                if label_name not in self.class_to_idx:
                    raise KeyError(f"Label '{label_name}' not in classes list.")
                target = self.class_to_idx[label_name]
                img_path = self.base_dir / IMAGES_DIR / f"{image_id}.jpg"
                if not img_path.exists():
                    raise FileNotFoundError(f"Missing image: {img_path}")
                items.append((img_path, target))
        return items

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int):
        path, target = self.samples[idx]
        img = Image.open(path).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        return img, target

# ---------- convenience builder ----------
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

always_hue_shift = transforms.ColorJitter(
    brightness=0,
    contrast=0,
    saturation=0,
    hue=0.5)

geometric_transforms = [
    transforms.RandomRotation(degrees=180),
]
def get_aircraft_loaders(
    data_root: str = DATA_ROOT,
    annotation_level: str = "variant",
    image_size: int = 224,
    batch_size: int = 64,
    num_workers: int = 4,
    download: bool = True,
    pin_memory: bool = True,
    augment: bool = True,
    use_trainval_as_train: bool = False,
    test_aug=False
):

    if test_aug:
        eval_tfm = transforms.Compose([
            transforms.Resize(int(image_size * 256 / 224)),
            transforms.CenterCrop(image_size),
            always_hue_shift,
            # transforms.RandomChoice(hue_transforms),
            transforms.RandomChoice(geometric_transforms),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ])
    else:
        eval_tfm = transforms.Compose([
            transforms.Resize(int(image_size * 256 / 224)),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ])
    train_tfm = (transforms.Compose([
        transforms.Resize(int(image_size * 256 / 224)),
        transforms.RandomResizedCrop(image_size, scale=(0.6, 1.0)),
        transforms.RandomHorizontalFlip(),
        AutoAugment(AutoAugmentPolicy.IMAGENET),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ]) if augment else eval_tfm)

    train_split = "trainval" if use_trainval_as_train else "train"

    if not _dataset_exists(data_root):
        if is_ddp():
            if is_main_process():
                _ = FGVCAircraftDataset(
                    root=data_root,
                    split=train_split,
                    annotation_level=annotation_level,
                    transform=None,
                    download=True,
                    verify_md5=True,
                )
            dist.barrier()
        else:
            if download:
                _ = FGVCAircraftDataset(
                    root=data_root,
                    split=train_split,
                    annotation_level=annotation_level,
                    transform=None,
                    download=True,
                    verify_md5=True,
                )

    train_dataset = FGVCAircraftDataset(
        root=data_root,
        split=train_split,
        annotation_level=annotation_level,
        transform=train_tfm,
        download=False,
        verify_md5=True,
    )
    test_dataset = FGVCAircraftDataset(
        root=data_root,
        split="test",
        annotation_level=annotation_level,
        transform=eval_tfm,
        download=False,
        verify_md5=True,
    )


    return train_dataset, test_dataset
