import os
import cv2
# from clip_retrieval.clip_client import ClipClient, Modality
import hashlib
from tqdm import tqdm
from torchvision import transforms
from PIL import Image

def compute_clip_score(directory):
    """
    Compute the CLIP score for each image in a directory and remove low-score images.
    """
    # Initialize CLIP client
    client = ClipClient(
        url="https://knn.laion.ai/knn-service",
        indice_name="laion5B-L-14",
        aesthetic_score=9,
        aesthetic_weight=0.5,
        modality=Modality.IMAGE,
        num_images=20,
    )

    # Traverse the directory structure and compute the CLIP scores for each image
    file_paths = []
    folder_names = []
    clip_scores = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith(".jpg") or file.endswith(".png"):
                file_path = os.path.join(root, file)
                file_paths.append(file_path)
                folder_name = os.path.basename(os.path.dirname(file_path))
                folder_names.append(folder_name)
                results = client.query(text=folder_name, image=file_path)
                clip_score = results[0]['score'] if len(results) > 0 else 0.0
                clip_scores.append(clip_score)

    # Remove low-score images
    remove_indices = [i for i, score in enumerate(clip_scores) if score < 0.5]
    for i in sorted(remove_indices, reverse=True):
        os.remove(file_paths[i])
        file_paths.pop(i)
        folder_names.pop(i)
        clip_scores.pop(i)

    return file_paths, folder_names, clip_scores


def visualize_images(directory):
    """
    Visualize images in a nested directory structure.
    """
    file_paths = []
    folder_names = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith(".jpg") or file.endswith(".png"):
                file_paths.append(os.path.join(root, file))
                folder_names.append(os.path.basename(os.path.dirname(file_paths[-1])))

    if len(file_paths) == 0:
        print("No images found in directory")
        return
    
    current_folder = folder_names[0]
    current_index = 0
    img = cv2.imread(file_paths[current_index])
    cv2.imshow(current_folder, img)
    
    while True:
        key = cv2.waitKey(0)
        if key == ord('d'):
            os.remove(file_paths[current_index])
            file_paths.pop(current_index)
            folder_names.pop(current_index)
            if current_index == len(file_paths):
                current_index -= 1
            if len(file_paths) > 0:
                cv2.destroyAllWindows()
                current_folder = folder_names[current_index]
                img = cv2.imread(file_paths[current_index])
                cv2.imshow(current_folder, img)
            else:
                break
        elif key == ord('q') or key == 27: # exit on q or ESC
            break
        elif key == ord('f') or key == 83: # right arrow or f
            current_index = min(current_index + 1, len(file_paths) - 1)
            if folder_names[current_index] != current_folder:
                current_folder = folder_names[current_index]
                cv2.destroyAllWindows()
            img = cv2.imread(file_paths[current_index])
            cv2.imshow(current_folder, img)
            cv2.setWindowTitle(current_folder, current_folder)
        elif key == ord('s') or key == 81: # left arrow or s
            current_index = max(current_index - 1, 0)
            if folder_names[current_index] != current_folder:
                current_folder = folder_names[current_index]
                cv2.destroyAllWindows()
            img = cv2.imread(file_paths[current_index])
            cv2.imshow(current_folder, img)
            cv2.setWindowTitle(current_folder, current_folder)
        elif key == ord('n'): # previous folder
            cv2.destroyAllWindows()
            folder_index = folder_names.index(current_folder)
            folder_index = (folder_index - 1) % len(set(folder_names))
            current_folder = list(set(folder_names))[folder_index]
            current_index = folder_names.index(current_folder)
            img = cv2.imread(file_paths[current_index])
            cv2.imshow(current_folder, img)
            cv2.setWindowTitle(current_folder, current_folder)
        elif key == ord('m'): # next folder
            cv2.destroyAllWindows()
            folder_index = folder_names.index(current_folder)
            folder_index = (folder_index + 1) % len(set(folder_names))
            current_folder = list(set(folder_names))[folder_index]
            current_index = folder_names.index(current_folder)
            img = cv2.imread(file_paths[current_index])
            cv2.imshow(current_folder, img)
            cv2.setWindowTitle(current_folder, current_folder)

    cv2.destroyAllWindows()

def remove_unopenable_files(file_paths, transform):
    """
    Remove files that cannot be opened using OpenCV.

    Args:
        file_paths: A list of file paths to check and remove.

    Returns:
        None
    """
    remove_indices = []
    for i, file_path in enumerate(file_paths):
        # Check that the file exists
        try:
            if not os.path.isfile(file_path):
                raise ValueError(f"File not found: {file_path}")
            else:
                # Load the image in PIL format
                pil_image = Image.open(file_path)

                # Apply the transformations to the image
                tensor_image = transform(pil_image)
        except Exception as e:
            print(f"Error loading image: {e}")
            remove_indices.append(i)
    for i in sorted(remove_indices, reverse=True):
        os.remove(file_paths[i])
        file_paths.pop(i)
        
def traverse_and_remove_unopenable_files(directory):
    """
    Traverse a directory and remove files that cannot be opened using OpenCV.

    Args:
        directory: A string specifying the directory to traverse.

    Returns:
        None
    """
    file_paths = []
    transform = transforms.Compose([
        # transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith(".jpg") or file.endswith(".png"):
                file_path = os.path.join(root, file)
                file_paths.append(file_path)
    remove_unopenable_files(file_paths, transform)

def traverse_and_resize_images(directory: str, size: tuple = (32, 32)) -> None:
    """
    Traverse a directory and resize all images to the specified size.

    Args:
        directory: A string specifying the directory to traverse.
        size: A tuple of integers specifying the target size of the images. Defaults to (32, 32).

    Returns:
        None
    """
    trans = transforms.Resize(size)
    file_paths = []

    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith((".jpg", ".png")):
                file_path = os.path.join(root, file)
                file_paths.append(file_path)

    # Add a progress bar
    for file_path in tqdm(file_paths):
        try:
            img = Image.open(file_path).convert('RGB')
            resized_img = trans(img)
            resized_img.save(file_path)
        except Exception as e:
            print(f"Error processing image {file_path}: {e}")
            # resized_img = cv2.resize(img, size)
            # cv2.imwrite(file_path, resized_img)
            # cv2.imencode('.png', resized_img)[1].tofile(file_path)



def remove_redundant_images(directory):
    """
    Remove redundant images in a directory based on file size and hash.

    Args:
        directory: A string specifying the directory to remove redundant images in.

    Returns:
        None
    """
    file_sizes = {}
    file_hashes = {}
    pbar = tqdm(total=sum(len(files) for _, _, files in os.walk(directory)))
    for root, _, files in os.walk(directory):
        for file in files:
            if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                file_path = os.path.join(root, file)
                file_size = os.path.getsize(file_path)
                with open(file_path, 'rb') as f:
                    file_hash = hashlib.md5(f.read()).hexdigest()

                if file_size in file_sizes and file_hashes[file_size] == file_hash:
                    # Remove the redundant file
                    os.remove(file_path)
                    print(f"Deleted file: {file_path}")
                else:
                    file_sizes[file_size] = file_path
                    file_hashes[file_size] = file_hash

                pbar.update(1)
    pbar.close()

if __name__ == '__main__':
    directory = '/Checkpoint/user/data/tree_data/imagenet_images_256/'
    # visualize_images(directo ry)
    traverse_and_remove_unopenable_files(directory)

    # traverse_and_resize_images(directory, size=256)
    # remove_redundant_images(directory)