import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import numpy as np
from tqdm import tqdm
from einops import rearrange, repeat
from omegaconf import OmegaConf
from PIL import Image
from diffusers.models import attention_processor
from torchvision.transforms.functional import to_pil_image
from diffusers import DDIMScheduler, DiffusionPipeline
from masactrl.diffuser_utils import MasaCtrlPipeline
from masactrl.masactrl_utils import AttentionBase
from masactrl.masactrl_utils import regiter_attention_editor_diffusers
from masactrl.masactrl import MutualSelfAttentionControl
from diffusiontrend_utils import My_StableDiffusionXLPipeline
from torchvision.utils import save_image
from torchvision.io import read_image
from pytorch_lightning import seed_everything
import subprocess
import cv2

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
torch.cuda.set_device(0) 
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_path = "/xxxxxxxxx"

scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
model = My_StableDiffusionXLPipeline.from_pretrained(model_path, scheduler=scheduler).to(device)
model.enable_vae_tiling()



def change_tensor_to_np(dict_data):
    for key in dict_data.keys():
        dict_data[key] = dict_data[key].cpu().numpy()
    return dict_data

@torch.no_grad()
def self_tokenize(self, prompt):
    tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
    text_inputs = self.tokenizer(
        prompt,
        padding="max_length",
        max_length=self.tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    ).input_ids
    text_tokens = self.tokenizer.convert_ids_to_tokens(text_inputs[0])
    text_dict = {}
    for i, token in enumerate(text_tokens):
        text_dict[i] = token
        if "endoftext" in token:
            break
    return text_dict

@torch.no_grad()
def get_attn(emb, res, layer):
    def hook(self, sd_in, sd_out):
        if "attn2" in layer:
            key = self.to_k(emb)
        else:
            key = self.to_k(sd_in[0])
        query = self.to_q(sd_in[0])
        heads = self.heads
        query, key = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=heads), (query, key))
        attn = torch.einsum("b i d, b j d -> b i j", query, key)
        attn = attn * self.scale
        attn = attn.softmax(dim=-1)
        res[layer] = attn
    return hook

def to_norm(x):
    return (x - np.min(x)) / (np.max(x) - np.min(x))


def viusalize_latents(latents: torch.FloatTensor, result_path: str, t: int):
        mean_latents = latents.mean(dim=1).squeeze().to(torch.float32)
        mean_latents_np = mean_latents.cpu().numpy()
        mean_latents_np = (mean_latents_np - mean_latents_np.min()) / (mean_latents_np.max() - mean_latents_np.min()) * 255
        image = Image.fromarray(mean_latents_np)
        image = image.convert("RGB")
        if os.path.exists(result_path) == False:
            os.makedirs(result_path)
        image.save(f'{result_path}/mean_latents_{t}.png')

def load_image(image_path):
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    image = read_image(image_path)
    image = image[:3].unsqueeze_(0).float() / 127.5 - 1.  # [-1, 1]
    image = F.interpolate(image, (1024, 768))
    image = image.to(device)
    return image

@torch.no_grad()
def image2latent(image):
    DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    if type(image) is Image:
        image = np.array(image)
        image = torch.from_numpy(image).float() / 127.5 - 1 # transfer to pytorch tensor and norm
        image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE) # b,c,h,w
    
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" 
    if hasattr(torch.cuda, 'empty_cache'):
        torch.cuda.empty_cache()

    latents = model.vae.encode(image)['latent_dist'].mean
    end_memory_usage0 = get_gpu_memory_usage()
    latents = latents * 0.18215
    return latents

def next_step(
        model_output: torch.FloatTensor,
        timestep: int,
        x: torch.FloatTensor,
        eta=0.,
        verbose=False
    ):
    """
    Inverse sampling for DDIM Inversion
    """
    if verbose:
        print("timestep: ", timestep)
    next_step = timestep
    timestep = min(timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps, 999)
    alpha_prod_t = scheduler.alphas_cumprod[timestep] if timestep >= 0 else scheduler.final_alpha_cumprod
    alpha_prod_t_next = scheduler.alphas_cumprod[next_step]
    beta_prod_t = 1 - alpha_prod_t
    pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
    pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
    x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
    return x_next, pred_x0

