import torch
import os
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from torchvision.transforms import ToPILImage
from data import Dataset, DataLoader
import conf
from sklearn.cluster import KMeans
import torch.nn as nn
import torch.nn.functional as F
import cv2
os.environ["CUDA_VISIBLE_DEVICES"] = "3"


def keep_largest_connected_component(mask):
    if mask.dim() > 2:
        mask = mask.squeeze(0) 
    if mask.is_cuda:
        mask = mask.cpu()
    mask_np = (mask.numpy() * 255).astype(np.uint8)
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask_np, connectivity=8)
    largest_component = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])

    largest_mask = (labels == largest_component).astype(np.uint8)
    result = torch.from_numpy(largest_mask).float().unsqueeze(0)
    #print(f'Result:{result.shape}')
    return torch.from_numpy(largest_mask).float()

def process_masks(mask, erode_iter=1, dilate_iter=1):
    if mask.is_cuda:
        mask = mask.cpu()
    mask_np = mask.numpy()
    if mask_np.ndim > 2:
        mask_np = mask_np[0]

    mask_uint8 = (mask_np * 255).astype(np.uint8)
    kernel = np.ones((3,3), np.uint8)
    
    eroded_mask = cv2.erode(mask_uint8, kernel, iterations=erode_iter)
    dilated_mask = cv2.dilate(eroded_mask, kernel, iterations=dilate_iter)

    processed_mask = dilated_mask.astype(np.float32) / 255
    result = torch.from_numpy(processed_mask).unsqueeze(0)
    # print(f'result:{result.shape}') [1,1024, 768]
    
    return result  

def save_mask(mask, filename):

    mask_img = ToPILImage()(mask)
    mask_img.save(filename)

def binary_threshold(mask, threshold=None):
    # If no specific threshold is provided, use the mean value of the mask
    if threshold is None:
        threshold = mask.mean().item()
    return (mask > threshold).float()

