import numpy as np
import torch
import json
import os

from PIL import Image

from eval_classification import transform_fns
from scripts import icl_helpers
from utils import get_class_to_id_mapping, get_paths

class ClassificationDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        image_processor,
        nn_file,
        image_root,
        icl_k=2,
        transform_fn=None,
        **kwargs
    ):
        super().__init__()
        self.image_processor = image_processor

        self.nns = np.load(nn_file)['idcs']
        self.label_to_idx = get_class_to_id_mapping(image_root)

        if transform_fn:
            transform_fn = transform_fns[transform_fn]
        else:
            transform_fn = lambda x: x

        self.paths = []
        self.labels = []
        for cls in self.label_to_idx:
            cls_paths = get_paths(os.path.join(image_root, cls))
            self.paths += cls_paths
            self.labels += [transform_fn(self.label_to_idx[cls]) for _ in cls_paths]

        self.image_root = image_root
        self.icl_k = icl_k

    def __len__(self):
        return len(self.paths)

    def open(self, path):
        img = Image.open(path)
        return self.image_processor(img)

    def __getitem__(self, idx):
        path, label = self.paths[idx], self.labels[idx]
        nns = self.nns[idx]

        nn_images = []
        nn_labels = []
        for nn in nns:
            nn_images.append(self.open(self.paths[nn]))
            nn_labels.append(self.labels[nn])

        x = self.open(path)

        return x, label, nn_images, nn_labels


def get_dataset(image_processor, config):
    test_dataset = eval(config["dataset_kwargs"]["dataset_type"])(image_processor, **config["dataset_kwargs"])
    return test_dataset