@torch.no_grad()
def invert(image: torch.Tensor,
        prompt,
        num_inference_steps=50,
        guidance_scale=1,
        eta=0.0,
        selected_mode_prompt=None,
        return_intermediates=False,
        **kwds):
        """
        invert a real image into noise map with determinisc DDIM inversion
        """
        DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        batch_size = image.shape[0] + 1 # batchsize = 1 to calc attn-map

        if isinstance(prompt, list):
            if batch_size == 1:
                image = image.expand(len(prompt), -1, -1, -1)
        elif isinstance(prompt, str):
            if batch_size > 1:
                prompt = [prompt] * batch_size
                prompt[-1] += selected_mode_prompt  

        print('---Inversion Start---')

        # define initial latents
        image.to(device)
        latents = image2latent(image) 
        latents = torch.cat([latents] * 2)
        start_latents = latents

        # interative sampling
        scheduler.set_timesteps(num_inference_steps)
        print("Valid timesteps: ", reversed(scheduler.timesteps))
        latents_list = [latents]
        pred_x0_list = [latents]

        
        do_classifier_free_guidance = guidance_scale > 1.0 # cfg=0
        # prompt in inversion process is empty ''
        ( prompt_embeds, negative_prompt_embeds, 
            pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = model.encode_prompt(
            prompt=prompt,
            do_classifier_free_guidance=do_classifier_free_guidance,
            negative_prompt = ''
            #negative_prompt='blurry, ugly, duplicate, poorly drawn, deformed, mosaic',
        )

        add_text_embeds = pooled_prompt_embeds
        if model.text_encoder_2 is None:
            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
        else:
            text_encoder_projection_dim = model.text_encoder_2.config.projection_dim
            # 1280

        # here code copy from sdxl, original/target size should be changed as the size changing
        original_size = (1024, 768)
        crops_coords_top_left = (0,0)
        target_size = (1024, 768)
        add_time_ids = model._get_add_time_ids(
            original_size,
            crops_coords_top_left,
            target_size,
            dtype=prompt_embeds.dtype,
            text_encoder_projection_dim=text_encoder_projection_dim,
        )
        negative_add_time_ids = add_time_ids

        if do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0).to(device)
            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0).to(device)
            add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0).to(device).repeat(batch_size * 1, 1)
        else:
            prompt_embeds = prompt_embeds.to(device)
            add_text_embeds = add_text_embeds.to(device)
            add_time_ids = add_time_ids.to(device).repeat(batch_size * 1, 1)

        prompt_embeds_input = torch.cat([prompt_embeds] * 1)
        add_text_embeds_input = torch.cat([add_text_embeds] * 1)
        add_time_ids_input = torch.cat([add_time_ids] * 1)
            
        added_cond_kwargs = {"text_embeds": add_text_embeds_input, "time_ids": add_time_ids_input}


        for i, t in enumerate(tqdm(reversed(scheduler.timesteps), desc="DDIM Inversion")):
            model_inputs = latents

            # predict the noise
            noise_pred = model.unet(
                model_inputs, 
                t, 
                #encoder_hidden_states=text_embeddings, 
                encoder_hidden_states=prompt_embeds_input,
                added_cond_kwargs=added_cond_kwargs,
                ).sample
            
            # compute the previous noise sample x_t-1 -> x_t
            latents, pred_x0 = next_step(noise_pred, t, latents)
            latents_list.append(latents)
            pred_x0_list.append(pred_x0)

        for handle in cross_hooks:
            handle.remove()

        if return_intermediates:
            # return the intermediate laters during inversion
            # pred_x0_list = [latent2image(img, return_type="pt") for img in pred_x0_list]
            return latents, latents_list, cross_attn_map_res
        return latents, start_latents, cross_attn_map_res

def stretch_bounding_box(top_left, bottom_right, width_scale=1, height_scale=1):
    xmin, ymin = top_left
    xmax, ymax = bottom_right
    
    center_x = (xmin + xmax) / 2
    
    original_width = xmax - xmin
    original_height = ymax - ymin
    new_width = original_width * width_scale
    new_height = original_height * height_scale
    
    new_xmin = center_x - new_width / 2
    new_xmax = center_x + new_width / 2


    new_ymax = ymin + new_height
    
    new_xmin = round(new_xmin)
    new_xmax = round(new_xmax)
    new_ymax = round(new_ymax)

    return (new_xmin, ymin), (new_xmax, new_ymax)

