import os, json, torch
from pathlib import Path
from src.processing.captioning import ImageCaptioner
from tqdm import tqdm
from src.utils.images import get_images_paths, pad_images_to_max_size, load_images

DATA_PATH = Path(os.environ['NAMING_BIASES_DATA_PATH'])


class DatasetsCommands:
    """
    Commands used to manage datasets.
    """
    
    def gen_captions(self,
                     dataset_name: str,
                     model_config_path: Path | str,
                     prompt: Path | str,
                     caption_folder_name: str = 'captions',
                     max_new_tokens: int = 200,
                     batch_size: int = 1,
                     pad_images: bool = False
                     ) -> None:
        """
        Generates captions for an existing BiasDataset.

        :param dataset_name: name of the BiasDataset folder, the dataset must be inside "$NAMING_BIASES_DATA_PATH/datasets"
        :param model_config_path: path to the yaml config for the captioning model
        :param prompt: prompt to pass to the captioning model, if this is a valid file path the file will be read, otherwise
                       this string is used directly as a prompt
        :param caption_folder_name: name of the folder in which to store the generated captions
        :param max_new_tokens: maximum number of tokens to generate
        :param batch_size: the batch size
        :param pad_images: if True images are padded to be the same size as the biggest one, this is needed because huggingface
                           has a bug when passing a batch size > 1 and images of different sizes.
        """
        imgs_path = DATA_PATH / 'datasets' / dataset_name / 'imgs'
        captions_path = DATA_PATH / 'datasets' / dataset_name / caption_folder_name

        paths = get_images_paths(imgs_path)

        imgs_count = len(paths)

        paths = [p for p in paths if not (captions_path / f'{p.stem}.txt').is_file()]

        print(f'generating captions for {len(paths)} out of {imgs_count} images.')

        captioner = ImageCaptioner(model_config_path)

        captions_path.mkdir(exist_ok=True)

        if Path(prompt).is_file():
            prompt = Path(prompt).read_text().strip()

        # NOTE: I had to manually pad the images because huggingface has a bug currently
        if pad_images:
            all_images = pad_images_to_max_size(paths)
        else:
            all_images = load_images(paths)

        with tqdm(total=len(paths)) as pbar:
            for i in range(0, len(paths), batch_size):
                batch = all_images[i:i+batch_size]
                paths_batch = paths[i:i+batch_size]

                out = captioner.caption(batch, prompt=prompt, max_new_tokens=max_new_tokens, batch_size=batch_size) # type: ignore

                for path, caption in zip(paths_batch, out):
                    new_path = captions_path / f'{path.stem}.txt'
                    new_path.write_text(caption)

                pbar.update(len(batch))

        metadata = {
            'model_conf': captioner.conf,
            'prompt': prompt,
            'max_new_tokens': max_new_tokens
        }

        metadata_path = imgs_path.parent / f'{caption_folder_name}_metadata.json'
        metadata_path.write_text(json.dumps(metadata, indent=4))

        mem = torch.cuda.max_memory_allocated() / (1024 ** 3)

        print(f'Max memory allocated: {mem:.2f} GiB')
