import torch
from torchvision import transforms
import os
from PIL import Image
import torch.utils.data as data
import numpy as np
import torch.nn.functional as F
from transformers import AutoTokenizer, PretrainedConfig,CLIPImageProcessor, CLIPVisionModelWithProjection,CLIPTextModelWithProjection, CLIPTextModel, CLIPTokenizer
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal


class VitonHDTestDataset(data.Dataset):
    def __init__(
        self,
        dataroot_path: str,
        phase: Literal["train", "test"],
        order: Literal["paired", "unpaired"] = "paired",
        size: Tuple[int, int] = (512, 384),
    ):
        super(VitonHDTestDataset, self).__init__()
        self.dataroot = dataroot_path
        self.phase = phase
        self.height = size[0]
        self.width = size[1]
        self.size = size
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )
        self.toTensor = transforms.ToTensor()

        annotation_list = [
            "sleeveLength",
            "neckLine",
            "item",
        ]

        self.order = order
        self.toTensor = transforms.ToTensor()

        im_names = []
        c_names = []
        dataroot_names = []


        if phase == "train":
            filename = os.path.join(dataroot_path, f"{phase}_pairs.txt")
        else:
            # filename = os.path.join(dataroot_path, f"{phase}_pairs.txt")
            if order == "paired":
                filename = os.path.join(dataroot_path, "data_list.txt")
            else:
                filename = os.path.join(dataroot_path, "unpaired_idm.txt")

        with open(filename, "r") as f:
            for line in f.readlines():
                if phase == "train":
                    im_name, _ = line.strip().split()
                    c_name = im_name
                else:
                    if order == "paired":
                        im_name, _ = line.strip().split()
                        c_name = im_name
                    else:
                        im_name, c_name = line.strip().split()

                im_names.append(im_name)
                c_names.append(c_name)
                dataroot_names.append(dataroot_path)

        self.im_names = im_names
        self.c_names = c_names
        self.dataroot_names = dataroot_names
        self.clip_processor = CLIPImageProcessor()
    
    def return_coords_mask(self, tens, mask = False):
        if not mask:
            _, y, x = np.where(tens > 0.5)
            if y.size == 0 or x.size == 0:
                x_min, y_min = 0,0
                x_max, y_max = tens.shape[2], tens.shape[1]
            else:
                y_max, y_min = np.max(y), np.min(y)
                x_max, x_min = np.max(x), np.min(x)
        else:
            y, x = np.where(tens > 0.5)
            if y.size == 0 or x.size == 0:
                x_min, y_min = 0,0
                x_max, y_max = tens.shape[1], tens.shape[0]
            else:
                y_max, y_min = np.max(y), np.min(y)
                x_max, x_min = np.max(x), np.min(x)
        return y_min, y_max, x_min, x_max
    
    def __getitem__(self, index):
        c_name = self.c_names[index]
        im_name = self.im_names[index]

        cloth_annotation = "shirt"
        cloth = Image.open(os.path.join(self.dataroot, self.phase, "cloth", c_name))

        im_pil_big = Image.open(
            os.path.join(self.dataroot, self.phase, "image", im_name)
        ).resize((self.width,self.height))
        image = self.transform(im_pil_big)




        parse_name = im_name.replace('.jpg', '.png')
        im_parse = Image.open(os.path.join(self.dataroot, self.phase, 'image-parse-v3', parse_name))
        im_parse = im_parse.resize((self.width, self.height), Image.NEAREST)
        # im_parse_final = transforms.ToTensor()(im_parse) * 255
        parse_array = np.array(im_parse)
        parse_cloth = (parse_array == 5).astype(np.float32) + \
                          (parse_array == 6).astype(np.float32) + \
                          (parse_array == 7).astype(np.float32)

        py_min, py_max, px_min, px_max = self.return_coords_mask(parse_cloth, mask = True)
            

        mask = Image.open(os.path.join(self.dataroot, self.phase, "image-parse-agnostic-v3.2", im_name.replace('.jpg','.png'))).resize((self.width,self.height))
        mask = self.toTensor(mask)
        mask = (mask[:1] > 0)*1.0
        im_mask = image * mask
 
        pose_img = Image.open(
            os.path.join(self.dataroot, self.phase, "image-densepose", im_name)
        ).resize((self.width, self.height), Image.NEAREST)
        pose_img = self.transform(pose_img)  # [-1,1]
        cm = Image.open(os.path.join(self.dataroot, self.phase, 'cloth-mask', c_name))
        cm = cm.resize((self.width, self.height))
        cm = self.transform(cm) + 1
        gy_min, gy_max, gx_min, gx_max = self.return_coords_mask(cm)
        result = {}
        result["c_name"] = c_name
        result["im_name"] = im_name
        result["image"] = image
        result["cloth_pure"] = self.transform(cloth)
        result["cloth"] = self.clip_processor(images=cloth, return_tensors="pt").pixel_values
        result["inpaint_mask"] =1-mask
        result["im_mask"] = im_mask
        result["caption_cloth"] = "a photo of " + cloth_annotation
        result["caption"] = "model is wearing a " + cloth_annotation
        result["pose_img"] = pose_img
        result["pcoords"] = ((py_min, py_max), (px_min, px_max))
        result["gcoords"] = ((gy_min, gy_max), (gx_min, gx_max))
        return result

    def __len__(self):
        return len(self.im_names)