def calc_hw_scale(p_mask, g_mask, p_top_left_index, p_bottom_right_index, g_top_left_index, g_bottom_right_index, option):

    if option == 'upper':
        p_xmin, p_ymin = p_top_left_index
        p_xmax, p_ymax = p_bottom_right_index
        p_bbox = p_mask[p_ymin:p_ymax, p_xmin:p_xmax]

        g_xmin, g_ymin = g_top_left_index
        g_xmax, g_ymax = g_bottom_right_index
        g_bbox = g_mask[g_ymin:g_ymax, g_xmin:g_xmax]

        p_height = p_ymax - p_ymin
        g_height = g_ymax - g_ymin

        p_row_index = p_ymin + int(p_height * 0.85) 
        g_row_index = g_ymin + int(g_height * 0.85)

        p_row = p_mask[p_row_index, p_xmin:p_xmax]
        g_row = g_mask[g_row_index, g_xmin:g_xmax]

        p_length = torch.sum(p_row == 1).item()
        g_length = torch.sum(g_row == 1).item()


        width_scale = g_length / p_length
        height_scale = 1

    elif option == 'lower':
        p_xmin, p_ymin = p_top_left_index
        p_xmax, p_ymax = p_bottom_right_index
        p_bbox = p_mask[p_ymin:p_ymax, p_xmin:p_xmax]

        g_xmin, g_ymin = g_top_left_index
        g_xmax, g_ymax = g_bottom_right_index
        g_bbox = g_mask[g_ymin:g_ymax, g_xmin:g_xmax]

        p_height = p_ymax - p_ymin
        g_height = g_ymax - g_ymin

        p_row_index = p_ymin + int(p_height * 0.025)
        g_row_index = g_ymin + int(g_height * 0.025)

        p_row = p_mask[p_row_index, p_xmin:p_xmax]
        g_row = g_mask[g_row_index, g_xmin:g_xmax]


        p_length = torch.sum(p_row == 1).item()
        g_length = torch.sum(g_row == 1).item()


        width_scale = g_length / p_length
        
        g_trousers_len = g_height / g_length
        p_trousers_len = p_height / p_length
        height_scale = g_trousers_len / p_trousers_len # 按比例缩放



    elif option == 'dress':
        p_xmin, p_ymin = p_top_left_index
        p_xmax, p_ymax = p_bottom_right_index
        p_bbox = p_mask[p_ymin:p_ymax, p_xmin:p_xmax]

        g_xmin, g_ymin = g_top_left_index
        g_xmax, g_ymax = g_bottom_right_index
        g_bbox = g_mask[g_ymin:g_ymax, g_xmin:g_xmax]

        p_height = p_ymax - p_ymin
        g_height = g_ymax - g_ymin

        p_row_index = p_ymin + int(p_height * 0.28) 
        g_row_index = g_ymin + int(g_height * 0.28)

        p_row = p_mask[p_row_index, p_xmin:p_xmax]
        g_row = g_mask[g_row_index, g_xmin:g_xmax]

        p_length = torch.sum(p_row == 1).item()
        g_length = torch.sum(g_row == 1).item()
        

        width_scale = g_length / p_length
        g_trousers_len = g_height / g_length
        p_trousers_len = p_height / p_length
        height_scale = g_trousers_len / p_trousers_len 
    
    else:
        print('###Option Error!###')
        width_scale = -1
        height_scale = -1

    print(f'width_scale and height_scale has been calc automatically. p_len:{p_length}  g_len:{g_length}')
    return width_scale, height_scale

