import os
import sys
import json
import re
import pickle
from tqdm import tqdm
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image

import numpy as np
import pandas as pd
from PIL import Image

from random import randint
from sklearn.model_selection import train_test_split


from matplotlib import pyplot as plt

from dino_masking import FeatureExtractorMulti, FeatureGeneratorMulti
from model import VAE


def load_images_from_folder(folder):
    images = []
    folder_dir = sorted([f for f in os.listdir(folder)])

    for filename in folder_dir:
        img_path = os.path.join(folder, filename)
        img = Image.open(img_path)
        img_array = np.array(img)
        images.append(img_array)
    return np.stack(images)

def load_images_from_folder_clevrtex_multi(folder):
    images = []
    pattern = re.compile(r'^CLEVRTEXv2_full_\d+_\d+\.png$') # Change this to the regex pattern of your images
    second_folders = sorted([f for f in os.listdir(folder)])
    #print(second_folders)
    for second_folder in second_folders:
        print(second_folder)
        for third_folder in os.listdir(folder+"/"+second_folder):
            #print(third_folder)
            filenames = sorted([f for f in os.listdir(folder+"/"+second_folder+"/"+third_folder) if f.endswith('.png')])
            for filename in filenames:
                #print(filename)
                if pattern.match(filename):
                    img_path = os.path.join(folder+"/"+second_folder+"/"+third_folder, filename)
                    print(img_path)
                    img = Image.open(img_path)
                    img_array = np.array(img)
                    images.append(img_array)
    return np.stack(images)

def load_images_from_folder_clevrtex_single(folder):
    images = []
    pattern = re.compile(r'^LOCAL_0-500_\d+\.png$') # Change this to the regex pattern of your images

    filenames = sorted([f for f in os.listdir(folder) if f.endswith('.png')])
    for filename in filenames:
        if pattern.match(filename):
            img_path = os.path.join(folder, filename)
            img = Image.open(img_path)
            img_array = np.array(img)
            images.append(img_array)
    return np.stack(images)

def load_json_multi(path):
    # Open and load the JSON file
    datas = []
    second_folders = sorted([f for f in os.listdir(path)])
    for second_folder in second_folders:
            print(second_folder)
            for third_folder in os.listdir(path+"/"+second_folder):
                filenames = sorted([f for f in os.listdir(path+"/"+second_folder+"/"+third_folder) if f.endswith('.json')])
                #print("JSON_FILENAMES: ", filenames)
                for filename in filenames:
                    with open(path+"/"+second_folder+"/"+third_folder+"/"+filename, 'r') as file:
                        datas.append(json.load(file))
   
    return datas

def load_json_single(path):
    # Open and load the JSON file
    datas = []
    filenames = sorted([f for f in os.listdir(path) if f.endswith('.json')])
    #print("JSON_FILENAMES: ", filenames)
    for filename in filenames:
        with open(path+filename, 'r') as file:
            datas.append(json.load(file))
   
    return datas

def load_json(path):
    # Open and load the JSON file
    with open(path, 'r') as file:
        data = json.load(file)

    return data


    
