import copy
import os
import sys

import h5py
import pandas
import torch
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm

from constants import paths

# from base_case import plant_featute_label_startegy

# In waterbirds, 1 is water bird (background) and 0 is land bird (background)
landbirds_to_fix = (
    "Western_Meadowlark",
    "Eastern_Towhee",
    "Western_Wood_Pewee",
)  # should be land birds (0)


def correct_labels(df):
    for bird in landbirds_to_fix:
        df.loc[df["img_filename"].str.contains(bird), "y"] = 0
    return df


def get_splits(seed, test_split=0.15, val_split=0.15):
    # np.random.seed(seed)
    legend = pandas.read_csv(
        os.path.join(
            paths._datasets_dir,
            "waterbirds",
            "waterbird_complete95_forest2water2",
            "metadata.csv",
        )
    )
    legend = correct_labels(legend)
    train_set, test_set = train_test_split(
        legend, test_size=test_split, random_state=seed
    )
    train_set, val_set = train_test_split(
        train_set, test_size=(val_split / (1 - test_split)), random_state=seed
    )
    return train_set, val_set, test_set


def get_transform():
    scale = 256.0 / 224.0
    target_resolution = 224
    transform = transforms.Compose(
        [
            transforms.Resize(
                (int(target_resolution * scale), int(target_resolution * scale))
            ),
            transforms.CenterCrop(target_resolution),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )
    return transform


def load_image(filename, transform=None):
    img = copy.deepcopy(
        Image.open(
            os.path.join(
                paths._datasets_dir,
                "waterbirds",
                "waterbird_complete95_forest2water2",
                filename,
            )
        )
    )

    if transform is not None:
        img = transform(img)

    return img


def load_from_hdf5(filename, h5, transform=None):
    p1, p2 = filename.split("/")
    img = copy.deepcopy(Image.fromarray(h5[p1][p2][...]))

    if transform is not None:
        img = transform(img)

    return img


class Collective_Waterbirds(Dataset):
    def __init__(
        self, split=None, csv_df=None, img_transform=None, preload=True
    ) -> None:
        super().__init__()

        if split is not None:
            self.df = pandas.read_csv(
                os.path.join(paths._datasets_dir, "waterbirds", f"{split}.csv")
            )
        else:
            self.df = csv_df

        self.split = split
        self.image_size = 224

        self.preload = preload
        if img_transform is None:
            self.img_transform = get_transform()
        else:
            self.img_transform = img_transform

        self.sensitives: torch.Tensor = None
        self.labels: torch.Tensor = None
        self.true_labels: torch.Tensor = None

        self.initialize_labels()

        self.images = None
        if self.preload:
            self._preload_data()
        # self.backup = None
        # self.protected_returned = False
        # self.transformed = {}

        # self.labels = self.df['y'].to_numpy().astype(int)
        # self.protected = self.df['place'].to_numpy().astype(int)

        # self.imgs = []
        # with h5py.File(paths._birds_hdf5, "r") as f:
        #     for fn in tqdm(self.df["img_filename"], miniters=100):
        #         try:
        #             self.imgs.append(load_from_hdf5(fn, f, self.img_transform))
        #         except OSError:
        #             print("could not load ", fn)

    def _preload_data(self):
        print("loading waterbirds", file=sys.stderr)
        self.images = torch.zeros(len(self), 3, self.image_size, self.image_size)
        with h5py.File(paths._birds_hdf5, "r") as f:
            for f_i, fn in tqdm(
                enumerate(self.df["img_filename"]), disable=None, total=len(self)
            ):
                try:
                    self.images[f_i] = load_from_hdf5(fn, f, self.img_transform)
                except OSError:
                    print("could not load ", fn, file=sys.stderr)
        self.df = None

    # def return_protected(self, protected_returned):
    #     self.protected_returned = protected_returned
    def initialize_labels(self):
        # save vectors in shape NX1 to match pytorch standards instead of (N, )
        self.labels = torch.tensor(self.df["y"].to_numpy()).float()
        self.true_labels = torch.tensor(self.df["y"].to_numpy()).float()
        self.sensitives = torch.tensor(self.df["place"].to_numpy()).float()

        # # Make sure that the sensitive attribute is always the minority
        # if torch.count_nonzero(self.sensitives) > (len(self) // 2):
        #     self.sensitives = 1 - self.sensitives

        # Make sure that the sensitive attribute has smaller positive rate
        group0_pos = (
            torch.count_nonzero((1 - self.sensitives) * self.true_labels).item()
            / torch.count_nonzero(1 - self.sensitives).item()
        )
        group1_pos = (
            torch.count_nonzero(self.sensitives * self.true_labels).item()
            / torch.count_nonzero(self.sensitives).item()
        )
        if group0_pos < group1_pos:
            self.sensitives = 1 - self.sensitives

    def __len__(self):
        if self.df is None:
            return len(self.images)
        return len(self.df)

    def __getitem__(self, index):
        if self.preload:
            img = self.images[index]
        else:
            with h5py.File(paths._birds_hdf5, "r") as f:
                img = load_from_hdf5(
                    self.df["img_filename"][index], f, self.img_transform
                )
        sensitive = self.sensitives[index]
        label = self.labels[index]
        return img, sensitive, label
        # label = self.labels[index].item()
        # if index in self.transformed:
        #     img = self.transformed[index]
        # else:
        #     img = self.imgs[index]
        # if self.protected_returned:
        #     protected = self.protected[index].item()
        #     return img, label, protected
        # return img, label