def real_image_editing(model, person_image, origin_garment_image, garment_image, p_mask, g_mask, option, width_scale, height_scale, rotate_degree):
    seed = 42
    seed_everything(seed)

    out_dir_ori = "./tryon_result"
    os.makedirs(out_dir_ori, exist_ok=True)

    prompts = [
        "A clothes", 
        "A model wearing clothes", 
    ]

    STEP = 4
    LAYER_LIST = [64]

    # invert the image into noise map
    if isinstance(person_image, np.ndarray):
        person_image = torch.from_numpy(person_image).to(device) / 127.5 - 1.
        person_image = person_image.unsqueeze(0).permute(0, 3, 1, 2)
        person_image = F.interpolate(person_image, (1024, 768))

    if isinstance(garment_image, np.ndarray):
        garment_image = torch.from_numpy(garment_image).to(device) / 127.5 - 1.
        garment_image = garment_image.unsqueeze(0).permute(0, 3, 1, 2)
        garment_image = F.interpolate(garment_image, (1024, 768))
        

    selected_mode_prompt = [
        'clothes t-shirt shirt sweater jacket polo pullover', # upper
        #'upper shirt jacket', # upper
        'clothes trousers skirt shorts jeans leggings sweatpants capris slacks', # lower
        #'clothes trousers skirt shorts',
        'clothes dress suit uniform', # suit
        #'clothes dress',
    ]
    selected_mode = 2 # [0:upper, 1:lower, 2:suit]
    prompt_len = len(selected_mode_prompt[selected_mode].split(" ")) + 2
    # 0 startoftext; -1 endoftext
    if selected_mode == 0 and 't-shirt' in selected_mode_prompt[selected_mode]:
        prompt_len += 2 # t-shirt
    print(f'attention calc prompt_len:{prompt_len}')

    p_top_left_index, p_bottom_right_index = find_mask_bounding_box(p_mask)
    g_top_left_index, g_bottom_right_index = find_mask_bounding_box(g_mask) # 0-x, 1-y

    if width_scale == None and height_scale == None: # calc scale auto
        width_scale, height_scale = calc_hw_scale(p_mask, g_mask, p_top_left_index, p_bottom_right_index, g_top_left_index, g_bottom_right_index, option)  
    else:
        if width_scale == None:
            width_scale = 1
        if height_scale == None:
            height_scale = 1        
        print(f'Scale have been specified. None changed to 1.')
    print(f'width_scale:{width_scale}  height_scale:{height_scale}')


    p_top_left_index_expand, p_bottom_right_index_expand = stretch_bounding_box(p_top_left_index, p_bottom_right_index, width_scale=width_scale, height_scale=height_scale)

    g_only = torch.full_like(garment_image, 0, dtype=torch.float)
    g_content = garment_image * g_mask + g_only * (1-g_mask) 

    g_content_box = g_content[:, :, g_top_left_index[1]:g_bottom_right_index[1], g_top_left_index[0]:g_bottom_right_index[0]]

    p_box_height = p_bottom_right_index_expand[1] - p_top_left_index_expand[1]
    p_box_width = p_bottom_right_index_expand[0] - p_top_left_index_expand[0]
    g_content_resized = F.interpolate(g_content_box, size=(p_box_height, p_box_width), mode='bilinear', align_corners=False)

    
    g_only[:, :, p_top_left_index_expand[1]:p_bottom_right_index_expand[1], p_top_left_index_expand[0]:p_bottom_right_index_expand[0]] = g_content_resized
    
    g_mask_box = g_mask[g_top_left_index[1]:g_bottom_right_index[1], g_top_left_index[0]:g_bottom_right_index[0]]
    g_mask_box = g_mask_box.unsqueeze(0).unsqueeze(0)

    g_mask_resized = F.interpolate(g_mask_box, size=(p_box_height, p_box_width), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)

    
    g_mask_only = torch.full_like(g_mask, 0, dtype=torch.float)
    g_mask_only[p_top_left_index_expand[1]:p_bottom_right_index_expand[1], p_top_left_index_expand[0]:p_bottom_right_index_expand[0]] = g_mask_resized


    # to get inversion result zt and attn map
    empty_prompt = ''
    start_code_p, latents_list_p, cross_attn_map_res_p = invert(person_image,
                                            empty_prompt, 
                                            selected_mode_prompt=selected_mode_prompt[selected_mode],
                                            return_intermediates=True)
    # here remember to select bs=0 of latents_list, bs=1 to get attn-map
    latents_list_p=[latent[:1] for latent in latents_list_p]
    torch.cuda.empty_cache()

    start_code_g, latents_list_g, cross_attn_map_res_g = invert(g_only,
                                            empty_prompt, 
                                            selected_mode_prompt=selected_mode_prompt[selected_mode],
                                            return_intermediates=True)
    latents_list_g=[latent[:1] for latent in latents_list_g]
    torch.cuda.empty_cache()


                                        

    start_code_p = start_code_p[:1] # only get bs=0 as zt
    start_code_g = start_code_g[:1] # only get bs=0 as zt

    pmask_resized = F.interpolate(p_mask.unsqueeze(0).unsqueeze(0), size=(128, 96), mode='bilinear', align_corners=False)
    gmask_resized = F.interpolate(g_mask.unsqueeze(0).unsqueeze(0), size=(128, 96), mode='bilinear', align_corners=False)

    '''random_noise = torch.randn_like(start_code_p)
    start_code_noise = random_noise * pmask_resized + start_code_p * (1 - pmask_resized)
    start_code_noise = start_code_noise.to(torch.float32)'''

    # start_code = torch.cat([start_code_g, start_code_noise], dim=0)
    start_code = torch.cat([start_code_g, start_code_p], dim=0)




    print('---Denoise Start---')
    for LAYER in LAYER_LIST:

        # inference the synthesized image
        per_ratio = 0
        gar_ratio = 0.3
        random_noise_ratio = 0.7
        
        image_masactrl = model(prompts, 
                                latents=start_code, 
                                guidance_scale=7.5, 
                                bg_intermediate_latents=latents_list_p, 
                                latents_list_g=latents_list_g,
                                garment_mask=g_mask_only,
                                person_mask=p_mask, 
                                step_replace_latent=10,
                                step_to_change=35,
                                p_top_left_index_expand=p_top_left_index_expand,
                                p_bottom_right_index_expand=p_bottom_right_index_expand,
                                label_height_factor = 0.1,
                                label_width_factor = 0.2, 
                                per_ratio = per_ratio,
                                gar_ratio = gar_ratio,
                                random_noise_ratio = random_noise_ratio).images

        sample_count = len(os.listdir(out_dir_ori))
        out_dir = os.path.join(out_dir_ori, f"sample_{sample_count}")
        os.makedirs(out_dir, exist_ok=True)

        print(f'image_masactrl:{len(image_masactrl)}')
        
        save_image_pil(person_image, os.path.join(out_dir, f"person_image.png"))
        save_image_pil(origin_garment_image, os.path.join(out_dir, f"ori_garment_image.png"))
        save_image_pil(garment_image, os.path.join(out_dir, f"warp_garment_image.png"))
        image_masactrl[0].save(os.path.join(out_dir, f"garment_inversion.png")) # prompt[0]
        image_masactrl[1].save(os.path.join(out_dir, f"result.png")) # prompt[1]
        with open(os.path.join(out_dir, f"settings.txt"), "w") as f:
            for p in prompts:
                f.write(p + "\n")
            f.write(f"seed: {seed}\n")
            f.write(f"width_scale: {width_scale}\n")
            f.write(f"height_scale: {height_scale}\n")
            f.write(f"per: {per_ratio}   gar: {gar_ratio}  noise: {random_noise_ratio}\n")


        print("Try-On images are saved in", os.path.join(out_dir, f"result.png"))

