import torch
import os
import glob
import re
from PIL import Image
import matplotlib.pyplot as plt
import requests
import pickle
import numpy as np
from collections import Counter
from PIL import Image
from torchvision import transforms
from sklearn.decomposition import PCA

from scipy.ndimage import label
from tqdm import tqdm


def load_images_from_folder(folder):
    images = []
    folder_dir = sorted([f for f in os.listdir(folder)])

    for filename in folder_dir:
        img_path = os.path.join(folder, filename)
        img = Image.open(img_path)
        img_array = np.array(img)
        images.append(img_array)
    return np.stack(images)

def load_images_from_folder_clevrtex_multi(folder):
    images = []
    pattern = re.compile(r'^CLEVRTEXv2_full_\d+_\d+\.png$')
    second_folders = sorted([f for f in os.listdir(folder)])
    #print(second_folders)
    for second_folder in second_folders:
        print(second_folder)
        for third_folder in os.listdir(folder+"/"+second_folder):
            #print(third_folder)
            filenames = sorted([f for f in os.listdir(folder+"/"+second_folder+"/"+third_folder) if f.endswith('.png')])
            for filename in filenames:
                #print(filename)
                if pattern.match(filename):
                    img_path = os.path.join(folder+"/"+second_folder+"/"+third_folder, filename)
                    img = Image.open(img_path)
                    img_array = np.array(img)
                    images.append(img_array)
    return np.stack(images)


