import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import numpy as np
from numpy.linalg import norm
from src.turtlegfx.eval.lines import get_normalized_images_from_code



def get_feature_extractor():
    """
    Get the feature extractor model. Use ResNet18 with ImageNet weights.
    """
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    model = nn.Sequential(*list(model.children())[:-1])
    model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    return model.to(device), device


def get_image_transform():
    """
    Get the image transformation pipeline.
    """
    return transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])


def extract_features(image, model, device, transform):
    """
    Extract features from an image using a pre-trained ResNet18 model.

    Args:
        image (PIL.Image.Image or str): The image to extract features from or the path to the image.
        model (torch.nn.Module): The pre-trained ResNet18 model.
        device (torch.device): The device to run the model on.
        transform (torchvision.transforms.Compose): The transformation to apply to the image.

    Returns:
    """
    if isinstance(image, str):
        # image path
        image = Image.open(image).convert('RGB')
    else:
        # PIL image
        image = image.convert('RGB')
    
    input_tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        features = model(input_tensor)
    return features.cpu().numpy().reshape(-1)


def extract_batch_features(batch_images, model, device, transform, batch_size=32):
    """
    Extract features from a batch of images efficiently using batched processing.
    
    Args:
        batch_images (list): List of PIL images or image paths
        model (torch.nn.Module): Pre-trained model
        device (torch.device): Device to run the model on
        transform (transforms.Compose): Image transformation pipeline
        batch_size (int): Size of mini-batches for processing
    
    Returns:
        tuple: (feature vectors array, list of successfully processed indices)
    """
    feature_vectors = []
    processed_indices = []
    
    for i in range(0, len(batch_images), batch_size):
        mini_batch = batch_images[i:i + batch_size]
        batch_tensors = []
        
        for idx, img in enumerate(mini_batch):
            try:
                if isinstance(img, str):
                    img = Image.open(img).convert('RGB')
                else:
                    img = img.convert('RGB')
                    
                # Skip if image is too large
                if img.size[0] * img.size[1] > 2000 * 2000:
                    continue
                    
                img_tensor = transform(img)
                batch_tensors.append(img_tensor)
                processed_indices.append(i + idx)
            except Exception as e:
                print(f"Error processing image at index {i + idx}: {e}")
                continue
                
        if not batch_tensors:
            continue
            
        # Process mini-batch
        batch_tensor = torch.stack(batch_tensors).to(device)
        with torch.no_grad():
            batch_features = model(batch_tensor).cpu().numpy()
        feature_vectors.extend(batch_features.reshape(batch_features.shape[0], -1))
        
        # Clean up
        del batch_tensor, batch_features
        torch.cuda.empty_cache()
    
    return np.array(feature_vectors), processed_indices


def compare_embeddings_from_batch_images(batch_images1, batch_images2, metric="euclidean", batch_size=32):
    """
    Compare embeddings of two batches of images efficiently using ResNet18 features.
    
    This function processes images in mini-batches to efficiently compute similarity scores
    between corresponding pairs of images from two batches. It can handle both file paths
    and PIL Image objects.
    
    Args:
        batch_images1 (list): First batch of images. Each element can be either:
            - str: Path to an image file
            - PIL.Image.Image: PIL Image object
        batch_images2 (list): Second batch of images (same format as batch_images1)
        metric (str, optional): Similarity metric to use. Defaults to "euclidean".
            - "euclidean": Returns normalized similarity in range [0, 1]
            - "cosine": Returns cosine similarity in range [-1, 1]
        batch_size (int, optional): Size of mini-batches for processing. Defaults to 32.
    
    Returns:
        numpy.ndarray: Array of similarity scores with same length as input batches.
            Each score at index i represents the similarity between batch_images1[i] 
            and batch_images2[i]. Contains None for pairs where image processing failed.
    
    Examples:
        >>> # Compare two batches using file paths
        >>> paths1 = ['img1.jpg', 'img2.jpg', 'img3.jpg']
        >>> paths2 = ['ref1.jpg', 'ref2.jpg', 'ref3.jpg']
        >>> scores = compare_embeddings_from_batch_images(paths1, paths2)
        >>> print(scores)  # [0.85, 0.92, None]  # None indicates processing failure
        >>>               # 0.85 is similarity between img1.jpg and ref1.jpg
        >>>               # 0.92 is similarity between img2.jpg and ref2.jpg
        
        >>> # Compare using PIL Images
        >>> from PIL import Image
        >>> imgs1 = [Image.open(p) for p in paths1]
        >>> imgs2 = [Image.open(p) for p in paths2]
        >>> scores = compare_embeddings_from_batch_images(
        ...     imgs1, imgs2, metric='cosine', batch_size=2
        ... )
        >>> print(scores)  # [0.95, 0.88, 0.76]
        >>>               # 0.95 is similarity between imgs1[0] and imgs2[0]
        >>>               # 0.88 is similarity between imgs1[1] and imgs2[1]
        >>>               # 0.76 is similarity between imgs1[2] and imgs2[2]
    
    Raises:
        ValueError: If invalid metric is specified
        AssertionError: If input batches have different lengths
    """
    assert len(batch_images1) == len(batch_images2), "Batches must have same length"
    
    # Initialize model and transform
    model, device = get_feature_extractor()
    transform = get_image_transform()
    
    # Extract features for both batches
    features1, indices1 = extract_batch_features(batch_images1, model, device, transform, batch_size)
    features2, indices2 = extract_batch_features(batch_images2, model, device, transform, batch_size)
    
    # Create similarity array with None values
    similarities = np.full(len(batch_images1), None)
    
    # Find common indices that were successfully processed
    common_indices = set(indices1) & set(indices2)
    
    if not common_indices:
        return similarities
    
    # Get features for common indices
    valid_features1 = features1[[indices1.index(i) for i in common_indices]]
    valid_features2 = features2[[indices2.index(i) for i in common_indices]]
    
    if metric == "cosine":
        # Normalize features
        norms1 = np.linalg.norm(valid_features1, axis=1, keepdims=True)
        norms2 = np.linalg.norm(valid_features2, axis=1, keepdims=True)
        valid_features1 = valid_features1 / norms1
        valid_features2 = valid_features2 / norms2
        
        # Compute similarities
        valid_similarities = np.sum(valid_features1 * valid_features2, axis=1)
    
    elif metric == "euclidean":
        # Compute distances
        distances = np.linalg.norm(valid_features1 - valid_features2, axis=1)
        
        # Normalize distances to similarities
        max_distance = np.sqrt(valid_features1.shape[1])
        valid_similarities = 1 - (distances / max_distance)
        valid_similarities = np.clip(valid_similarities, 0.0, 1.0)
    
    else:
        raise ValueError(f"Invalid metric: {metric}. Use 'cosine' or 'euclidean'.")
    
    # Update similarities array with computed values
    for idx, sim in zip(common_indices, valid_similarities):
        similarities[idx] = float(sim)
    
    return similarities


