"""
WILDS dataset utilities.
Loads and creates a dataloader for the WILDS datasets: iWildCam, Camelyon17, and FMOW.
"""

import os
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Subset
from wilds import get_dataset
from wilds.common.grouper import CombinatorialGrouper
from wilds.common.data_loaders import get_train_loader, get_eval_loader
from wilds.datasets.wilds_dataset import WILDSSubset
from experiments.helpers import worker_init_fn
from torchvision.transforms import ToPILImage
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
import open_clip

def get_wilds_dataset_with_domain_filter(
    dataset_name: str = None,
    path_to_dataset_dir: str = None,
    split_type: str = None,
    transform=None,
    domains=None,
    download=True
):
    """Return a WILDS subset for the given split, optionally filtered by domain.
    For iWildCam train split, uses a domain->cluster CSV if present; otherwise
    filters by raw domain values.
    """
    dataset = get_dataset(dataset=dataset_name, download=download, root_dir=path_to_dataset_dir)
    subset = dataset.get_subset(split_type, transform=transform)

    if (domains is None):
        return subset

    domain_field_map = {
        'iwildcam': 'location',
        'camelyon17': 'hospital',
        'fmow': 'year',
    }
    if dataset_name not in domain_field_map:
        raise ValueError(f"Domain filtering not implemented for {dataset_name}")

    field = domain_field_map[dataset_name]
    field_idx = dataset.metadata_fields.index(field)

    domain_values = subset.metadata_array[:, field_idx]
    if hasattr(domain_values, "numpy"):
        domain_values = domain_values.numpy()
    domain_values = domain_values.astype(int, copy=False)

    if dataset_name == "iwildcam" and split_type == "train":
        map_path = os.path.join(path_to_dataset_dir, "domain_to_cluster_map.csv")
        if os.path.exists(map_path):
            df = pd.read_csv(map_path)
            key_col = "domain_name" if "domain_name" in df.columns else "location"
            dom2clu = dict(zip(df[key_col].astype(int), df["cluster_label"].astype(int)))

            cluster_array = np.array([dom2clu.get(int(v), -1) for v in domain_values], dtype=int)
            target_clusters = np.array([int(d) for d in domains], dtype=int)
            mask = np.isin(cluster_array, target_clusters)
        else:
            target_domains = np.array([int(d) for d in domains], dtype=int)
            mask = np.isin(domain_values, target_domains)
            print(f"[WARN] {map_path} not found. Filtering by raw locations: {list(target_domains)}")
    else:
        target_domains = np.array([int(d) for d in domains], dtype=int)
        mask = np.isin(domain_values, target_domains)

    filtered_idx_in_subset = np.where(mask)[0]
    subset_indices = np.array(subset.indices)
    orig_indices = subset_indices[filtered_idx_in_subset]

    filtered_subset = WILDSSubset(
        dataset=dataset,
        indices=orig_indices,
        transform=transform
    )
    return filtered_subset

def _split_wilds_subset_train_val(
    subset: WILDSSubset,
    generator: torch.Generator,
    val_fraction: float = 0.2
):
    """Splits a WILDSSubset into train/val.
    Returns two WILDSSubset objects that point to the original dataset.
    """
    assert 0.0 < val_fraction < 1.0, "val_fraction must be in (0,1)"
    orig_dataset = subset.dataset
    orig_indices = np.array(subset.indices)
    n = len(orig_indices)
    if n == 0:
        raise ValueError("Subset is empty; cannot create a validation split.")

    perm = torch.randperm(n, generator=generator).numpy()
    split_at = int(round((1.0 - val_fraction) * n))
    train_idx = orig_indices[perm[:split_at]]
    val_idx   = orig_indices[perm[split_at:]]

    train_subset = WILDSSubset(dataset=orig_dataset, indices=train_idx, transform=subset.transform)
    val_subset   = WILDSSubset(dataset=orig_dataset, indices=val_idx,   transform=subset.transform)

    print(f"[Holdout split] total={n} -> train={len(train_subset)} (80%), val={len(val_subset)} (20%)")
    return train_subset, val_subset

def get_wilds_dataloader(
    trial_num: int = None,
    dataset_name: str = None,
    path_to_dataset_dir: str = None,
    split_type: str = None,
    domains=None,
    transform=None,
    batch_size=32, 
    num_workers=4,
    pin_memory=True,
    is_training=True,
    create_val_from_train: bool = False,
    val_fraction: float = 0.2):
    """Construct a seeded DataLoader for a WILDS subset.
    Optionall filter by domain is domains is not None. 
    Optionally splits into train/val if create_val_from_train is True.
    Returns the subset and its DataLoader.
    """
    g = torch.Generator()
    g.manual_seed(trial_num)

    base_subset = get_wilds_dataset_with_domain_filter(
        dataset_name=dataset_name,
        path_to_dataset_dir=path_to_dataset_dir,
        split_type=split_type,
        transform=transform,
        domains=domains
    )

    if create_val_from_train and split_type == 'train':
        train_subset, val_subset = _split_wilds_subset_train_val(
            subset=base_subset, generator=g, val_fraction=val_fraction
        )
        dataset_subset = train_subset if is_training else val_subset
    else:
        dataset_subset = base_subset

    if is_training:
        loader = get_train_loader(
            loader='standard',
            dataset=dataset_subset,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=pin_memory,
            worker_init_fn=worker_init_fn,
            generator=g
        )
    else:
        loader = get_eval_loader(
            loader='standard',
            dataset=dataset_subset,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=pin_memory,
            worker_init_fn=worker_init_fn,
            generator=g
        )

    return dataset_subset, loader


def get_domain_list(dataset_name: str, path_to_dataset_dir: str, split_type: str = 'train'):
    """List available domain identifiers for a WILDS dataset split."""
    dataset = get_dataset(dataset=dataset_name, root_dir=path_to_dataset_dir, download=True)
    subset = dataset.get_subset(split_type)

    domain_field_map = {
        'iwildcam': 'location',
        'camelyon17': 'hospital',
        'fmow': 'year',
    }
    if dataset_name not in domain_field_map:
        raise ValueError(f"Domain listing not implemented for {dataset_name}")

    grouper = CombinatorialGrouper(
        dataset=dataset,
        groupby_fields=[domain_field_map[dataset_name]]
    )
    group_array = grouper.metadata_to_group(subset.metadata_array)
    unique_domains = np.unique(group_array.numpy()).tolist()
    return sorted(unique_domains)