class FeatureGeneratorMulti:
    '''
        Second part of Section 4.1 to create segmented image and image patches
    '''
    def __init__(self, cell_size=None):
        # Dimensions of the images
        self.small_image_height, self.small_image_width = 37, 37
        self.large_image_height, self.large_image_width = 518, 518

        # Calculate the scaling factors
        self.scale_y = self.large_image_height / self.small_image_height
        self.scale_x = self.large_image_width / self.small_image_width

        self.step_size = 1
        if cell_size is not None:
            self.cell_size = cell_size
        else:
            self.cell_size = 14 # Dino patch size #5 # Size of each grid cell

        print("CELL_SIZE: ", self.cell_size)
        self.transform2 = transforms.Compose([           
                                transforms.Resize(520),
                                transforms.CenterCrop(518), #should be multiple of model patch_size                 
                                ])

        self.image_index = 0

    def map_pixel(self, small_y, small_x, scale_y, scale_x):
        """
        Map a pixel from the smaller image to the larger image.
        
        Parameters:
        small_y (int): y-coordinate in the smaller image
        small_x (int): x-coordinate in the smaller image
        
        Returns:
        (int, int): Corresponding (y, x) coordinates in the larger image
        """
        large_y = int(small_y * scale_y)
        large_x = int(small_x * scale_x)
        return large_y, large_x
    

    def make_bounding_boxes_square(self,threshold_feature, patch_size):
        """
        Adjust bounding boxes composed of patches to make them perfectly square.

        Args:
            threshold_feature: 2D binary mask where 1 indicates active \( 15 \times 15 \) patches.
            patch_size: Size of each patch (e.g., 15).

        Returns:
            square_mask: Updated binary mask with square bounding boxes.
        """
        # Label connected patches
        labeled_mask, num_features = label(threshold_feature)

        # Create a new mask for the square bounding boxes
        square_mask = np.zeros_like(threshold_feature)

        for label_idx in range(1, num_features + 1):
            # Find the bounding box for the current label
            y_coords, x_coords = np.where(labeled_mask == label_idx)
            min_y, max_y = y_coords.min(), y_coords.max()
            min_x, max_x = x_coords.min(), x_coords.max()

            # Calculate the width and height of the bounding box
            width = max_x - min_x + 1
            height = max_y - min_y + 1

            # Determine the size to make the bounding box square
            max_dim = max(width, height)

            # Expand the bounding box to make it square
            square_min_y = max(0, min_y)
            square_max_y = min(threshold_feature.shape[0], min_y + max_dim - 1)

            square_min_x = max(0, min_x)
            square_max_x = min(threshold_feature.shape[1], min_x + max_dim - 1)

            # Fill the new square bounding box in the mask
            square_mask[square_min_y:square_max_y + 1, square_min_x:square_max_x + 1] = 1

        return square_mask

    
    def get_packets_single(self, filtered_image, threshold_feature, cell_size):
        img = filtered_image[:3, :, :]
        img_t = self.transform2(torch.from_numpy(img / 255.0))
        orig_image_up = np.transpose(img_t.numpy(), (1, 2, 0))

        # Adjust the mask to make bounding boxes square
        patch_size = 15
        square_mask = self.make_bounding_boxes_square(threshold_feature, patch_size)


        self.i_range = np.arange(0, self.small_image_height, self.step_size)
        self.j_range = np.arange(0, self.small_image_width, self.step_size)
        # Initialize the larger image
        large_image = np.zeros((self.large_image_height, self.large_image_width, 3))

        index = 0
        packet_list = []

        # Generate segmented image
        for i in self.i_range:
            for j in self.j_range:
                large_y, large_x = self.map_pixel(i, j, self.scale_y, self.scale_x)
                if square_mask[int(i)][int(j)] > 0:
                    packet = []
                    for dy in range(cell_size):
                        for dx in range(cell_size):
                            if (large_y + dy < self.large_image_height) and (large_x + dx < self.large_image_width):
                                large_image[large_y + dy, large_x + dx] = orig_image_up[large_y + dy, large_x + dx]

                    if len(packet) == cell_size ** 2:
                        packet_list.append([packet, index])

                index += 1

        # Generate image patches
        index = 0
        for i in self.i_range:
            for j in self.j_range:
                large_y, large_x = self.map_pixel(i, j, self.scale_y, self.scale_x)
                if threshold_feature[int(i)][int(j)] > 0:
                    packet = []
                    for dy in range(cell_size):
                            for dx in range(cell_size):
                                if (large_y + dy < self.large_image_height) and (large_x + dx < self.large_image_width):
                                    packet.append(orig_image_up[large_y + dy, large_x + dx])

                    #print("PACKET: ", np.stack(packet).shape)
                    if len(packet) == cell_size**2:
                        packet_list.append([packet, index])
                       #print("PACKET_O: ", np.stack(packet_list).shape)
                    
                index += 1

        self.image_index += 1
        return packet_list, large_image



    def get_packets(self, filtered_image, threshold_feature, cell_size):
        img = np.moveaxis(filtered_image, -1, 0)
        img = img[:3, :, :]
        img_t = self.transform2(torch.from_numpy(img / 255.))
        orig_image_up = np.transpose(img_t.numpy(), (1,2,0))

        # Initialize the larger image
        large_image = np.zeros((self.large_image_height, self.large_image_width, 3))
        #print(threshold_feature.shape)

        # Adjust the mask to make bounding boxes square
        patch_size = 10
        square_mask = self.make_bounding_boxes_square(threshold_feature, patch_size)

        # Create a denser grid
        self.i_range = np.arange(0, self.small_image_height, self.step_size)
        self.j_range = np.arange(0, self.small_image_width, self.step_size)
        # Check if pca_features True or False and then add original image if true
        
        # Generate segmented image
        index = 0
        packet_list = []
        for i in self.i_range:
            for j in self.j_range:
                large_y, large_x = self.map_pixel(i, j, self.scale_y, self.scale_x)
                if square_mask[int(i)][int(j)] > 0: # Mask from PCA
                    packet = []
                    for dy in range(cell_size):
                            for dx in range(cell_size):
                                if (large_y + dy < self.large_image_height) and (large_x + dx < self.large_image_width):
                                    large_image[large_y + dy, large_x + dx] = orig_image_up[large_y + dy, large_x + dx]


                    
                index += 1

        # Generate image patches
        index = 0
        for i in self.i_range:
            for j in self.j_range:
                large_y, large_x = self.map_pixel(i, j, self.scale_y, self.scale_x)
                if threshold_feature[int(i)][int(j)] > 0: # Mask from PCA
                    packet = []
                    for dy in range(cell_size):
                            for dx in range(cell_size):
                                if (large_y + dy < self.large_image_height) and (large_x + dx < self.large_image_width):
                                    packet.append(orig_image_up[large_y + dy, large_x + dx])

                    if len(packet) == cell_size**2:
                        packet_list.append([packet, index])
                    
                index += 1


        self.image_index += 1

        return packet_list, large_image
    

    def get_segmentation(self, packet_list, cell_size):
        #print(np.stack(np.stack(packet_list)[:,1]))

        packet_array = np.stack(np.stack(packet_list)[:,0]).reshape(-1, cell_size**2*3)
        packet_index = np.stack(np.stack(packet_list)[:,1])
        filtered_packet_array = packet_array

        return packet_array, filtered_packet_array, packet_index, None

    

