import numpy as np
from enum import Enum
from tqdm import tqdm
import argparse
import json
import os
import logging
from functools import partial
from collections import Counter
from concurrent.futures import ProcessPoolExecutor

from src.turtlegfx.utils.base64img import convert_base64_to_img, save_base64_image
from src.turtlegfx_datagen.utils.img_utils import get_image_filename_from_id

import torch
import torch.nn as nn
from torchvision import models, transforms
from sklearn.cluster import DBSCAN
import torch.nn.parallel

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class ClusterAlgorithm(Enum):
    ALGORITHM = "DBSCAN"

class ImagePreprocessing(Enum):
    RESIZE_HEIGHT = 256
    RESIZE_WIDTH = 256
    CROP_HEIGHT = 224
    CROP_WIDTH = 224
    NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
    NORMALIZATION_STD = [0.229, 0.224, 0.225]

class FeatureExtractor(Enum):
    EXTRACTOR = "resnet18"

def get_feature_extractor():
    model_name = FeatureExtractor.EXTRACTOR.value
    model = getattr(models, model_name)(pretrained=True)
    model = nn.Sequential(*list(model.children())[:-1])
    model.eval()
    if torch.cuda.device_count() > 1:
        logger.info(f"Using {torch.cuda.device_count()} GPUs")
        model = nn.DataParallel(model)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    return model.to(device), device

def get_image_transform():
    return transforms.Compose([
        transforms.Resize((ImagePreprocessing.RESIZE_HEIGHT.value, ImagePreprocessing.RESIZE_WIDTH.value)),
        transforms.CenterCrop((ImagePreprocessing.CROP_HEIGHT.value, ImagePreprocessing.CROP_WIDTH.value)),
        transforms.ToTensor(),
        transforms.Normalize(ImagePreprocessing.NORMALIZATION_MEAN.value, ImagePreprocessing.NORMALIZATION_STD.value)
    ])

def extract_features(data, model, device, transform, batch_size):
    feature_vectors = []
    processed_items = []  # To track the successfully processed items
    num_samples = len(data)
    for i in tqdm(range(0, num_samples, batch_size), desc="Extracting features"):
        batch_data = data[i:i+batch_size]
        batch_images = []
        batch_items = []
        for item in batch_data:
            try:
                img = convert_base64_to_img(item['task_image'])
                # Optional: Skip images that are too large
                if img.size[0] * img.size[1] > 2000 * 2000:  # Adjust the limit as needed
                    print(f"Skipping image {item['id']} due to large size.")
                    continue
                img = transform(img)
                batch_images.append(img)
                batch_items.append(item)  # Track the item if successfully processed
            except Exception as e:
                print(f"Error processing image {item['id']}: {e}")
                continue
        if not batch_images:
            continue  # Skip if no images in batch
        batch_tensor = torch.stack(batch_images).to(device)
        with torch.no_grad():
            batch_features = model(batch_tensor).cpu().numpy()
        feature_vectors.extend(batch_features.reshape(batch_features.shape[0], -1))
        processed_items.extend(batch_items)  # Keep only successfully processed items
        # Clean up to free memory
        del batch_images, batch_tensor, batch_features
        torch.cuda.empty_cache()
    return np.array(feature_vectors), processed_items

def cluster_images(feature_vectors, eps, min_samples):
    norms = np.linalg.norm(feature_vectors, axis=1, keepdims=True)
    normalized_features = feature_vectors / norms
    return DBSCAN(eps=eps, min_samples=min_samples, metric='euclidean').fit(normalized_features).labels_

def remove_duplicates(data):
    # Remove duplicates: keep the item with the shortest code per cluster
    cluster_id_to_item = {}
    outlier_items = []
    for item in data:
        cluster_id = item['deduplication']['cluster_id']
        code_length = len(item['code'].splitlines())
        if cluster_id == 'outlier':
            outlier_items.append(item)
        else:
            if cluster_id not in cluster_id_to_item:
                cluster_id_to_item[cluster_id] = (item, code_length)
            else:
                existing_item, existing_code_length = cluster_id_to_item[cluster_id]
                if code_length < existing_code_length:
                    cluster_id_to_item[cluster_id] = (item, code_length)
                # If there's a tie, keep the first one (do nothing)
    items_to_keep = [item for item, _ in cluster_id_to_item.values()] + outlier_items

    print("Duplicates Statistics: ")
    print(f"\tTotal items: {len(data)}")
    print(f"\tItems to keep: {len(items_to_keep)}")
    print(f"\tNumber of duplicates removed: {len(data) - len(items_to_keep)}")
    print(f"\tNon-duplicate rate: {len(items_to_keep) / len(data):.2%}")
    return items_to_keep