def is_binary_tensor(tensor):
    return torch.all((tensor == 0.0) | (tensor == 1.0))

def find_mask_bounding_box(image, scale_factor=8):
    image = image.cpu()
    image_np = np.array(image)

    if image_np.ndim == 3:
        image_np = image_np[:, :, 0]

    rows = np.any(image_np, axis=1)
    cols = np.any(image_np, axis=0)
    
    ymin, ymax = np.where(rows)[0][[0, -1]]
    xmin, xmax = np.where(cols)[0][[0, -1]]

    return (xmin, ymin), (xmax, ymax)

def find_latent_mask_bounding_box(image, scale_factor=8):
    image = image.cpu()
    image_np = np.array(image)

    if image_np.ndim == 3:
        image_np = image_np[:, :, 0]
    rows = np.any(image_np, axis=1)
    cols = np.any(image_np, axis=0)
    ymin, ymax = np.where(rows)[0][[0, -1]]
    xmin, xmax = np.where(cols)[0][[0, -1]]
    ymin = round( ymin / scale_factor )
    ymax = round( ymax / scale_factor )
    xmin = round( xmin / scale_factor )
    xmax = round( xmax / scale_factor )

    return (xmin, ymin), (xmax, ymax)

def save_image_pil(tensor, save_path):

    tensor = tensor.cpu()
    tensor = tensor.squeeze(0)
    tensor = (tensor + 1) / 2
    tensor = torch.clamp(tensor, 0, 1) 

    image = to_pil_image(tensor)
    image.save(save_path)

def save_mask_pil(tensor, save_path):

    tensor = tensor.cpu()
    tensor = tensor.squeeze(0)

    tensor = torch.clamp(tensor, 0, 1) 
    image = to_pil_image(tensor)
    image.save(save_path)


