from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
from typing import NamedTuple, List
from src.utils.images import ALLOWED_IMAGE_EXTENSIONS

class BiasExample(NamedTuple):
    bias_id: int
    correctly_classified: List[Image.Image] | List[str]
    incorrectly_classified: List[Image.Image] | List[str]


class BiasDataset(Dataset):
    """
    This dataset works with the "bias dataset" format and returns
    BiasExamples that can contain either the images, or the captions.
    You may want to use it with images if:
        - You are using a multimodal LLM that support multiple images as input
    You may want to use it with captions if:
        - You don't have a multimodal LLM that supports multiple images as input
        and you have already generated the captions

    Args:
        root: path to the dataset folder
        return_captions: whether to return captions instead of images
        caption_folder_name: if return_captions is True, the captions will be read
                             from this folder.
        n_correctly_classified: if not None, the first n_correctly_classified examples in correctly-classified.txt will be selected.
        top_k_incorrectly_classified: if not None, the first top_k_incorrectly_classified examples in incorrectly-classified.txt will be selected.
                                      Note that the incorrectly-classified examples should be sorted by the distance from the decision boundary (descending).
    """
    
    def __init__(self,
                 root: Path | str,
                 return_captions: bool = False,
                 caption_folder_name: str = 'captions',
                 biases_folder_name: str = 'biases',
                 n_correctly_classified: int | None = None,
                 top_k_incorrectly_classified: int | None = None
                 ) -> None:
        super().__init__()

        self.root = Path(root)
        self.return_captions = return_captions
        self.n_correctly_classified = n_correctly_classified
        self.top_k_incorrectly_classified = top_k_incorrectly_classified

        self.imgs_folder = self.root / 'imgs'
        self.captions_folder = self.root / caption_folder_name
        self.biases_folder = self.root / biases_folder_name
        
        self.example_folders = sorted(list(self.biases_folder.glob('*')))
    
    def __len__(self) -> int:
        return len(self.example_folders)
    
    def __getitem__(self, index: int) -> BiasExample:
        bias_folder = self.example_folders[index]
        bias_id = int(bias_folder.name)
        
        correctly_classified_path = bias_folder / 'correctly-classified.txt'
        incorrectly_classified_path = bias_folder / 'incorrectly-classified.txt'

        correctly_classified_ids = correctly_classified_path.read_text().strip().split('\n')
        correctly_classified_ids = correctly_classified_ids[:self.n_correctly_classified]

        incorrectly_classified_ids = incorrectly_classified_path.read_text().strip().split('\n')
        incorrectly_classified_ids = incorrectly_classified_ids[:self.top_k_incorrectly_classified]

        if self.return_captions:
            correctly_classified = [self._load_caption(example_id) for example_id in correctly_classified_ids]
            incorrectly_classified = [self._load_caption(example_id) for example_id in incorrectly_classified_ids]
        else:
            correctly_classified = [self._load_image(example_id) for example_id in correctly_classified_ids]
            incorrectly_classified = [self._load_image(example_id) for example_id in incorrectly_classified_ids]

        return BiasExample(bias_id, correctly_classified, incorrectly_classified)

    def _load_image(self, example_id: str) -> Image.Image:
        for ext in ALLOWED_IMAGE_EXTENSIONS:
            img_path = self.imgs_folder / f'{example_id}.{ext}'
            if img_path.is_file():
                return Image.open(img_path)
        return Image.open(img_path) # Should never happen

    def _load_caption(self, example_id: str) -> str:
        caption_path = self.captions_folder / f'{example_id}.txt'
        text = caption_path.read_text()
        if text.startswith('\\n'): # captions start with \n because it was not escaped correctly, so we remove it
            text = text[2:]
        return text