def process_item(item, label, cluster_sizes, cluster_total, image_total, args):
    if label != -1:
        cluster_id = f"cls-{label+1:06d}"
        cluster_size = cluster_sizes[label]
    else:
        cluster_id, cluster_size = "outlier", 1

    item["deduplication"] = {
        "cluster_id": cluster_id,
        "cluster_size": cluster_size,
        "cluster_total": cluster_total,
        "image_total": image_total,
        "cluster_algorithm": {
            "name": ClusterAlgorithm.ALGORITHM.value,
            "parameters": {
                "eps": args['eps'],
                "min_samples": args['min_samples'],
                "metric": "euclidean"
            }
        },
        "feature_extraction": {
            "model_name": FeatureExtractor.EXTRACTOR.value,
            "model_output_dim": 512,
            "batch_size": args['batch_size']
        },
        "image_preprocessing": {
            "resize": {"height": ImagePreprocessing.RESIZE_HEIGHT.value, "width": ImagePreprocessing.RESIZE_WIDTH.value},
            "center_crop": {"height": ImagePreprocessing.CROP_HEIGHT.value, "width": ImagePreprocessing.CROP_WIDTH.value},
            "normalization": {
                "mean": ImagePreprocessing.NORMALIZATION_MEAN.value,
                "std": ImagePreprocessing.NORMALIZATION_STD.value
            }
        }
    }
    return item


def process_data(data, labels, args):
    # Compute cluster_sizes
    label_counts = Counter(labels)
    cluster_sizes = {label: count for label, count in label_counts.items() if label != -1}

    cluster_total, image_total = len(cluster_sizes), len(data)

    args_dict = {
        'eps': args.eps,
        'min_samples': args.min_samples,
        'batch_size': args.batch_size,
    }

    # Process data sequentially
    data = [
        process_item(item, label, cluster_sizes, cluster_total, image_total, args_dict)
        for item, label in tqdm(zip(data, labels), total=len(data), desc="Processing data")
    ]

    if args.remove_duplicates:
        data_to_keep = remove_duplicates(data)

        # Save statistics to a json file
        stats_path = args.output_path.replace('.json', '_stats.json')
        with open(stats_path, 'w') as file:
            json.dump({
                "n_samples_total": len(data),
                "n_samples_removed": len(data) - len(data_to_keep),
                "n_samples_remaining": len(data_to_keep),
                "retention_rate": len(data_to_keep) / len(data),
                "duplicate_rate": 1 - len(data_to_keep) / len(data),
            }, file, indent=2)
    else:
        data_to_keep = data

    return data_to_keep


def save_dup_images(processed_data, output_path):
    """
    Save duplicated images to a specified folder.

    Args:
        processed_data (list): List of processed data items.
        output_path (str): Path to the output file.
    """
    print(f"Saving duplicated images to {output_path}")
    dup_folder = os.path.join(
        os.path.dirname(output_path),
        f"{os.path.splitext(os.path.basename(output_path))[0]}_duplicated"
    )
    os.makedirs(dup_folder, exist_ok=True)

    for item in processed_data:
        cluster_id = item['deduplication']['cluster_id']
        cluster_size = item['deduplication']['cluster_size']

        # If there are more than 2 images in the same cluster, save them to the duplicated folder
        if cluster_id != 'outlier' and cluster_size > 1:
            filename = f"{cluster_id}_{get_image_filename_from_id(item['id'])}"
            save_base64_image(item['task_image'], os.path.join(dup_folder, filename))

def main(args):
    logger.info(f"Loading {args.input_path}")

    with open(args.input_path, 'r') as f:
        data = json.load(f)

    logger.info(f"Loaded {len(data)} samples from {args.input_path}")

    # Extract features --> clustering
    model, device = get_feature_extractor()
    transform = get_image_transform()

    logger.info(f"Extracting features...")
    feature_vectors, filtered_data = extract_features(data, model, device, transform, args.batch_size)

    logger.info(f"Clustering images...")
    labels = cluster_images(feature_vectors, args.eps, args.min_samples)

    logger.info(f"Processing data...")
    processed_data = process_data(filtered_data, labels, args)

    logger.info(f"Saving data to {args.output_path}")
    os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
    with open(args.output_path, 'w') as f:
        json.dump(processed_data, f, indent=2)

    # Save the duplicated images if the flag is set
    if args.save_dup_images:
        logger.info(f"Saving duplicated images to {args.output_path}")
        save_dup_images(processed_data, args.output_path)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Image clustering and deduplication based on similarity.")
    parser.add_argument('--input_path', type=str, required=True, help='')
    parser.add_argument('--output_path', type=str, required=True, help='')
    parser.add_argument('--eps', type=float, default=0.2, help='DBSCAN epsilon parameter')
    parser.add_argument('--min_samples', type=int, default=2, help='DBSCAN min_samples parameter')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size for feature extraction')
    parser.add_argument('--remove_duplicates', action='store_true', help='Whether to remove duplicates from the dataset')
    parser.add_argument('--save_dup_images', action='store_true', help='Whether to save the duplicated images in a `duplicated` folder')
    parser.add_argument('--num_workers', type=int, default=32, help='Number of worker processes to use for multiprocessing')

    main(parser.parse_args())