def compare_embeddings_from_image(image1, image2, metric="euclidean"):
    """
    Compare the embeddings of two images.

    Args:
        image1 (PIL.Image.Image or str): The first image (or image path) to compare.
        image2 (PIL.Image.Image or str): The second image (or image path) to compare.
        metric (str): The similarity metric to use ("euclidean" or "cosine").

    Returns:
        float: The similarity between the two images calculated using the metric.
    """
    model, device = get_feature_extractor()
    transform = get_image_transform()

    # Extract features for both images
    features1 = extract_features(image1, model, device, transform)
    features2 = extract_features(image2, model, device, transform)

    if metric == "cosine":
        # Normalize feature vectors
        features1 = features1 / norm(features1)
        features2 = features2 / norm(features2)

        # Compute cosine similarity
        similarity = np.dot(features1, features2)
    
    elif metric == "euclidean":
        # Compute Euclidean distance
        distance = np.linalg.norm(features1 - features2)

        # Normalize distance
        max_distance = np.sqrt(len(features1))
        similarity = 1 - (distance / max_distance)
        similarity = max(0.0, min(similarity, 1.0))  # Ensure it's within [0, 1]
    
    else:
        raise ValueError(f"Invalid metric: {metric}. Please use 'cosine' or 'euclidean'.")
    
    return similarity


def compare_embeddings_from_code(code1, code2, metric="euclidean"):
    """
    Compare the embeddings of two code turtle graphics programs.

    Args:
        code1 (str): The first turtle graphics code.
        code2 (str): The second turtle graphics code.
        metric (str): The similarity metric to use ("euclidean" or "cosine").

    Returns:
        float: The similarity score between the two code snippets.
    """
    # Get normalized images from code execution
    image1, image2 = get_normalized_images_from_code(code1, code2)
    
    if image1 is None or image2 is None:
        return 0.0  # Return minimum similarity for failed cases
    
    # Compare the embeddings 
    similarity_score = compare_embeddings_from_image(image1, image2, metric=metric)
    return similarity_score


def main(args):
    """
    Find the most similar images to the base image.
    
    Args:
        args (argparse.Namespace): The command line arguments.
    """

    image_files = glob(args.compare_files)
    base_file = args.base_file

    scores = []
    for image_file in tqdm(image_files, desc="Comparing images"):
        score = compare_embeddings_from_image(base_file, image_file, metric=args.metric)
        scores.append((score, image_file))
    
    # sort the scores
    scores = sorted(scores, key=lambda x: x[0], reverse=True)
    for score, image_file in scores:
        print(f"Similarity Score: ({os.path.basename(base_file)}, {os.path.basename(image_file)}, {score:.4f})")
    

if __name__ == "__main__":
    import argparse
    from glob import glob
    from tqdm import tqdm
    import os

    parser = argparse.ArgumentParser()
    parser.add_argument("--base_file", type=str, default="src/turtlegfx/data/midi/image/midi_2c.png")
    parser.add_argument("--compare_files", type=str, default="src/turtlegfx/data/midi/image/midi_*.png")
    parser.add_argument("--metric", type=str, default="euclidean")
    args = parser.parse_args()
    
    main(args)
