import json
import os
from tqdm import tqdm
from copy import deepcopy
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import entropy
from collections import Counter
import heapq
import math
from collections import defaultdict
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_distances
from transformers import CLIPProcessor, CLIPModel
import torch
from tqdm import tqdm
import faiss
import time
from argparse import ArgumentParser




def plot_data(selected_data, all_data):
    """
    Plot the margin vs normalized_llm_score for all data points, 
    with selected points in red and others in gray.
    """
    # Extract values for all data
    all_margins = [d['margin'] for d in all_data]
    all_llm_scores = [d['normalized_llm_score'] for d in all_data]

    # Extract values for selected data
    selected_margins = [d['margin'] for d in selected_data]
    selected_llm_scores = [d['normalized_llm_score'] for d in selected_data]

    # Plot all data points in gray
    plt.scatter(all_margins, all_llm_scores, color='gray', label='All Data', alpha=0.5)

    # Plot selected data points in red
    plt.scatter(selected_margins, selected_llm_scores, color='red', label='Selected Data')

    # Add labels and title
    plt.xlabel('Margin')
    plt.ylabel('Normalized LLM Score')
    plt.title('Margin vs Normalized LLM Score')
    plt.legend()

    # Show the plot
    plt.show()


def plot_caption_distribution(selected_data):
    """Plot a histogram showing the distribution of captions in the selected data."""
    
    # Extract captions from selected data
    selected_captions = [d['caption'] for d in selected_data]
    
    # Count the frequency of each caption
    caption_counts = Counter(selected_captions)
    
    # Create lists of captions and their corresponding frequencies
    captions = list(caption_counts.keys())
    frequencies = list(caption_counts.values())
    
    # Plot the histogram
    plt.figure(figsize=(10, 6))
    plt.barh(captions, frequencies, color='skyblue')
    plt.xlabel('Frequency')
    plt.ylabel('Captions')
    plt.title('Distribution of Captions in Selected Data')
    plt.tight_layout()
    
    # Show the plot
    plt.show()

def lazy_greedy_subset_selection(pick_data, K, alpha, beta):
    N = len(pick_data)
    
    margins = np.array([d['normalized_margin'] for d in pick_data])
    llm_scores = np.array([d['normalized_llm_score'] for d in pick_data])
    captions = [d['caption'] for d in pick_data]
    
    selected = np.zeros(N, dtype=bool)
    selected_indices = []
    
    def calculate_entropy(indices):
        selected_captions = [captions[i] for i in indices]
        counts = Counter(selected_captions)
        all_captions = set(captions)
        probs = [counts.get(caption, 0) / len(indices) for caption in all_captions]
        return entropy(probs)
    
    def calculate_gain(i):
        margin_gain = margins[i]
        llm_score_gain = llm_scores[i]
        new_indices = selected_indices + [i]
        entropy_gain = calculate_entropy(new_indices) - calculate_entropy(selected_indices) if selected_indices else 0
        return -(margin_gain + alpha * llm_score_gain + beta * entropy_gain)  # Negative for max-heap
    
    # Initialize priority queue with initial gains
    pq = [(calculate_gain(i), i) for i in range(N)]
    heapq.heapify(pq)
    
    for _ in tqdm(range(K)):
        while True:
            gain, index = heapq.heappop(pq)
            if not selected[index]:
                # Recalculate gain
                new_gain = calculate_gain(index)
                if new_gain == gain:
                    # This is the best element
                    selected[index] = True
                    selected_indices.append(index)
                    break
                else:
                    # Push back with updated gain
                    heapq.heappush(pq, (new_gain, index))
    
    return [pick_data[i] for i in selected_indices]




def calculate_duplication_scores(n):
    """
    Given the rank (n), compute the duplication score using the function (1 - exp(n - 1)).
    """

    if n > 50:
        return -15

    return 1 - math.exp((n - 1) / 20)

def select_top_k_pairs_w_decreasing(pick_data, K, alpha, beta):
    # Group data by captions
    caption_groups = defaultdict(list)
    
    for entry in pick_data:
        caption_groups[entry['caption']].append(entry)
    
    # Sort each group by (normalized_margin + alpha * normalized_llm_score)
    for caption, pairs in caption_groups.items():
        pairs.sort(key=lambda x: x['normalized_margin'] + alpha * x['normalized_llm_score'], reverse=True)
    
    # Create a new list to store all pairs with their final scores
    scored_pairs = []
    
    for caption, pairs in caption_groups.items():
        for i, pair in enumerate(pairs):
            # Rank starts at 1 (i.e., the first entry in the sorted list has rank 1)
            rank = i + 1
            duplication_score = calculate_duplication_scores(rank)
            
            # Calculate final score
            final_score = (pair['normalized_margin'] +
                           alpha * pair['normalized_llm_score'] +
                           beta * duplication_score)
            
            # Store the final score and the pair in the list
            scored_pairs.append((final_score, pair))
    
    # Sort all pairs by their final score in descending order
    scored_pairs.sort(key=lambda x: x[0], reverse=True)
    
    # Extract the top K pairs
    top_k_pairs = [pair for score, pair in scored_pairs[:K]]

    print(f"Selected {K} pairs with alpha={alpha}, beta={beta}")
    
    return top_k_pairs



