import torch
import os
import json
import requests 
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
from PIL import Image

def load_prompts(file_path):
    with open(file_path, 'r') as f:
        samples = []
        if '.json' in file_path:
            for line in f:
                samples.append(json.loads(line)['caption'].strip())
        elif '.txt' in file_path:
            for line in f:
                samples.append(line.strip())
        else:
            raise ValueError('Invalid file type')
    return samples


def load_images(file_path, output_folder):
    with open(file_path, 'r') as f:
        urls = []
        if '.json' in file_path:
            for line in f:
                urls.append((json.loads(line)['index'], json.loads(line)['url'].strip()))
        for i, (index, url) in enumerate(urls):
            try:
                data = requests.get(url, verify=False).content 
                f = open(f'{output_folder}/{i:04d}_{index}.png','wb')
                f.write(data) 
                f.close() 
            except:
                print(f'Error downloading {url}')
                

def transform_image(image):
    image = F.to_tensor(image)
    image = F.resize(image, (512), interpolation=F.InterpolationMode.BILINEAR) 
    image = F.center_crop(image, (512, 512)) 
    image = F.normalize(image, [0.5], [0.5])
    return image

def get_image_transforms(hflip=False):
    if hflip:
        return transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop((512, 512)),
            transforms.RandomHorizontalFlip(),
            transforms.Normalize([0.5], [0.5]),
        ])
    else:
        return transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop((512, 512)),
            transforms.Normalize([0.5], [0.5]),
        ])

def load_and_encode_image(img_path, vae):
    image = Image.open(img_path).convert("RGB")
    image = transform_image(image)
    image = image.unsqueeze(0).cuda()
    latents = vae.encode(image).latent_dist.sample()
    latents = latents * vae.config.scaling_factor
    return latents

class CaptionDataset(torch.utils.data.Dataset):
    def __init__(self, data_folder, transform=None, load_captions=False, return_img_path=False):
        self.transform = transform
        self.data_folder = data_folder
        self.load_captions = load_captions
        self.return_img_path = return_img_path
        self.image_files = sorted([file for file in os.listdir(data_folder) if file.endswith(('.jpg', '.png', 'jpeg'))])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.data_folder, self.image_files[idx])

        if self.return_img_path:
            image = img_path
        else:
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
        
        if self.load_captions:
            caption_path = img_path.replace('.jpg', '.txt').replace('.png', '.txt').replace('.jpeg', '.txt')
            with open(caption_path, 'r') as f:
                caption = f.read().strip()
            return image, caption
        else:
            return image