def rotate_image(image_tensor, rotation_degrees):
    if image_tensor.dim() == 4 and image_tensor.size(0) == 1:
        image_tensor = image_tensor.squeeze(0)  # Remove batch dimension if it's a single image in the batch
    elif image_tensor.dim() != 3:
        raise ValueError("Expected a 3-dimensional tensor or a single-image batch.")

    # Convert tensor to numpy array
    image_np = image_tensor.permute(1, 2, 0).cpu().numpy()
    #print(image_np.shape)
    if image_np.shape[2] == 3:  # Check if the image has 3 channels
        image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)

    # Get image dimensions
    rows, cols, _ = image_np.shape
    
    src_points = np.float32([[0, 0], [cols, 0], [0, rows], [cols, rows]])

    # Calculate the scaling factor and vertical displacement based on the rotation degrees
    if rotation_degrees >= 0:
        horizontal_scale = 0.55 + (np.cos(np.radians(rotation_degrees)) - 0.1) * 0.5
        vertical_displacement = 0.000001 * np.sin(np.radians(rotation_degrees))
        # Destination points for right rotation
        dest_points = np.float32([
            [0, 0], 
            [cols * horizontal_scale, rows * (0.2 + vertical_displacement)], 
            [0, rows], 
            [cols * horizontal_scale, rows * (0.8 - vertical_displacement)]
        ])
    else:
        abs_degrees = abs(rotation_degrees)
        horizontal_scale = 0.55 + (np.cos(np.radians(abs_degrees)) - 0.1) * 0.5
        vertical_displacement = 0.000001 * np.sin(np.radians(abs_degrees))
        # Destination points for left rotation
        dest_points = np.float32([
            [cols * (1 - horizontal_scale), rows * (0.2 - vertical_displacement)], 
            [cols, 0],  # Right top corner stays fixed
            [cols * (1 - horizontal_scale), rows * (0.8 + vertical_displacement)], 
            [cols, rows]  # Right bottom corner stays fixed
        ])
    
    # Compute the perspective transform matrix
    perspective_matrix = cv2.getPerspectiveTransform(src_points, dest_points)
    
    # Apply the perspective transformation
    transformed_image_np = cv2.warpPerspective(image_np, perspective_matrix, (cols, rows))
    
    # Convert back to tensor and transfer to GPU
    #transformed_image_tensor = torch.from_numpy(cv2.cvtColor(transformed_image_np, cv2.COLOR_BGR2RGB)).permute(2, 0, 1).to(torch.float32).cuda()

    if transformed_image_np.ndim == 2:  # If the result is a single-channel image
        transformed_image_tensor = torch.from_numpy(transformed_image_np).unsqueeze(0).to(torch.float32).cuda()
    else:
        transformed_image_tensor = torch.from_numpy(cv2.cvtColor(transformed_image_np, cv2.COLOR_BGR2RGB)).permute(2, 0, 1).to(torch.float32).cuda()
    
    #print(transformed_image_tensor.shape)
    return transformed_image_tensor


if __name__ == "__main__":

    person_path = 'your path to model image'
    p_mask_path = 'your path to model mask'
    garment_path = 'your path to garment image'
    g_mask_path = 'your path to garment mask'

    person_image = load_image(person_path)
    p_mask = Image.open(p_mask_path)
    p_mask = p_mask.convert('L')
    p_mask = to_norm(np.array(p_mask))
    p_mask = torch.from_numpy(p_mask).to(device)

    garment_image = load_image(garment_path)
    g_mask = Image.open(g_mask_path)
    g_mask = g_mask.convert('L')
    g_mask = to_norm(np.array(g_mask))
    g_mask = torch.from_numpy(g_mask).to(device)
    

    if hasattr(torch.cuda, 'empty_cache'):
        torch.cuda.empty_cache()

    option = 'upper'
    # option: upper, lower, dress
    width_scale = 1.3
    height_scale = 1.35

    rotate_degree = 0
    if not rotate_degree == 0:
        warped_garment_mask = rotate_image(g_mask.unsqueeze(0), rotate_degree).squeeze(0)
        warped_garment_image = rotate_image(garment_image, rotate_degree).unsqueeze(0) 
    else:
        warped_garment_mask = g_mask
        warped_garment_image = garment_image

    real_image_editing(model, person_image, garment_image, warped_garment_image, p_mask, warped_garment_mask, option, width_scale, height_scale, rotate_degree)