class EmbeddingDataSet:
    def __init__(self):

        self.output_string = 'multi'
        self.dataset_size = 500

        self.model15 = VAE(image_channels=3).to('cuda')
        self.model15.load_state_dict(torch.load('model_outputs/vae_14_14_3_60.torch', map_location='cuda'))
        # large model
        self.dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').cuda()
        self.dino_feature_dim = 384

        self.transform_dino = transforms.Compose([
                        transforms.Resize((98, 98)),  
                        ])  

        self.transform_out = transforms.Compose([           
                                    transforms.Resize((64, 64)),              
                                    ])

        self.transform2 = transforms.Compose([           
                                transforms.Resize(520),
                                transforms.CenterCrop(518), #should be multiple of model patch_size                 
                                ])

    def load_data_single(self):
        folder_path_single = '/home/stefan/Downloads/clevr-dataset-gen/output/images2/'
        images_single = load_images_from_folder(folder_path_single)

        folder_path_multi = '/home/stefan/Downloads/CLEVR_v1.0/images/val/'
        images_multi = load_images_from_folder(folder_path_multi)

        # Import csv file with csv reader with the metadata for single images
        metadata_path_single = '/home/stefan/Downloads/clevr-dataset-gen/output/scenes/'
        data_single = load_json_single(metadata_path_single)


        return images_single, images_multi, data_single 


    def load_data_multi(self):
        folder_path_multi = '/home/stefan/Downloads/CLEVR_v1.0/images/val/'
        images_multi = load_images_from_folder(folder_path_multi)

        json_path_val = '/home/stefan/Downloads/CLEVR_v1.0/scenes/CLEVR_val_scenes.json'
        json_data_val = load_json(json_path_val)

        return images_multi, json_data_val
    
    def preprocess_data_attributes_single(self, data_single, images_single):
        all_color_items_s = []
        all_shape_items_s = []
        all_size_items_s = []
        all_material_items_s = []
        all_pixelcoords_items_s = []
        all_3dcoords_items_s = []
        all_numobjects_items_s = []

        for data in data_single:
            print(data.keys())
            print(len(data))
            print(data.keys())
            #print(data['num_objects']) 
            print(data['objects'][0]['color'])
            print(data['objects'][0]['shape'])
            print(data['objects'][0]['size'])
            print(data['objects'][0]['material'])
            print(data['objects'][0]['pixel_coords'])
            print(data['objects'][0]['3d_coords'])

            all_color_items_s.append([data['objects'][0]['color']])
            all_shape_items_s.append([data['objects'][0]['shape']])
            all_size_items_s.append([data['objects'][0]['size']])
            all_material_items_s.append([data['objects'][0]['material']])
            all_pixelcoords_items_s.append([data['objects'][0]['pixel_coords']])
            all_3dcoords_items_s.append([data['objects'][0]['3d_coords']])
            all_numobjects_items_s.append([1])
        
        all_data_items_s = [images_single, all_color_items_s, all_shape_items_s, all_size_items_s, all_material_items_s, all_pixelcoords_items_s, all_3dcoords_items_s, all_numobjects_items_s]

        return all_data_items_s     

    def preprocess_data_attributes_multi(self, json_data_val, images_multi):
        print(np.stack(json_data_val['scenes'])[0])

        data = np.stack(json_data_val['scenes'])

        all_objects = [item['objects'] for item in data if 'objects' in item]

        all_color_items = []
        all_shape_items = []
        all_size_items = []
        all_material_items = []
        all_pixelcoords_items = []
        all_3dcoords_items = []
        all_numobjects_items = []

        for object in all_objects:
            color_item = []
            shape_item = []
            size_item = []
            material_item = []
            pixelcoords_item = []
            threedcoords_item = []
            numobjects_item = len(object)

            for object_item in object:
                color_item.append(object_item['color'])
                shape_item.append(object_item['shape'])
                size_item.append(object_item['size'])
                material_item.append(object_item['material'])
                pixelcoords_item.append(object_item['pixel_coords'])
                threedcoords_item.append(object_item['3d_coords'])
                
                #print(object_item)

            all_color_items.append(color_item)
            all_shape_items.append(shape_item)
            all_size_items.append(size_item)
            all_material_items.append(material_item)
            all_pixelcoords_items.append(pixelcoords_item)
            all_3dcoords_items.append(threedcoords_item)
            all_numobjects_items.append(numobjects_item)

        print(len(all_color_items))
        img_idx = 2
        print(all_color_items[img_idx])
        print(all_shape_items[img_idx])
        print(all_size_items[img_idx])
        print(all_material_items[img_idx])
        print(all_pixelcoords_items[img_idx])
        print(all_3dcoords_items[img_idx])


        all_data_items_m = [images_multi, all_color_items, all_shape_items, all_size_items, all_material_items, all_pixelcoords_items, all_3dcoords_items, all_numobjects_items]

        return all_data_items_m

    def get_test_data_splits(self, all_data_items):
        splits = train_test_split(*all_data_items, test_size=0.9, random_state=42)
        train_images_multi, test_images_multi, train_color_items, test_color_items, train_shape_items, test_shape_items, train_size_items, test_size_items, train_material_items, test_material_items, train_pixelcoords_items, test_pixelcoords_items, train_3dcoords_items, test_3dcoords_items, train_numobjects_items, test_numobjects_items = splits

        return test_images_multi, test_color_items, test_shape_items, test_size_items, test_material_items, test_pixelcoords_items, test_3dcoords_items, test_numobjects_items


    def get_segmentation_packets(self, image, X_test, num_images, test_images_multi=None):
        if test_images_multi is None:
            random_indices = np.random.choice(range(len(X_test)), num_images-1, replace=False)
            random_samples = X_test[random_indices]
            filtered_images = np.concatenate((random_samples, np.expand_dims(image, 0)), axis=0)

        else:
            random_indices = np.random.choice(range(len(test_images_multi)), num_images-5, replace=False)
            random_indices_single = np.random.choice(range(len(X_test)), 4, replace=False)
            random_samples = test_images_multi[random_indices]
            random_samples_single = X_test[random_indices_single]
            random_samples = random_samples[:, :, :, :3]

            random_samples_resized = []
            random_samples_single_resized = []

            for i in range(len(random_samples)):
                random_samples_resized.append(self.transform2(torch.moveaxis(torch.from_numpy(random_samples[i].astype(float)).cuda(), -1, 0)).float().detach().cpu().numpy())
            for i in range(len(random_samples_single)):
                random_samples_single_resized.append(self.transform2(torch.moveaxis(torch.from_numpy(random_samples_single[i].astype(float)).cuda(), -1, 0)).float().detach().cpu().numpy())
            
            image = self.transform2(torch.moveaxis(torch.from_numpy(image.astype(float)).cuda(), -1 , 0)).float().detach().cpu().numpy()
            
            random_samples_single_resized = np.stack(random_samples_single_resized)
            random_samples_resized = np.stack(random_samples_resized)
            filtered_images = np.concatenate((random_samples_resized, random_samples_single_resized, np.expand_dims(image, 0)), axis=0)       

        return filtered_images

    
    def get_projections(self, test_images, test_attributes, FeatureExtractor, FeatureGenerator, test_images_multi=None):
        # Get projection of all test images
        feature_extractor = FeatureExtractor(self.dinov2_vitl14)
        feature_generator = FeatureGenerator()

        X_test = test_images[:, :, :, :3] 

        print(X_test.shape)
        #print(X_test.shape)

        num_images = 50
        cell_size = 14

        full_images = []
        large_images = []
        segmented_images = []
        square_regions = [] 
        dino_full_atts = []
        dino_projection_atts = []
        projections = []
        features = []

        error_counter = 0

        for i, image in enumerate(tqdm(X_test, desc='Getting embeddings')):
            #print("Processing image: ", i)
            # Get 10 random samples from the test set
            filtered_images = self.get_segmentation_packets(image, X_test, num_images, test_images_multi)

            if test_images_multi is not None:
                total_features1, total_cls = feature_extractor.get_features_single(filtered_images, self.dino_feature_dim)
                thresholded_features = feature_extractor.perform_pca(total_features1, total_cls)
                packet_list, large_image = feature_generator.get_packets_single(filtered_images[-1], thresholded_features[-1], cell_size=cell_size)

            else:
                total_features1, total_cls = feature_extractor.get_features(filtered_images, self.dino_feature_dim)
                thresholded_features = feature_extractor.perform_pca(total_features1, total_cls)
                packet_list, large_image = feature_generator.get_packets(filtered_images[-1], thresholded_features[-1], cell_size=cell_size)

            if len(packet_list) > 0:
                _, image_segments, _, _ = feature_generator.get_segmentation(packet_list, cell_size=cell_size)

                full_images.append(image)
                                
                features_dict_image_t = self.dinov2_vitl14.forward_features(torch.moveaxis(torch.from_numpy(large_image), -1, 0).float().cuda().unsqueeze(0))
                dino_full_atts.append(features_dict_image_t['x_norm_clstoken'].detach().cpu())
                dino_projection_atts.append(total_cls[-1])

                print("Image segments: ", image_segments.shape)
                
                image_segments = image_segments.reshape(-1, cell_size, cell_size, 3)
                image_segment_vae = self.transform_out(torch.tensor(np.transpose(image_segments, (0, 3, 1, 2)))).cuda().float()
                image_segment_vae = image_segment_vae.cuda().float()
                encoded_segment = self.model15.encode(image_segment_vae)[0].detach().cpu().numpy()
                projections.append(encoded_segment)

                 # Free up CUDA memory
                del image_segment_vae
                del total_cls
                del total_features1

                if i % 10 == 0:  # Call empty_cache less frequently
                    torch.cuda.empty_cache()
            else:
                print("Error in segmentation")
                error_counter += 1
                # Adjust the index for the number of errors encountered
                adjusted_index = i - error_counter
                # Delete error segmentations from attributes with np.delete
                test_attributes[0] = np.delete(test_attributes[0], adjusted_index, axis=0)
                test_attributes[1] = np.delete(test_attributes[1], adjusted_index, axis=0)
                test_attributes[2] = np.delete(test_attributes[2], adjusted_index, axis=0)
                test_attributes[3] = np.delete(test_attributes[3], adjusted_index, axis=0)
                test_attributes[4] = np.delete(test_attributes[4], adjusted_index, axis=0)
                test_attributes[5] = np.delete(test_attributes[5], adjusted_index, axis=0)

            # Save the embeddings
            if i % 500 == 0:
                print("SAVING EMBEDDINGS")
                self.save_embeddings(full_images, large_images, square_regions, dino_full_atts, dino_projection_atts, projections, features, test_attributes)



    def save_embeddings(self, full_images, large_images, square_regions, dino_full_atts, dino_projection_atts, projections, features, test_attributes):
        full_image_data = {'full_images': full_images, 'large_images': large_images, 'square_regions': square_regions, 'dino_full_atts': dino_full_atts, 'dino_projection_atts': dino_projection_atts, 'projections': projections, 'features': features}
        full_image_metadata = {'color': test_attributes[0], 'shape': test_attributes[1], 'size': test_attributes[2], 'material': test_attributes[3], 'pixel_coords': test_attributes[4], '3d_coords': test_attributes[5], 'num_objects': test_attributes[6]}
        with open('pickled_embeddings/'+self.output_string+'_imagef_clevrtex_data_'+str(self.dataset_size)+'_384.pickle', 'wb') as handle:
            pickle.dump(full_image_data, handle, protocol=pickle.HIGHEST_PROTOCOL)

        with open('pickled_embeddings/'+self.output_string+'_imagef_clevrtex_metadata_'+str(self.dataset_size)+'_384.pickle', 'wb') as handle:
            pickle.dump(full_image_metadata, handle, protocol=pickle.HIGHEST_PROTOCOL)


    def get_single_projections(self, num_samples=1000):
        print("LOADING DATA")
        images_single, images_multi, data_val = self.load_data_single()

        print("PREPROCESSING DATA")

        all_data_items = self.preprocess_data_attributes_single(data_val, images_single)

        print(all_data_items)

        print("SPLITTING DATA")
        test_images_single, test_color_items_s, test_shape_items_s, test_size_items_s, test_material_items_s, test_pixelcoords_items_s, test_3dcoords_items_s, test_numobjects_items_s = self.get_test_data_splits(all_data_items)

        print(len(test_images_single))

        #indices = np.random.choice(range(len(test_images_single)), num_samples+10, replace=False)

        test_images_single = np.array(test_images_single)[:num_samples]
        test_images_multi = np.array(images_multi)[:num_samples]
        test_color_items_s = np.array(test_color_items_s)[:num_samples]
        test_shape_items_s = np.array(test_shape_items_s)[:num_samples]
        test_size_items_s = np.array(test_size_items_s)[:num_samples]
        test_material_items_s = np.array(test_material_items_s)[:num_samples]
        test_pixelcoords_items_s = np.array(test_pixelcoords_items_s)[:num_samples]
        test_3dcoords_items_s = np.array(test_3dcoords_items_s)[:num_samples]
        test_numobjects_items_s = np.array(test_numobjects_items_s)[:num_samples]

        test_image_attributes_single = [test_color_items_s, test_shape_items_s, test_size_items_s, test_material_items_s, test_pixelcoords_items_s, test_3dcoords_items_s, test_numobjects_items_s]

        print("GETTING EMBEDDINGS")

        # Get embeddings
        self.get_projections(test_images_single, test_image_attributes_single, FeatureExtractorMulti, FeatureGeneratorMulti, test_images_multi)


    def get_multi_projections(self, num_samples=1000):
        print("LOADING DATA")
        images_multi, data_val = self.load_data_multi()

        print("PREPROCESSING DATA")

        all_data_items = self.preprocess_data_attributes_multi(data_val, images_multi)

        print("SPLITTING DATA")
        test_images_multi, test_color_items_m, test_shape_items_m, test_size_items_m, test_material_items_m, test_pixelcoords_items_m, test_3dcoords_items_m, test_numobjects_items_m = self.get_test_data_splits(all_data_items)

        #indices = np.random.choice(range(len(test_images_multi)), num_samples+10, replace=False)

        test_images_multi = np.array(test_images_multi)[:num_samples]
        test_color_items_m = np.array(test_color_items_m)[:num_samples]
        test_shape_items_m = np.array(test_shape_items_m)[:num_samples]
        test_size_items_m = np.array(test_size_items_m)[:num_samples]
        test_material_items_m = np.array(test_material_items_m)[:num_samples]
        test_pixelcoords_items_m = np.array(test_pixelcoords_items_m)[:num_samples]
        test_3dcoords_items_m = np.array(test_3dcoords_items_m)[:num_samples]
        test_numobjects_items_m = np.array(test_numobjects_items_m)[:num_samples]

        test_image_attributes_multi = [test_color_items_m, test_shape_items_m, test_size_items_m, test_material_items_m, test_pixelcoords_items_m, test_3dcoords_items_m, test_numobjects_items_m]

        print("GETTING EMBEDDINGS")

        # Get embeddings
        self.get_projections(test_images_multi, test_image_attributes_multi, FeatureExtractorMulti, FeatureGeneratorMulti)



# Define main function to get embeddings
if __name__ == '__main__':

    # Add args that checks for simple or multi
    parser = argparse.ArgumentParser()
    parser.add_argument('--single', action='store_true')
    parser.add_argument('--multi', action='store_true')
    args = parser.parse_args()

    dataset = EmbeddingDataSet()

    if args.single:
        num_samples = 510
        dataset.get_single_projections(num_samples=num_samples)
        dataset.output_string = 'single'
        dataset.dataset_size = num_samples
    elif args.multi:
        num_samples = 5010
        dataset.get_multi_projections(num_samples=num_samples)
        dataset.output_string = 'multi'
        dataset.dataset_size = num_samples
    
    print("DONE")