class FeatureExtractorMulti:
    '''
        Class that defines the masking process with PCA (Section 4.1)
    '''
    def __init__(self, model=None):

        if model is None:
            model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').cuda()
        self.model = model

        self.transform = transforms.Compose([           
                                    transforms.Resize(256),                    
                                    transforms.CenterCrop(224),               
                                    transforms.Normalize(                      
                                    mean=[0.485, 0.456, 0.406],                
                                    std=[0.229, 0.224, 0.225]              
                                    )])


        self.transform1 = transforms.Compose([           
                                        transforms.Resize(520),
                                        transforms.CenterCrop(518), #should be multiple of model patch_size                 
                                        transforms.Normalize(mean=0.5, std=0.2)
                                        ])


    def get_features_single(self, filtered_images, feature_dim=384):
        self.patch_size = self.model.patch_size # patchsize=14

        #520//14
        self.patch_h  = 520//self.patch_size
        self.patch_w  = 520//self.patch_size

        self.feat_dim = feature_dim # vits14
        #feat_dim = 768 # vitb14x
        #feat_dim = 1024 # vitl14
        #feat_dim = 1536 # vitg14
        self.num_images = 50

        total_features1  = []
        total_cls = []
  
        for img in filtered_images:
            #print("IMG: ", img)
            #img = np.moveaxis(img, -1, 0)
            img = img[:3, :, ]
            #print(img.shape)
            

            img_t = self.transform1(torch.from_numpy(img.astype(float))).cuda().float()/ 255.
            #print(img_t.shape)
            features_dict = self.model.forward_features(img_t.unsqueeze(0))
            features = features_dict['x_norm_patchtokens'].detach().cpu()
            features_cls = features_dict['x_norm_clstoken'].detach().cpu()
            total_features1.append(features)
            total_cls.append(features_cls)
            
        total_features1 = torch.cat(total_features1, dim=0)
        total_cls = torch.cat(total_cls, dim=0)
        total_features1.shape

        return total_features1, total_cls

    def get_features(self, filtered_images, feature_dim=384):
        self.patch_size = self.model.patch_size # patchsize=14

        #520//14
        self.patch_h  = 520//self.patch_size
        self.patch_w  = 520//self.patch_size

        self.feat_dim = feature_dim # vits14
        #feat_dim = 768 # vitb14x
        #feat_dim = 1024 # vitl14
        #feat_dim = 1536 # vitg14
        self.num_images = 50

        total_features1  = []
        total_cls = []
  
        for img in filtered_images:
            img = np.moveaxis(img, -1, 0)
            img = img[:3, :, ]

            img_t = self.transform1(torch.from_numpy(img.astype(float))).cuda().float()/ 255.
            #print(img_t.shape)
            features_dict = self.model.forward_features(img_t.unsqueeze(0))
            features = features_dict['x_norm_patchtokens'].detach().cpu()
            features_cls = features_dict['x_norm_clstoken'].detach().cpu()
            total_features1.append(features)
            total_cls.append(features_cls)
            
        total_features1 = torch.cat(total_features1, dim=0)
        total_cls = torch.cat(total_cls, dim=0)
        total_features1.shape

        return total_features1, total_cls
    
    def perform_pca(self, total_features1, total_cls):
        # First PCA to Seperate Background
        # sklearn expects 2d array for traning
        total_features1 = total_features1.reshape(self.num_images * self.patch_h * self.patch_w, self.feat_dim) #4(*H*w, 1024)

        pca = PCA(n_components=3)
        pca.fit(total_features1)
        pca_features = pca.transform(total_features1)

        pca_features[:, 0] = (pca_features[:, 0] - pca_features[:, 0].min()) / (pca_features[:, 0].max() - pca_features[:, 0].min())

        median_feature_value = np.median(pca_features[:, 0])

        pca_features_bg = pca_features[:, 0] >  median_feature_value  # Thresholding against median
        pca_features_fg = ~pca_features_bg

        # 2nd PCA for only foreground patches
        pca.fit(total_features1[pca_features_fg]) 
        pca_features_left = pca.transform(total_features1[pca_features_fg])

        for i in range(3):
            # min_max scaling
            pca_features_left[:, i] = (pca_features_left[:, i] - pca_features_left[:, i].min()) / (pca_features_left[:, i].max() - pca_features_left[:, i].min())

        pca_features_rgb = pca_features.copy()
        # for black background
        pca_features_rgb[pca_features_bg] = 0
        # new scaled foreground features
        pca_features_rgb[pca_features_fg] = pca_features_left
        #print(pca_features_left.shape)


        # reshaping to numpy image format
        pca_features_rgb1 = pca_features_rgb.reshape(self.num_images, self.patch_h, self.patch_w, 3)
        pca_features_rgb1 = pca_features_rgb1[:, :, :, 0] # RGB to BGR
        average_pca_features_rgb1 = np.mean(pca_features_rgb1[pca_features_rgb1 != 0])
        thresholded_pca_features_rgb1 = np.where(pca_features_rgb1 < average_pca_features_rgb1, pca_features_rgb1, 0)

        #for i in range(self.num_images):
        #    plt.subplot(10, 5, i+1)
        #    plt.imshow(thresholded_pca_features_rgb1[i])

        #plt.show()

        return thresholded_pca_features_rgb1


