import os
from collections.abc import Callable

import numpy as np
import pandas as pd
import torch

from .base import AnomalyDataset


class Waterbirds(AnomalyDataset):
    split_dict = {
        "train": 0,
        "validation": 1,
        "test": 2,
    }
    csv_filename = "metadata.csv"

    def __init__(self, root: str, split: str, transform: Callable | None = None):
        assert split in self.split_dict.keys()
        super().__init__(root=os.path.join(root, "waterbirds_v1.0"), split=split, transform=transform)

        metadata = pd.read_csv(os.path.join(self.root, self.csv_filename))
        _split = metadata["split"] == self.split_dict[split]

        self.attr_names = ["y", "place"]
        self.filename = list(metadata["img_filename"][_split])
        self.attr = torch.tensor(np.array(metadata[self.attr_names][_split]), dtype=torch.int64)


def waterbirds(
    root: str = "./data",
    split: str = "train",
    transform: Callable | None = None,
):
    if split == "valid":
        split = "validation"
    return Waterbirds(root=root, split=split, transform=transform)
