import random
import torch
import pandas as pd
from tqdm.auto import tqdm


def get_indices(
    dataset, forget_labels: None | list, index_file: None | str, seed: int = 42
):
    forget_indices = []
    retain_indices = []
    if index_file is not None:
        indices = pd.read_csv(index_file)
        if forget_labels is not None:
            forget_indices = indices[indices["class"].isin(forget_labels)][
                "index"
            ].tolist()
            retain_indices = indices[~indices["class"].isin(forget_labels)][
                "index"
            ].tolist()
        else:
            forget_indices_df = indices.sample(frac=0.3, random_state=seed)
            forget_indices = forget_indices_df["index"].tolist()
            retain_indices_df = indices[~indices["index"].isin(forget_indices)]
            retain_indices = retain_indices_df["index"].tolist()

    else:
        for i, (_, label) in tqdm(enumerate(dataset), total=len(dataset)):
            if label in forget_labels:
                forget_indices.append(i)
            else:
                retain_indices.append(i)
    return forget_indices, retain_indices


class UnlearnDiscriminativeDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        dataset,
        forget_labels: None | list,
        num_retains: None | int = 1000,
        num_forgets: None | int = 1000,
        index_file: None | str = None,
        indices_dict: None | dict = None,
        seed: int = 42,
    ):
        self.dataset = dataset

        if indices_dict is not None:
            forget_indices, retain_indices = (
                indices_dict["forget_indices"],
                indices_dict["retain_indices"],
            )
        else:
            forget_indices, retain_indices = get_indices(
                dataset, forget_labels, index_file, seed
            )

        if num_forgets is not None:
            if num_forgets >= 0:
                random.seed(seed)
                forget_indices = random.choices(
                    forget_indices, k=min(num_forgets, len(forget_indices))
                )
        else:
            forget_indices = []

        if num_retains is not None:
            if num_retains >= 0:
                random.seed(seed)
                retain_indices = random.choices(
                    retain_indices, k=min(num_retains, len(retain_indices))
                )
        else:
            retain_indices = []

        if len(forget_indices) < 1 and len(retain_indices) < 1:
            raise Exception(
                "Dataset is empty, assign non None value to atleast num_forgets or num_retains"
            )

        self.indices = forget_indices + retain_indices
        self.forget_indices = forget_indices
        self.retain_indices = retain_indices
        self.forget_labels = forget_labels
        self.num_retains = len(retain_indices)
        self.num_forgets = len(forget_indices)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        feature, label = self.dataset[index]
        if self.forget_labels is not None:
            if label in self.forget_labels:
                return feature, 1
        else:
            if index in self.forget_indices:
                return feature, 1
        return feature, 0


class DiscriminativeDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        dataset,
        forget_labels: None | list,
        num_retains: None | int = None,
        num_forgets: None | int = None,
        index_file: None | str = None,
        indices_dict: None | dict = None,
        seed: int = 42,
    ):
        self.dataset = dataset

        if indices_dict is not None:
            forget_indices, retain_indices = (
                indices_dict["forget_indices"],
                indices_dict["retain_indices"],
            )
        else:
            forget_indices, retain_indices = get_indices(
                dataset, forget_labels, index_file, seed
            )

        if num_forgets is not None:
            if num_forgets >= 0:
                random.seed(seed)
                forget_indices = random.choices(
                    forget_indices, k=min(num_forgets, len(forget_indices))
                )
        else:
            forget_indices = []

        if num_retains is not None:
            if num_retains >= 0:
                random.seed(seed)
                retain_indices = random.choices(
                    retain_indices, k=min(num_retains, len(retain_indices))
                )
        else:
            retain_indices = []

        if len(forget_indices) < 1 and len(retain_indices) < 1:
            raise Exception(
                "Dataset is empty, assign non None value to atleast num_forgets or num_retains"
            )
        self.indices = forget_indices + retain_indices
        self.forget_indices = forget_indices
        self.retain_indices = retain_indices
        self.forget_labels = forget_labels
        self.num_retains = len(retain_indices)
        self.num_forgets = len(forget_indices)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        feature, label = self.dataset[index]
        return feature, label
