import io
import torch
import os
import pandas as pd
from PIL import Image
import numpy as np

def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result

class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, txt_tokenizer, max_length=256):
        super(TrainDataset, self).__init__()
        self.data_dir = data_dir
        self.data_files = [
            os.path.join(data_dir, 'train-00000-of-00002.parquet'),
            os.path.join(data_dir, 'train-00001-of-00002.parquet')
        ]
        self.txt_tokenizer = txt_tokenizer
        self.max_length = max_length
        
        self.data = self._load_data()

    def _load_data(self):
        data = []
        for file in self.data_files:
            df = pd.read_parquet(file)
            data.extend(df[['image', 'reports']].values.tolist())
        return data

    def __len__(self):
        return len(self.data)

    def _vqgan_input_from(self, img: Image, target_image_size=512) -> torch.Tensor:
        # Resize with aspect ratio preservation.
        s = min(img.size)
        scale = target_image_size / s
        new_size = (round(scale * img.size[0]), round(scale * img.size[1]))
        img = img.resize(new_size, Image.LANCZOS)

        # Center crop.
        x0 = (img.width - target_image_size) // 2
        y0 = (img.height - target_image_size) // 2
        img = img.crop((x0, y0, x0 + target_image_size, y0 + target_image_size))

        # Convert to tensor.
        np_img = np.array(img) / 255.0  # Normalize to [0, 1]
        np_img = np_img * 2 - 1  # Scale to [-1, 1]
        tensor_img = (
            torch.from_numpy(np_img).permute(2, 0, 1).float()
        )  # (Channels, Height, Width) format.

        # Add batch dimension.
        return tensor_img

    def __getitem__(self, idx):
        image_info, report = self.data[idx]
        image_bytes = image_info['bytes']
        
        # Convert image data
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        width, height = image.size
        size = (width, height)

        pad_image = expand2square(image, (122, 116, 104))
        input_image = pad_image.resize((512, 512), Image.LANCZOS)
        image_tensor = self._vqgan_input_from(input_image)

        report = 'The image is a radiograph of the chest, showing the thoracic cavity structures, ' + report

        split_index = report.find(',')
        prompt = report[:split_index].strip()

        self.txt_tokenizer.padding_side = "right"
        prompt_out = self.txt_tokenizer(
            [prompt],
            return_tensors="pt",
            padding="longest",
            max_length=self.max_length,
            truncation=True,
        )
        prompt_out = prompt_out['input_ids'][0]
        prompt_idx = prompt_out.shape[-1]
        del prompt_out

        conversations = [report]
        txt_out = self.txt_tokenizer(
            conversations,
            return_tensors="pt",
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
        )

        txt_ids = txt_out.input_ids[0]
        mask = txt_out.attention_mask[0]

        data = {'image': image_tensor, 'text': txt_ids, 'size': size, 'mask': mask, 'prompt_idx': prompt_idx}
        return data


class ValDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, txt_tokenizer, max_length=256):
        super(ValDataset, self).__init__()
        self.data_dir = data_dir
        self.data_files = [
            os.path.join(data_dir, 'validation-00000-of-00001.parquet')
        ]
        self.txt_tokenizer = txt_tokenizer
        self.max_length = max_length

        self.data = self._load_data()

    def _load_data(self):
        data = []
        for file in self.data_files:
            df = pd.read_parquet(file)
            data.extend(df[['image', 'reports']].values.tolist())
        return data

    def __len__(self):
        return len(self.data)

    def _vqgan_input_from(self, img: Image, target_image_size=512) -> torch.Tensor:
        # Resize with aspect ratio preservation.
        s = min(img.size)
        scale = target_image_size / s
        new_size = (round(scale * img.size[0]), round(scale * img.size[1]))
        img = img.resize(new_size, Image.LANCZOS)

        # Center crop.
        x0 = (img.width - target_image_size) // 2
        y0 = (img.height - target_image_size) // 2
        img = img.crop((x0, y0, x0 + target_image_size, y0 + target_image_size))

        # Convert to tensor.
        np_img = np.array(img) / 255.0  # Normalize to [0, 1]
        np_img = np_img * 2 - 1  # Scale to [-1, 1]
        tensor_img = (
            torch.from_numpy(np_img).permute(2, 0, 1).float()
        )  # (Channels, Height, Width) format.

        # Add batch dimension.
        return tensor_img

    def __getitem__(self, idx):
        image_info, report = self.data[idx]
        image_bytes = image_info['bytes']

        # Convert image data
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        width, height = image.size
        size = (width, height)

        pad_image = expand2square(image, (122, 116, 104))
        input_image = pad_image.resize((512, 512), Image.LANCZOS)
        image_tensor = self._vqgan_input_from(input_image)

        report = 'The image is a radiograph of the chest, showing the thoracic cavity structures, ' + report

        split_index = report.find(',')
        prompt = report[:split_index].strip()

        self.txt_tokenizer.padding_side = "right"
        prompt_out = self.txt_tokenizer(
            [prompt],
            return_tensors="pt",
            padding="longest",
            max_length=self.max_length,
            truncation=True,
        )
        prompt_out = prompt_out['input_ids'][0]
        prompt_idx = prompt_out.shape[-1]
        del prompt_out

        conversations = [report]
        txt_out = self.txt_tokenizer(
            conversations,
            return_tensors="pt",
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
        )

        txt_ids = txt_out.input_ids[0]
        mask = txt_out.attention_mask[0]

        data = {'image': image_tensor, 'text': txt_ids, 'size': size, 'mask': mask, 'prompt_idx': prompt_idx}
        return data

# from evaluation.chameleon import ImageTokenizer
# from transformers import AutoTokenizer
# from torch.utils.data import DataLoader
#
# tokenizer = AutoTokenizer.from_pretrained(r'D:\CKPTS\liquid-7b',padding_side='left')
# print(tokenizer.model_max_length)
# ds = MyDataset(r'D:\CKPTS\liquid-7b\cc12m-train-0000', tokenizer)
# dl = DataLoader(ds, batch_size=2)
# vqgan_cfg_path = r"D:\CKPTS\vqgan\vqgan.yaml"
# vqgan_ckpt_path = r"D:\CKPTS\vqgan\vqgan.ckpt"
# image_tokenizer = ImageTokenizer(cfg_path=vqgan_cfg_path, ckpt_path=vqgan_ckpt_path, device="cuda:0",)
#
# for data in dl:
#     img = data['image'].to('cuda')
#     print(img.shape)
#     _, _, [_, _, vq_code] = image_tokenizer._vq_model.encode(img)
#     vq_code = vq_code.view(img.shape[0], -1)
#     vq_code = vq_code + len(tokenizer)
#
#     print(vq_code.shape)
