import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os
import numpy as np
import random
import torch.nn.functional as F
from .realesrgan import RealESRGAN_degradation


class ImageTextDegradationDataset(Dataset):
    def __init__(self, gt_image_dir, gt_text_dir, tokenizer=None, transform=None, empty_text_ratio=0.0, resize_bak=False):
        self.gt_image_dir = gt_image_dir
        self.gt_text_dir = gt_text_dir
        self.tokenizer = tokenizer
        self.transform = transform
        self.empty_text_ratio = empty_text_ratio
        self.resize_bak = resize_bak
        self.image_filenames = sorted([f for f in os.listdir(
            gt_image_dir) if os.path.isfile(os.path.join(gt_image_dir, f))])
        self.degradation = RealESRGAN_degradation(
            '/home/SPIDER/dataloaders/params_realesrgan_diffbir.yml', device='cpu')

        self.valid_indices = []
        for i, img_filename in enumerate(self.image_filenames):
            basename = os.path.splitext(img_filename)[0]
            text_filename = f"{basename}.txt"
            text_path = os.path.join(self.gt_text_dir, text_filename)

        self.img_preproc = transforms.Compose([
            transforms.ToTensor(),
        ])

    def tokenize_caption(self, caption=""):
        if self.tokenizer is None:
            return torch.zeros((77,), dtype=torch.long)

        inputs = self.tokenizer(
            caption, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        original_idx = self.valid_indices[idx]
        img_filename = self.image_filenames[original_idx]
        gt_image_path = os.path.join(self.gt_image_dir, img_filename)

        basename = os.path.splitext(img_filename)[0]
        text_filename = f"{basename}.txt"
        gt_text_path = os.path.join(self.gt_text_dir, text_filename)

        try:
            gt_image = Image.open(gt_image_path).convert('RGB')
            gt_img_tensor = self.img_preproc(gt_image)  # 范围 [0, 1]

            if random.random() < self.empty_text_ratio:
                gt_text = ""
            else:
                with open(gt_text_path, 'r', encoding='utf-8') as f:
                    gt_text = f.read().strip()

            gt_image_np = np.array(gt_image) / 255.0
            _, lq_img_tensor = self.degradation.degrade_process(
                gt_image_np, resize_bak=self.resize_bak)
            lq_img_tensor = lq_img_tensor.squeeze(0)

            example = dict()

            example["conditioning_pixel_values"] = lq_img_tensor
            example["pixel_values"] = gt_img_tensor * \
                2.0 - 1.0
            example["input_ids"] = self.tokenize_caption(
                caption=gt_text).squeeze(0)

            return example

        except FileNotFoundError:
            print(f"Error: {gt_image_path} or text: {gt_text_path}")

            if idx + 1 < len(self):
                return self.__getitem__(idx + 1)
            else:
                raise FileNotFoundError(
                    f"can not load{gt_image_path}, {gt_text_path}")
        except Exception as e:
            print(f" {idx} error (image: {img_filename}): {e}")
            if idx + 1 < len(self):
                return self.__getitem__(idx + 1)
            else:
                raise e


if __name__ == '__main__':
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        '/home/models/Group_AAA_Sr/sd2_base/models--stabilityai--stable-diffusion-2-base/snapshots/fa386bb446685d8ad8a8f06e732a66ad10be6f47/tokenizer')
    gt_image_folder = '/home/data/FFHQ512x512'
    gt_text_folder = '/home/data/FFHQ_text2'

    dataset = ImageTextDegradationDataset(
        gt_image_dir=gt_image_folder,
        gt_text_dir=gt_text_folder,
        tokenizer=tokenizer,
        empty_text_ratio=1
    )

    from torch.utils.data import DataLoader
    dataloader = DataLoader(dataset, batch_size=4,
                            shuffle=True, num_workers=0)  

    for batch in dataloader:
        print("LQ Images Shape:", batch["conditioning_pixel_values"].shape)
        print("GT Images Shape:", batch["pixel_values"].shape)
        print("Input IDs Shape:", batch["input_ids"].shape)

        import matplotlib.pyplot as plt
        import numpy as np

        lq_img = batch["conditioning_pixel_values"][0].permute(
            1, 2, 0).cpu().numpy()
        lq_img = np.clip(lq_img, 0, 1)
        gt_img = (batch["pixel_values"][0].permute(
            1, 2, 0).cpu().numpy() + 1) / 2
        plt.figure(figsize=(6, 6))
        plt.imshow(lq_img)
        plt.title("low resolution input image")
        plt.savefig('lq_image.png')
        plt.close()

        plt.figure(figsize=(6, 6))
        plt.imshow(gt_img)
        plt.title("high resolution target image")
        plt.savefig('gt_image.png')
        plt.close()

        input_ids = batch["input_ids"][0]
        text = tokenizer.decode(input_ids, skip_special_tokens=True)
        print("input text:", text)

        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(lq_img)
        plt.title("low resolution input image")

        plt.subplot(1, 2, 2)
        plt.imshow(gt_img)
        plt.title("high resolution target image")
        plt.suptitle(f"input text: {text}")
        plt.savefig('comparison.png')
        plt.close()

        break
