#!/usr/bin/env python
# coding: utf-8

# In[1]:


import pandas as pd
from sklearn.decomposition import TruncatedSVD
from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import random
import seaborn as sns 
import math
from math import log
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D
import sys
import os
import pickle
from tqdm.notebook import tqdm 
import logging
from concurrent.futures import ProcessPoolExecutor, as_completed
from concurrent.futures import ThreadPoolExecutor, as_completed


# 
# # Applying SVD and t-SNE to MNIST

# In[2]:


data = pd.read_csv('/home/pinki/mnist_train.csv')
labels = data.iloc[:, 0].values
images = data.iloc[:, 1:].values

scaler = MinMaxScaler()
images_scaled = scaler.fit_transform(images)

svd = TruncatedSVD(n_components=50, random_state=42)
images_svd_50d = svd.fit_transform(images_scaled)

tsne = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200, n_iter=1000)
images_tsne_2d = tsne.fit_transform(images_scaled)

label_indices = {}
for idx, label in enumerate(labels):
    if label not in label_indices:
        label_indices[label] = []
    label_indices[label].append(idx)

plt.figure(figsize=(8, 4))
custom_palette = sns.color_palette("tab10", 10)

sns.scatterplot(
    x=images_tsne_2d[:, 0],
    y=images_tsne_2d[:, 1],
    hue=labels,
    palette=custom_palette,
    s=60,
    alpha=0.7
)
plt.title("2D t-SNE Visualization", fontsize=16)
plt.xlabel("t-SNE Component 1")
plt.ylabel("t-SNE Component 2")
plt.legend(title="Digit", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.show()
print("Dictionary with labels as keys and indices of corresponding samples (preview):")
for key in list(label_indices.keys()):
    print("Label:", key, "Indices:", label_indices[key][:5], "...")


# # Loading the Distance Matrix

# In[24]:


distance_matrices_tsne = {}
tsne_path = "/home/pinki/matrix_mnist/distance_matrix_tsne_delta_0.1.npy"
distance_matrices_tsne[0.1] = np.load(tsne_path)
perturbed_distances= distance_matrices_tsne[0.1]


# # Strong Base line(k-means++)
# 

# In[14]:


def initialize_kmeans_plus_plus(dataset, k):
    n_points = len(dataset)
    centers = []
    
    first_center_idx = np.random.choice(n_points)
    centers.append(first_center_idx)
    
    for _ in range(k - 1):
        available_points = list(set(range(n_points)) - set(centers))
        distances = []

        for i in available_points:
            min_dist = float('inf')
            for center_idx in centers:
                dist = np.linalg.norm(dataset[i] - dataset[center_idx])  # Ensure dataset[center_idx] is a point
                min_dist = min(min_dist, dist)
            distances.append(min_dist)
        
        probabilities = np.array(distances) ** 2
        next_center_idx = np.random.choice(available_points, p=probabilities / probabilities.sum())
        centers.append(int(next_center_idx))

    return centers


# In[21]:


def calculate_cost(dataset, centers):
    cost = 0
    for i in range(len(dataset)):
        min_distance = float("inf")
        for center in centers:
            dist = np.linalg.norm(dataset[i] - dataset[center]) ** 2
            if dist < min_distance:
                min_distance = dist
        cost += min_distance
    return cost


# # Weak Baseline(k-means ++ with weak oracle)

# In[29]:


def weak_initialize_kmeans_plus_plus(dataset, k):
    n_points = len(dataset)
    centers = []
    
    first_center_idx = np.random.choice(n_points)
    centers.append(first_center_idx)
    
    for _ in range(k - 1):
        available_points = list(set(range(n_points)) - set(centers))
        distances = []

        for i in available_points:
            min_dist = float('inf')
            for center_idx in centers:
                dist = perturbed_distances[i,center_idx]  # Ensure dataset[center_idx] is a point
                min_dist = min(min_dist, dist)
            distances.append(min_dist)
        
        probabilities = np.array(distances) ** 2
        next_center_idx = np.random.choice(available_points, p=probabilities / probabilities.sum())
        centers.append(int(next_center_idx))

    return centers


# # Cost of Strong Baseline
# 

# In[ ]:


total_cost = 0
dataset=images_tsne_2d

for i in range(20):
    centers = initialize_kmeans_plus_plus(dataset, 10)
    actual_cost = calculate_cost(dataset, centers)
    print('Run', i+1, ': actual cost =', actual_cost)
    total_cost += actual_cost

average_cost = total_cost / 20
print('\nAverage actual cost over 10 runs =', average_cost)


# # Cost of Weak Baseline(k-means++ with strong oracle)

# In[ ]:


total_cost = 0
dataset=images_tsne_2d

for i in range(20):
    centers = weak_initialize_kmeans_plus_plus(dataset, 10)
    actual_cost = calculate_cost(dataset, centers)
    print('Run', i+1, ': actual cost =', actual_cost)
    total_cost += actual_cost

average_cost = total_cost / 20
print('\nAverage actual cost over 10 runs =', average_cost)


# # k-means in Weak-Strong Oracle Model

# ## Computing the Proxy for the Closest Center in C_i to x

# In[15]:


def compute_single_perturbed_median(x, C_i, closest_centers, perturbed_distances):
    """Compute the median of perturbed distances for one point x."""
    min_median = float('inf')
    
    for c in C_i:
        close_centers = closest_centers[c]
        
        distances = []
        for cc in close_centers:
            distance = perturbed_distances[x, cc]
            distances.append(distance)
        
        mid_index = len(distances) // 2
        median = np.partition(distances, mid_index)[mid_index]
        
        if median < min_median:
            min_median = median
    
    return x, min_median


# ## Forming Initial Balls Around every center in C_i Containing Threshold Many Points 

# In[16]:


def compute_initial_closest_centers(C_i, dataset, threshold):
    """Compute closest centers among the initial threshold centers."""
    maintained_distances = {}
    closest_centers = {}
    
    for c in C_i:
        distances = []
        
        for other_c in C_i:
            if c != other_c:
                diff = dataset[c] - dataset[other_c]
                dist = np.linalg.norm(diff)
                pair = (dist, other_c)
                distances.append(pair)
        
        distances.sort()
        
        maintained_distances[c] = distances
        
        top_closest = []
        index = 0
        while index < threshold and index < len(distances):
            value = distances[index][1]
            top_closest.append(value)
            index += 1
        
        closest_centers[c] = top_closest

    return maintained_distances, closest_centers


# In[17]:


def update_closest_centers(C_i, dataset, threshold):
    """Update closest centers after a new center is added."""
    maintained_distances = {}
    closest_centers = {}

    for c in C_i:
        distances = []

        for other_c in C_i:
            if c != other_c:
                difference = dataset[c] - dataset[other_c]
                distance = np.linalg.norm(difference)
                pair = (distance, other_c)
                distances.append(pair)

        distances.sort()

        maintained_distances[c] = distances

        top_closest = []
        index = 0
        while index < threshold and index < len(distances):
            center_id = distances[index][1]
            top_closest.append(center_id)
            index += 1

        closest_centers[c] = top_closest

    return maintained_distances, closest_centers


# ## Sampling O(klog_n/(0.5-delta)^2) many centers

# In[18]:


def perturbed_k_means_plus_plus(dataset, perturbed_distances, delta, k):
    batch_size=2000
    num_threads=64
    n = len(dataset)
    log_n = np.log(n)  # Natural log (base e)
    p = 0.5 - delta
    threshold = int(log_n / (p ** 2))
    t = (k - 1) * threshold

    C_i = random.sample(range(n), threshold)
    remaining_indices = list(set(range(n)) - set(C_i))

    no_so_edge_queries = 0
    no_so_point_queries = threshold

    # Initial closest centers
    maintained_distances, closest_centers = compute_initial_closest_centers(C_i, dataset, threshold)

    print(f" Starting k-means++ center selection with {t} iterations using {num_threads} threads...")

    for iteration in tqdm(range(t)):
        perturbed_medians_dict = {}

        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            futures = []
            for batch_start in range(0, len(remaining_indices), batch_size):
                batch = remaining_indices[batch_start:batch_start + batch_size]
                for x in batch:
                    future = executor.submit(compute_single_perturbed_median, x, C_i, closest_centers, perturbed_distances)
                    futures.append(future)

            # Track progress with tqdm
            for future in as_completed(futures):
                x, min_median = future.result()
                perturbed_medians_dict[x] = min_median

        # Step 4: Sample a new center
        sorted_indices = list(perturbed_medians_dict.keys())
        squared_distances = np.array([perturbed_medians_dict[i] ** 2 for i in sorted_indices])
        total = squared_distances.sum()

        if total == 0:
            sampled_index = random.choice(sorted_indices)
        else:
            probabilities = squared_distances / total
            sampled_index = np.random.choice(sorted_indices, p=probabilities)

        C_i.append(sampled_index)
        remaining_indices.remove(sampled_index)

        no_so_edge_queries += no_so_point_queries
        no_so_point_queries += 1

        # Step 5: Update closest centers
        maintained_distances, closest_centers = update_closest_centers(C_i, dataset, threshold)

    print("Selection process completed.")

    return C_i, no_so_edge_queries, no_so_point_queries


# 
# # # Perturbed_cost function assigns a weight to every center in C_i and return a weighted instance called center_weights

# In[19]:


def perturbed_cost(dataset, C_i, perturbed_distances,delta):
    n = len(dataset)
    log_n = np.log(n)
    p = (0.5-delta)
    threshold = int(log_n / p)
    total_perturbed_kmeans_cost = 0
    total_perturbed_kmedian_cost = 0
    center_weights = {}
    assignment = {}

    maintained_distances = {}
    closest_centers = {}

    for c in C_i:
        distances = []
        for other_c in C_i:
            distance = np.linalg.norm(dataset[other_c] - dataset[c])
            distances.append((distance, other_c))
        distances.sort(key=lambda item: item[0])
        maintained_distances[c] = distances

        closest_centers[c] = []
        for i in range(min(threshold, len(distances))):
            closest_centers[c].append(distances[i][1])

    for x in range(len(dataset)):
        min_median_distance = float('inf')
        assigned_center = None

        for c in C_i:
            distances = []
            for close_c in closest_centers[c]:
                perturbed_distance = perturbed_distances[x, close_c]
                distances.append(perturbed_distance)
            distances.sort()
            median_distance = distances[len(distances) // 2]

            if median_distance < min_median_distance:
                min_median_distance = median_distance
                assigned_center = c

        total_perturbed_kmeans_cost += min_median_distance ** 2
        total_perturbed_kmedian_cost += min_median_distance

        if assigned_center not in center_weights:
            center_weights[assigned_center] = 0
        center_weights[assigned_center] += 1

        assignment[x] = assigned_center

    return total_perturbed_kmeans_cost, total_perturbed_kmedian_cost, center_weights


# 
# # Running weighted k-means++ on the weighted instance(center_wights)

# In[20]:


def weighted_kmeans(dataset, k, center_weights):
    available_centers = list(center_weights.keys())
    selected_centers = []

    first_center = np.random.choice(available_centers, p=np.array(list(center_weights.values())) / sum(center_weights.values()))
    selected_centers.append(first_center)
    available_centers.remove(first_center)

    for _ in range(k - 1):
        min_distances = []

        for center in available_centers:
            min_distance = float("inf")

            for chosen_center in selected_centers:
                dist = np.linalg.norm(dataset[center] - dataset[chosen_center]) ** 2
                weighted_distance = center_weights[chosen_center] * dist
                min_distance = min(min_distance, weighted_distance)

            min_distances.append(min_distance)

        min_distances = np.array(min_distances)
        
        if len(available_centers)==0:
            return selected_centers
        probabilities = min_distances / np.sum(min_distances)
        next_center = np.random.choice(available_centers, p=probabilities)
        
        selected_centers.append(next_center)
        available_centers.remove(next_center)

    return selected_centers


# 
# ## With Different Values of t, Seeing How Kmeans Cost Varies

# In[28]:


dataset = images_tsne_2d  # vary the dataset 
delta = 0.1  # vary the value of delta for other results
total_final_cost = 0
total_so_edge_queries = 0
total_so_point_queries = 0
for i in range(5):
    print(f"Iteration {i + 1}:")
    C_i, no_so_edge_queries, no_so_point_queries = perturbed_k_means_plus_plus(dataset, distance_matrices_tsne[0.1], delta, 10)
    
    # Calculate cost and weighted centers
    a, b, center_weights = perturbed_cost(dataset, C_i, distance_matrices_tsne[0.1], delta)
    final_centers = weighted_kmeans(dataset, 10, center_weights)
    
    # Calculate final cost
    final_cost = calculate_cost(dataset, final_centers)
    
    # Update total cost and queries
    total_final_cost += final_cost
    total_so_edge_queries += no_so_edge_queries
    total_so_point_queries += no_so_point_queries
    
    # Output results for this iteration
    print(f"Final cost for iteration {i + 1}: {final_cost}")
    print(f"SO edge queries for iteration {i + 1}: {no_so_edge_queries}")
    print(f"SO point queries for iteration {i + 1}: {no_so_point_queries}")
    print("-" * 50)  # Print separator for clarity

# Compute and output averages
average_final_cost = total_final_cost / 5
average_so_edge_queries = total_so_edge_queries / 5
average_so_point_queries = total_so_point_queries / 5

# Show averages
print('\nAverage final cost over 5 runs =', average_final_cost)
print('Average SO edge queries over 5 runs =', average_so_edge_queries)
print('Average SO point queries over 5runs =', average_so_point_queries)




