import os.path
from pathlib import Path

import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms

from global_configs import root_dir


class VitDataset(Dataset):
    def __init__(self, images, transform_clip, transform_blip):

        self.images = images
        self.transform_clip = transform_clip
        self.transform_blip = transform_blip
        # self.df2 = df2
        self.root_dir = root_dir

    def __len__(self):
        # return len(self.images) + len(self.df2)
        return len(self.images)

    def __getitem__(self, idx):
        if idx < len(self.images):
            row = self.images.iloc[idx]
            file_name = row['filepath'].split("/")[-1]
            for part in range(0, 30):
                path = os.path.join(self.root_dir, f"v{part}", file_name)
                if os.path.exists(path):
                    final_path = path
                    break

        image = Image.open(final_path)
        image_clip = self.transform_clip(image)
        image_blip = self.transform_blip(image)
        return image_clip, image_blip.data['pixel_values'][0], file_name.split(".")[0]

def get_dataloaders(df1, transform_clip, transform_blip):

    trn_dataset = VitDataset(df1, transform_clip, transform_blip)
    trn_dataloader = DataLoader(
        dataset=trn_dataset,
        shuffle=True,
        batch_size=64,
        pin_memory=True,
        num_workers=4,
        drop_last=False
    )

    return trn_dataloader

