import argparse
import os
from typing import Dict, Optional

import torch
from torch.utils.data import DataLoader, random_split

try:
    from torchvision import datasets, transforms
except Exception as exc:  # pragma: no cover
    datasets = None
    transforms = None


def _build_transforms() -> Dict[str, transforms.Compose]:
    if transforms is None:
        raise RuntimeError("torchvision is required for transforms and datasets.ImageFolder")
    imagenet_mean = [0.485, 0.456, 0.406]
    imagenet_std = [0.229, 0.224, 0.225]
    train_tf = transforms.Compose(
        [
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
        ]
    )
    eval_tf = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
        ]
    )
    return {"train": train_tf, "eval": eval_tf}


def _load_domain_dataset(domain_root: str, transform) -> datasets.ImageFolder:
    if datasets is None:
        raise RuntimeError("torchvision is required to use datasets.ImageFolder")
    if not os.path.isdir(domain_root):
        raise FileNotFoundError(
            f"Domain directory not found: {domain_root}. Ensure data is organized as data_root/<domain>/<class>/*.jpg"
        )
    return datasets.ImageFolder(root=domain_root, transform=transform)


def _resolve_data_root(data_root: Optional[str], source_domain: str, target_domain: str) -> str:
    """Resolve a robust absolute data root.

    Preference order:
    1) If data_root is absolute and has both domains, use it.
    2) If data_root is relative, try relative to CWD, then relative to project root (two levels up from this file).
    3) Fallbacks: project_root/data/domainnet, then project_root/data/office_home, whichever contains domains.
    """
    candidates = []

    def has_domains(base: str) -> bool:
        return os.path.isdir(os.path.join(base, source_domain)) and os.path.isdir(os.path.join(base, target_domain))

    if data_root:
        # Absolute first
        if os.path.isabs(data_root):
            candidates.append(data_root)
        # Relative to CWD
        candidates.append(os.path.abspath(os.path.join(os.getcwd(), data_root)))
        # Relative to project root (two levels up from this file)
        project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))
        candidates.append(os.path.abspath(os.path.join(project_root, data_root)))

    # Fallbacks under project root
    project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))
    candidates.append(os.path.join(project_root, "data", "domainnet"))
    candidates.append(os.path.join(project_root, "data", "office_home"))

    for path in candidates:
        if has_domains(path):
            return path

    # Last resort: return provided path or sensible default; downstream will raise clear error
    return data_root or os.path.join(project_root, "data", "office_home")


def get_dataloaders(
    source_domain: str,
    target_domain: str,
    batch_size: int,
    data_root: Optional[str] = None,
    num_workers: int = 4,
    seed: int = 42,
    max_samples_per_split: Optional[int] = None,
) -> Dict[str, DataLoader]:
    """
    Return dataloaders for domain adaptation with ImageFolder structure.

    Expects data laid out as:
      data_root/
        <source_domain>/<class_name>/*.jpg
        <target_domain>/<class_name>/*.jpg

    Returns dict with keys: source_train, source_val, source_test, target_train
    """
    tf = _build_transforms()
    base_dir = _resolve_data_root(data_root, source_domain, target_domain)

    source_root = os.path.join(base_dir, source_domain)
    target_root = os.path.join(base_dir, target_domain)

    # Build datasets
    source_full = _load_domain_dataset(source_root, transform=tf["train"])  # temporary transform for split indices
    target_train_ds = _load_domain_dataset(target_root, transform=tf["train"])  # train-only
    target_eval_ds = _load_domain_dataset(target_root, transform=tf["eval"])    # eval-only

    # Reproducible split for source domain: 60/20/20
    g = torch.Generator()
    g.manual_seed(seed)
    n_total = len(source_full)
    n_train = int(0.6 * n_total)
    n_val = int(0.2 * n_total)
    n_test = n_total - n_train - n_val
    source_train_subset, source_val_subset, source_test_subset = random_split(
        source_full, lengths=[n_train, n_val, n_test], generator=g
    )

    # Assign correct transforms: train for train subset; eval for val/test
    # Note: Subsets share underlying dataset, so we must override .dataset.transform safely by creating
    # distinct base datasets with appropriate transforms for each split.
    source_base_train = _load_domain_dataset(source_root, transform=tf["train"])
    source_base_eval = _load_domain_dataset(source_root, transform=tf["eval"])

    # Re-wrap subsets to point to the correct base datasets while preserving indices
    source_train_subset.dataset = source_base_train
    source_val_subset.dataset = source_base_eval
    source_test_subset.dataset = source_base_eval

    # Optionally subsample splits for fast smoke runs
    def _maybe_subset(subset, k: Optional[int]):
        if k is None:
            return subset
        n = len(subset)
        if n <= k:
            return subset
        g2 = torch.Generator()
        g2.manual_seed(seed + 1)
        idx = torch.randperm(n, generator=g2)[:k]
        # Subset of Subset: remap indices
        base_indices = torch.tensor(subset.indices)
        new_indices = base_indices[idx]
        return type(subset)(subset.dataset, new_indices.tolist())

    source_train_subset = _maybe_subset(source_train_subset, max_samples_per_split)
    source_val_subset = _maybe_subset(source_val_subset, max_samples_per_split)
    source_test_subset = _maybe_subset(source_test_subset, max_samples_per_split)
    # For target, we provide train and eval views
    if max_samples_per_split is not None and len(target_train_ds) > max_samples_per_split:
        g3 = torch.Generator(); g3.manual_seed(seed + 2)
        perm = torch.randperm(len(target_train_ds), generator=g3)[:max_samples_per_split]
        target_train_ds = torch.utils.data.Subset(target_train_ds, perm.tolist())
    if max_samples_per_split is not None and len(target_eval_ds) > max_samples_per_split:
        g4 = torch.Generator(); g4.manual_seed(seed + 3)
        perm = torch.randperm(len(target_eval_ds), generator=g4)[:max_samples_per_split]
        target_eval_ds = torch.utils.data.Subset(target_eval_ds, perm.tolist())

    # DataLoaders
    loaders: Dict[str, DataLoader] = {
        "source_train": DataLoader(
            source_train_subset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
        ),
        "source_val": DataLoader(
            source_val_subset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True,
        ),
        "source_test": DataLoader(
            source_test_subset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True,
        ),
        "target_train": DataLoader(
            target_train_ds,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
        ),
        "target_eval": DataLoader(
            target_eval_ds,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True,
        ),
    }
    return loaders


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Dataset loader sanity check")
    parser.add_argument("--source", type=str, default="Amazon")
    parser.add_argument("--target", type=str, default="Webcam")
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--data_root", type=str, default="./data/office_home")
    parser.add_argument("--num_workers", type=int, default=4)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    dls = get_dataloaders(
        source_domain=args.source,
        target_domain=args.target,
        batch_size=args.batch_size,
        data_root=args.data_root,
        num_workers=args.num_workers,
        seed=args.seed,
    )
    for name, dl in dls.items():
        print(name, len(dl.dataset))