class ClothSegmentationModel(nn.Module):
    def __init__(self, embeddings_path):
        super(ClothSegmentationModel, self).__init__()
        # Image 
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1)  # Output: [512, 384, 32]
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) # Output: [256, 192, 64]
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) # Output: [128, 96, 128]

        self.prompts = ["clothes"]
        self.embeddings = torch.load(embeddings_path)
        self.emb_dict = {'clothes':self.embeddings[0]}

        self.fc1 = nn.Linear(512, 2048)
        self.fc2 = nn.Linear(2048, 128)

        self.conv_fusion = nn.Conv2d(128, 1, kernel_size=1)  # 1x1 conv

    def forward(self, images, text_prompt_ids, mode, save_path=None):
        if mode == 'train':
            # Image 
            x = self.conv1(images)
            x = F.relu(x)
            x = self.conv2(x)
            x = F.relu(x)
            x = self.conv3(x)
            x = F.relu(x)
            #conv3_output = x.clone()
            
            text_features = torch.stack([self.embeddings[text_prompt_id.item()] for text_prompt_id in text_prompt_ids]).cuda()
            t = F.relu(self.fc1(text_features))
            t = self.fc2(t)

            # Expand text features to match image feature map size
            t = t.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, x.size(2), x.size(3))
            
            # Fusion and mask prediction
            x = x * t
            x = self.conv_fusion(x)
            mask = torch.sigmoid(x)  

            # Upsample to match the input image size
            mask = F.interpolate(mask, size=(1024, 768), mode='bilinear', align_corners=False) # [8, 1, 1024, 768]
            mask = torch.squeeze(mask, dim=1) # [8, 1024, 768]
            #print(f'mask:{mask.shape}')

            return mask

        elif mode == 'test':
            x = self.conv1(images)
            x = F.relu(x)
            conv1_output = x.clone()
            x = self.conv2(x)
            x = F.relu(x)
            x = self.conv3(x)
            x = F.relu(x)
            conv3_output = x.clone()

            text_features = torch.stack([self.embeddings[text_prompt_id.item()] for text_prompt_id in text_prompt_ids]).cuda()
            t = F.relu(self.fc1(text_features))
            t = self.fc2(t)
            print(f't:{t.shape}')
            t = t.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, x.size(2), x.size(3))
            
            print(f'x:{x.shape}')

            x = x * t
            x = self.conv_fusion(x)
            mask = torch.sigmoid(x)

            binary_mask = binary_threshold(mask[0])
            refined_binary_mask = binary_threshold(mask[0], binary_mask.mean()) # [1, 128, 96]
            mask_upsampled = F.interpolate(refined_binary_mask.unsqueeze(0), size=conv3_output.shape[2:], mode='bilinear', align_corners=False)
            masked_features = conv3_output * mask_upsampled

            refined_binary_mask = F.interpolate(refined_binary_mask.unsqueeze(0), size=(1024, 768), mode='bilinear', align_corners=False).squeeze(1)

            features_for_clustering = masked_features.permute(0, 2, 3, 1).reshape(-1, masked_features.shape[1])
            n_clusters = 3
            kmeans = KMeans(n_clusters=n_clusters, n_init=10, random_state=0).fit(features_for_clustering.cpu().detach().numpy())
            labels = kmeans.labels_

            # Calculate the mean value of each cluster to identify the background cluster
            cluster_means = [features_for_clustering[labels == i].mean() for i in range(n_clusters)]
            #background_cluster = np.argmin(cluster_means)
            cluster_means_tensor = torch.tensor(cluster_means)  # Convert list of means to a PyTorch tensor
            background_cluster = torch.argmin(cluster_means_tensor).item()

            # Create masks from labels, ignoring the background cluster
            label_mask = torch.tensor(labels, device=masked_features.device).reshape(masked_features.shape[0], mask_upsampled.shape[2], mask_upsampled.shape[3])
            masks = [(label_mask == i).float() for i in range(n_clusters) if i != background_cluster]

            # Upsample masks to match the input image size
            masks = [F.interpolate(mask.unsqueeze(1), size=(1024, 768), mode='bilinear', align_corners=False).squeeze(1) for mask in masks]

            # Save masks
            save_mask(masks[0], os.path.join(save_path, f'mask_upper.png'))  # Assuming first non-background cluster is upper clothes
            save_mask(masks[1], os.path.join(save_path, f'mask_lower.png'))  # Assuming second non-background cluster is trousers       
            save_mask(refined_binary_mask, os.path.join(save_path, f'mask_refined.png'))


            refined_mask = keep_largest_connected_component(refined_binary_mask)
            #refined_mask = keep_largest_connected_component(masks[0])
            #refined_mask = keep_largest_connected_component(masks[1])
            processed_masks = process_masks(refined_mask, erode_iter=10, dilate_iter=10)
            for idx, mask in enumerate(processed_masks):
                save_mask(mask, os.path.join(save_path, f'mask_processed_{idx}.png'))

            return refined_binary_mask, masks[0], masks[1]



def load_model(model_path, device, args):
    model = ClothSegmentationModel(args.embeddings_path).to(device)
    model = torch.nn.DataParallel(model)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    return model

def save_mask(mask, filename):
    mask_img = ToPILImage()(mask)
    mask_img.save(filename)

def binary_threshold(mask, threshold=None):
    # If no specific threshold is provided, use the mean value of the mask
    if threshold is None:
        threshold = mask.mean().item()
    return (mask > threshold).float()

def main():
    args = conf.get_conf()
    print(f'args.exp_name:{args.exp_name}')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_path = 'your path to pretrain model'  # 替换为你的模型文件路径
    folder_name = filename_with_ext = os.path.basename(model_path)
    folder_name = os.path.splitext(folder_name)[0]
    model = load_model(model_path, device, args)

    image_path = 'your path to model/garment image'  
    prompt_id = 0  # 0 for upper, 1 for pants, 2 for dress


    transform = transforms.Compose([
        transforms.Resize((1024, 768)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])


    image = Image.open(image_path).convert('RGB')
    npnpnp = np.array(image.getdata())
    image = transform(image).unsqueeze(0).to(device)  


    save_path = './mask_result'
    save_path = os.path.join(save_path, folder_name)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    with torch.no_grad():
        mask, upper_clothes_mask, trousers_mask = model(images=image, text_prompt_ids=torch.tensor([prompt_id], dtype=torch.int64).to(device), mode='test', save_path=save_path)





if __name__ == '__main__':
    main()
    