import os
import numpy as np
import torch
import torchvision
from torch.utils.data import Dataset
from PIL import Image


class BaseDataset(Dataset):
    def __init__(self, split="train", transform=None, num_tasks=4, seed=42):
        """
        Args:
            split (string): One of 'train', 'valid', 'test'
            transform (callable, optional): Optional transform to be applied on a sample.
            num_tasks (int, optional): Total number of tasks to split the data into.
            seed (int, optional): Seed for random operations.
        """

        self.num_tasks = num_tasks
        self.sub_dataset = None
        self.dataset = None
        self.attribute_names = None
        self.annotations = None
        self.split = split
        self.transform = transform
        self.seed = seed

    def split_data_by_task(self, task_index):
        # split the data into num_tasks subsets
        # based on the split and the task_index, return the subset of the data as a new Dataset:
        assert (
            task_index < self.num_tasks
        ), "task_index must be less than num_tasks but got %d and %d" % (
            task_index,
            self.num_tasks,
        )
        assert self.annotations is not None, "annotations must be set before splitting"
        assert self.split not in ['valid', 'test'], "split must be 'train' for split_data_by_task func usage"
        data_size = len(self.annotations)
        task_size = data_size // self.num_tasks
        start_index = task_index * task_size
        end_index = start_index + task_size
        if (
            task_index == self.num_tasks - 1
        ):  # Ensure the last task gets the remaining data
            end_index = data_size
        self.sub_dataset = self.annotations[start_index:end_index]
        return self

    def __len__(self):
        return len(self.sub_dataset)

    def __getitem__(self, idx):
        raise NotImplementedError

    def random_split(self, subsample_rate=0.8):
        """similar to this:
        train_dataset, _ = torch.utils.data.random_split(
                train_dataset,[train_size,(1 train_size),],)
        and return self with the new split
        """
        assert (
            self.annotations is not None
        ), "annotations must be set before splitting the data"
        indices = list(range(len(self.annotations)))
        np.random.seed(self.seed)
        np.random.shuffle(indices)
        split = int(np.floor(subsample_rate * len(self.annotations)))
        # subsampling the data
        self.annotations = [self.annotations[i] for i in indices[:split]]
        return self