if __name__ == '__main__':
    # Create image patches to train VAE
    transform_segmented = transforms.Compose([           
                                transforms.Resize((98, 98)),              
                                ])
    
    transform_out = transforms.Compose([           
                                transforms.Resize((64, 64)),              
                                ])

    feature_extractor = FeatureExtractorMulti()
    feature_generator = FeatureGeneratorMulti()

    folder_path = '/home/stefan/Downloads/CLEVR_v1.0/images/val/'
    images = load_images_from_folder(folder_path)
    image_index = 0

    print("LEN IMAGES: ", images.shape)

    all_segments = []
    for i in tqdm(range(0, 5000, 50)):
        filtered_images = images[i:i+50]

        total_features1, total_cls = feature_extractor.get_features(filtered_images)
        thresholded_features = feature_extractor.perform_pca(total_features1, total_cls)

        for j in range(50):
            packet_list, large_image = feature_generator.get_packets(filtered_images[j], thresholded_features[j], cell_size=14)
            #plt.imshow(large_image)
            #plt.show()
            print("LEN PACKET_LIST: ", len(packet_list))
            if len(packet_list) > 0: 
                packet_array, image_segments, index_list, index_list_segmentation = feature_generator.get_segmentation(packet_list, cell_size=14)

                print("OUTPUT SEGMENTS: ", len(image_segments))
                for segment in image_segments:
                    all_segments.append(segment)

                    if image_index % 10000 == 0:
                        print("IMAGE_INDEX: ", image_index)
                        # Remove the oldest files if there are more than 2
                        files = sorted(glob.glob(os.path.join("pickled_output/clevr/", '*')), key=os.path.getctime)
                        while len(files) >= 2:
                            os.remove(files.pop(0))
                        # Pickle the all_segments
                        output_file_path = os.path.join("pickled_output/", 'all_segments_14__'+str(image_index)+'.pkl')
                        pickle.dump(all_segments[:len(all_segments)], open(output_file_path, 'wb'))
                    image_index += 1
