import json
import numpy as np
from sklearn.cluster import KMeans
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import argparse
import os

def k_center_greedy(features, n_samples, seed_indices=None):
    n_points = features.shape[0]
    if seed_indices is None:
        centers = [np.random.randint(n_points)]
    else:
        centers = list(seed_indices)

    dists = np.full(n_points, np.inf)
    
    for c in centers:
        new_dists = np.sum((features - features[c])**2, axis=1)
        dists = np.minimum(dists, new_dists)

    for _ in range(n_samples - len(centers)):
        new_center = np.argmax(dists)
        centers.append(new_center)
        new_dists = np.sum((features - features[new_center])**2, axis=1)
        dists = np.minimum(dists, new_dists)
        
    return centers

def main(args):

    print(f"Loading Dataset from {args.dataset_path}...")
    if args.dataset_path.endswith('.json'):
        with open(args.dataset_path, 'r') as f:
            data = json.load(f)
    else:
        ds = load_dataset(args.dataset_path, split='train')
        data = [item for item in ds]

    instructions = [item['instruction'] + " " + item.get('input', '') for item in data]
    

    print("Encoding instructions...")
    model = SentenceTransformer('/root/autodl-tmp/hf_model/bge-m3', device='cuda')
    embeddings = model.encode(instructions, batch_size=32, show_progress_bar=True, normalize_embeddings=True)

    print(f"Running K-Means with k={args.n_clusters}...")
    kmeans = KMeans(n_clusters=args.n_clusters, random_state=42, n_init=10)
    labels = kmeans.fit_predict(embeddings)

    unique_labels, counts = np.unique(labels, return_counts=True)
    cluster_stats = dict(zip(unique_labels, counts))
    
    valid_clusters = [lbl for lbl, count in cluster_stats.items() if count >= args.min_cluster_size]
    n_valid = len(valid_clusters)
    
    if n_valid == 0:
        raise ValueError(f"No clusters satisfy min_cluster_size={args.min_cluster_size}. Please reduce the threshold.")
    
    dropped_clusters = args.n_clusters - n_valid
    print(f"Cluster Filtering: Dropped {dropped_clusters} clusters (size < {args.min_cluster_size}). Remaining valid clusters: {n_valid}")

    base_samples = args.total_samples // n_valid
    remainder = args.total_samples % n_valid
    
    print(f"Sampling Budget: {base_samples} per cluster (First {remainder} clusters get +1)")

    selected_indices = []
    
    print("Running K-Center Greedy within valid clusters...")
    for i, cluster_id in enumerate(valid_clusters):
        target_n = base_samples + (1 if i < remainder else 0)
        
        cluster_indices = np.where(labels == cluster_id)[0]
        cluster_features = embeddings[cluster_indices]
        
        if len(cluster_indices) <= target_n:
            selected_indices.extend(cluster_indices)
        else:
            centroid = kmeans.cluster_centers_[cluster_id]
            dists_to_centroid = np.sum((cluster_features - centroid)**2, axis=1)
            seed_idx_in_cluster = np.argmin(dists_to_centroid)
            
            local_centers = k_center_greedy(cluster_features, target_n, seed_indices=[seed_idx_in_cluster])
            
            global_centers = cluster_indices[local_centers]
            selected_indices.extend(global_centers)

    selected_data = [data[i] for i in selected_indices]
    
    print(f"Target samples: {args.total_samples}, Actual selected: {len(selected_data)}")
    
    os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
    with open(args.output_path, 'w') as f:
        json.dump(selected_data, f, indent=2)
    
    print(f"Saved to {args.output_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_path", type=str, default="train_dataset/alpaca_data.json")
    parser.add_argument("--output_path", type=str, default="alpaca/geo_sampled.json")
    parser.add_argument("--total_samples", type=int, default=520)
    parser.add_argument("--n_clusters", type=int, default=228)
    parser.add_argument("--min_cluster_size", type=int, default=5, help="Drop clusters with fewer samples than this")
    args = parser.parse_args()
    main(args)