def select_optimal_subset_w_embedding(pick_data, K, alpha, beta, T=6, clip_model_name="openai/clip-vit-base-patch32", batch_size=5096, k_neighbors=5):
    start_time = time.time()

    # Load CLIP model and processor
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CLIPModel.from_pretrained(clip_model_name).to(device)
    processor = CLIPProcessor.from_pretrained(clip_model_name)
    print(f"Model loading time: {time.time() - start_time:.2f} seconds")

    # Extract unique captions and create mappings
    t0 = time.time()
    unique_captions = list(set(item['caption'] for item in pick_data))
    caption_to_index = {caption: idx for idx, caption in enumerate(unique_captions)}
    index_to_caption = {idx: caption for caption, idx in caption_to_index.items()}
    caption_to_data_indices = defaultdict(list)
    for idx, item in enumerate(pick_data):
        caption_to_data_indices[item['caption']].append(idx)
    print(f"Data preparation time: {time.time() - t0:.2f} seconds")

    # Embed unique captions using CLIP
    t0 = time.time()
    embeddings = []
    for i in tqdm(range(0, len(unique_captions), batch_size), desc="Embedding captions"):
        batch = unique_captions[i:i+batch_size]
        inputs = processor(text=batch, return_tensors="pt", padding=True, truncation=True).to(device)
        with torch.no_grad():
            outputs = model.get_text_features(**inputs)
        embeddings.append(outputs.cpu().numpy())
    embeddings = np.vstack(embeddings).astype(np.float32)  # Ensure float32 for FAISS
    print(f"CLIP embedding time: {time.time() - t0:.2f} seconds")

    # Normalize embeddings for cosine similarity
    faiss.normalize_L2(embeddings)

    # Create FAISS index for fast cosine similarity search
    t0 = time.time()
    d = embeddings.shape[1]  # dimensionality of vectors
    index = faiss.IndexFlatIP(d)  # Inner Product is equivalent to cosine similarity for normalized vectors
    index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, index)
    index.add(embeddings)
    print(f"FAISS index creation time: {time.time() - t0:.2f} seconds")

    # Calculate duplication scores using FAISS (cosine similarity)
    # instead of log euclidean, we use cosine distance for the duplication score
    t0 = time.time()
    duplication_scores = {}
    for idx, embedding in enumerate(embeddings):
        similarities, I = index.search(embedding.reshape(1, -1), k_neighbors + 1)  # +1 because the first result is the query itself
        neighbors = I[0][1:]  # Exclude self
        neighbor_similarities = similarities[0][1:]  # Exclude self-similarity
        # Convert similarities to cosine distances: distance = 1 - similarity
        cosine_distances = 1 - neighbor_similarities
        duplication_score = float(np.mean(cosine_distances))
        caption = index_to_caption[idx]
        duplication_scores[caption] = duplication_score
        # print(duplication_score)
    print(f"Duplication score calculation time: {time.time() - t0:.2f} seconds")

    # Calculate scores for each item
    t0 = time.time()
    scores = []
    for idx, item in enumerate(pick_data):
        normalized_margin = item['normalized_margin']
        normalized_llm_score = item['normalized_llm_score']
        duplication_score = duplication_scores[item['caption']]
        
        score = normalized_margin + alpha * normalized_llm_score + beta * duplication_score
        scores.append((idx, score))
    print(f"Score calculation time: {time.time() - t0:.2f} seconds")

    t0 = time.time()
    sorted_scores = sorted(scores, key=lambda x: x[1], reverse=True)
    selected_indices = []
    caption_count = defaultdict(int)
    current_T = T

    if beta > 0:
    
        while len(selected_indices) < K:
            all_captions_at_limit = True
            for idx, _ in sorted_scores:
                if len(selected_indices) >= K:
                    break
                
                caption = pick_data[idx]['caption']
                if caption_count[caption] < current_T:
                    selected_indices.append(idx)
                    caption_count[caption] += 1
                    all_captions_at_limit = False
            
            if all_captions_at_limit and len(selected_indices) < K:
                current_T *= 2
                print(f"Doubling T to {current_T}")
    else:
        # do not conisder K

        for idx, _ in sorted_scores:
            selected_indices.append(idx)
            if len(selected_indices) >= K:
                break
    
    result = [pick_data[idx] for idx in selected_indices]
    print(f"Final selection time: {time.time() - t0:.2f} seconds")
    print(f"Final T value: {current_T}")

    total_time = time.time() - start_time
    print(f"Total execution time: {total_time:.2f} seconds")

    return result





if __name__ == '__main__':

    parser = ArgumentParser()
    parser.add_argument("--dataset", default="a type of dataset used")
    # Load the data

    args = parser.parse_args()

    if args.dataset == "pickapic":
        data_path = 'YOUR_PATH'
        output_base_path = 'YOUR_OUTPUT_PATH'
    elif args.dataset == "hpsv2":
        data_path = 'YOUR_PATH'
        output_base_path = 'YOUR_PATH'
    else:
        raise ValueError("Invalid dataset name")

    full_data = json.load(open(data_path, 'r'))
    
    K_lst = [5000]
    alpha_beta_lsts = [
        (0.1, 0.1),  
        (1.5, 1.5),  
        ]
    
    for alpha, beta in alpha_beta_lsts:
        for K in K_lst:
            # selected_data = lazy_greedy_subset_selection(pick_data, K, alpha, beta)
            selected_data = select_optimal_subset_w_embedding(full_data, K, alpha, beta)
            # selected_data = select_top_k_pairs(pick_data, K, alpha, beta)
            # save the selected data
            with open(f'{output_base_path}embedding_{alpha}_{beta}.json', 'w') as f:
                json.dump(selected_data, f)

            # get the list of __index_level_0__ attribute of each data and save it into "pair_lazy_{alpha}_{beta}.json"

            if args.dataset == "pickapic":
                index_attribute = '__index_level_0__'
            elif args.dataset == "hpsv2":
                index_attribute = 'index'

            selected_data_index = [d[index_attribute] for d in selected_data]
            with open(f'{output_base_path}pair_embedding_{alpha}_{beta}.json', 'w') as f:
                json.dump(selected_data_index, f)



    

            