#!/usr/bin/env python
# coding: utf-8

# In[6]:


import pandas as pd
from sklearn.decomposition import TruncatedSVD
from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import random
import seaborn as sns 
import math
from math import log
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D
import sys
import os
import pickle
from tqdm.notebook import tqdm 
import logging
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import as_completed





# # SBM Data Generation 

# In[7]:


def sbm_data(k, d, total_points):
    base_points = total_points // k
    extra_points = total_points % k  
    points_per_cluster = [base_points] * k
    points_per_cluster[-1] += extra_points  
    print("Number of points in the last cluster:", points_per_cluster[-1]) 

    clusters = []
    for i in range(k):
        mu = np.zeros(d)
        mu[i] = 10**5  
        points = np.random.multivariate_normal(mean=mu, cov=np.eye(d), size=points_per_cluster[i])
        clusters.append(points)

    data = np.vstack(clusters)
    labels = np.repeat(np.arange(k), points_per_cluster)

    return data, labels


data, labels = sbm_data(7, 7, 50000)


print("Distance between data[4] and data[2000]:", np.linalg.norm(data[101] - data[100]))

print("Labels:\n", labels)

pca = PCA(n_components=3)
reduced_data = pca.fit_transform(data)

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
for i in range(7):  
    cluster_points = reduced_data[labels == i]
    ax.scatter(cluster_points[:, 0], cluster_points[:, 1], cluster_points[:, 2], label=f'Cluster {i}')

# Labels and title
ax.set_xlabel('PCA Component 1')
ax.set_ylabel('PCA Component 2')
ax.set_zlabel('PCA Component 3')
ax.set_title('3D PCA Visualization of 7D SBM Data Points with 7 Clusters')
ax.legend()
plt.show()


# # Loading Distance Matrices for SBM data

# In[10]:


dataset=data
sbm_50k_path = "/home/pinki/sbm_50k/distance_matrix_50000_delta_0.3.npy"
perturbed_distances = np.load(sbm_50k_path)


# # Strong Baseline(k-means++)

# In[8]:


def initialize_kmeans_plus_plus(dataset, k):
    n_points = len(dataset)
    centers = []
    
    first_center_idx = np.random.choice(n_points)
    centers.append(first_center_idx)
    
    for _ in range(k - 1):
        available_points = list(set(range(n_points)) - set(centers))
        distances = []

        for i in available_points:
            min_dist = float('inf')
            for center_idx in centers:
                dist = np.linalg.norm(dataset[i] - dataset[center_idx])  # Ensure dataset[center_idx] is a point
                min_dist = min(min_dist, dist)
            distances.append(min_dist)
        
        probabilities = np.array(distances) ** 2
        next_center_idx = np.random.choice(available_points, p=probabilities / probabilities.sum())
        centers.append(int(next_center_idx))

    return centers


# In[14]:


def calculate_cost(dataset, centers):
    cost = 0
    for i in range(len(dataset)):
        min_distance = float("inf")
        for center in centers:
            dist = np.linalg.norm(dataset[i] - dataset[center]) ** 2
            if dist < min_distance:
                min_distance = dist
        cost += min_distance
    return cost


# # Weak Baseline(k-means++ with Weak Orace)

# In[9]:


def weak_initialize_kmeans_plus_plus(dataset, k):
    n_points = len(dataset)
    centers = []
    
    first_center_idx = np.random.choice(n_points)
    centers.append(first_center_idx)
    
    for _ in range(k - 1):
        available_points = list(set(range(n_points)) - set(centers))
        distances = []

        for i in available_points:
            min_dist = float('inf')
            for center_idx in centers:
                dist = perturbed_distances[i,center_idx] # Ensure dataset[center_idx] is a point
                min_dist = min(min_dist, dist)
            distances.append(min_dist)
        
        probabilities = np.array(distances) ** 2
        next_center_idx = np.random.choice(available_points, p=probabilities / probabilities.sum())
        centers.append(int(next_center_idx))

    return centers


# # Computing Cost of Strong baseline

# In[18]:


total_cost = 0
total_points= 10000  #vary the total_points

for i in range(10):
    centers = initialize_kmeans_plus_plus(dataset, 7)
    actual_cost = calculate_cost(dataset, centers)
    print('Run', i+1, ': actual cost =', actual_cost)
    total_cost += actual_cost

average_cost = total_cost / 10
print('\nAverage actual cost over 10 runs =', average_cost)


# # Computing Cost of Weak Baseline

# In[15]:


total_cost = 0
total_points= 10000  #vary the total_points

for i in range(10):
    centers = weak_initialize_kmeans_plus_plus(dataset, 7)
    actual_cost = calculate_cost(dataset, centers)
    print('Run', i+1, ': actual cost =', actual_cost)
    total_cost += actual_cost

average_cost = total_cost / 10
print('\nAverage actual cost over 10 runs =', average_cost)


# # kmeans in weak strong Oracle Model

# In[11]:


def compute_single_perturbed_median(x, C_i, closest_centers, perturbed_distances):
    """Compute the median of perturbed distances for one point x."""
    min_median = float('inf')
    for c in C_i:
        close_centers = closest_centers[c]
        distances = [perturbed_distances[x, cc] for cc in close_centers]
        mid_index = len(distances) // 2
        median = np.partition(distances, mid_index)[mid_index]
        if median < min_median:
            min_median = median
    return x, min_median

def compute_initial_closest_centers(C_i, dataset, threshold):
    """Compute closest centers among the initial threshold centers."""
    maintained_distances = {}
    closest_centers = {}
    for c in C_i:
        distances = []
        for other_c in C_i:
            if c == other_c:
                continue
            dist = np.linalg.norm(dataset[c] - dataset[other_c])
            distances.append((dist, other_c))
        distances.sort()
        maintained_distances[c] = distances
        closest_centers[c] = [idx for (_, idx) in distances[:threshold]]
    return maintained_distances, closest_centers

def update_closest_centers(C_i, dataset, threshold):
    """Update closest centers after a new center is added."""
    maintained_distances = {}
    closest_centers = {}
    for c in C_i:
        distances = []
        for other_c in C_i:
            if c == other_c:
                continue
            dist = np.linalg.norm(dataset[c] - dataset[other_c])
            distances.append((dist, other_c))
        distances.sort()
        maintained_distances[c] = distances
        closest_centers[c] = [idx for (_, idx) in distances[:threshold]]
    return maintained_distances, closest_centers

def perturbed_k_means_plus_plus(dataset, perturbed_distances, delta, k):
    batch_size=2000
    num_threads=64
    n = len(dataset)
    log_n = np.log(n)
    p = 0.5 - delta
    threshold = 20
    t = 87

    C_i = random.sample(range(n), threshold)
    remaining_indices = list(set(range(n)) - set(C_i))

    no_so_edge_queries = 0
    no_so_point_queries = threshold

    # Initial closest centers
    maintained_distances, closest_centers = compute_initial_closest_centers(C_i, dataset, threshold)

    print(f" Starting k-means++ center selection with {t} iterations using {num_threads} threads...")

    for iteration in tqdm(range(t)):
        perturbed_medians_dict = {}

        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            futures = []
            for batch_start in range(0, len(remaining_indices), batch_size):
                batch = remaining_indices[batch_start:batch_start + batch_size]
                for x in batch:
                    future = executor.submit(compute_single_perturbed_median, x, C_i, closest_centers, perturbed_distances)
                    futures.append(future)

            # Track progress with tqdm
            for future in as_completed(futures):
                x, min_median = future.result()
                perturbed_medians_dict[x] = min_median

        # Step 4: Sample a new center
        sorted_indices = list(perturbed_medians_dict.keys())
        squared_distances = np.array([perturbed_medians_dict[i] ** 2 for i in sorted_indices])
        total = squared_distances.sum()

        if total == 0:
            sampled_index = random.choice(sorted_indices)
        else:
            probabilities = squared_distances / total
            sampled_index = np.random.choice(sorted_indices, p=probabilities)

        C_i.append(sampled_index)
        remaining_indices.remove(sampled_index)

        no_so_edge_queries += no_so_point_queries
        no_so_point_queries += 1

        # Step 5: Update closest centers
        maintained_distances, closest_centers = update_closest_centers(C_i, dataset, threshold)

    print("Selection process completed.")

    return C_i, no_so_edge_queries, no_so_point_queries


# # For each cenetr in C_i perturbed_cost assign a weight and return weighted instance center_weights

# In[12]:


def perturbed_cost(dataset, C_i, perturbed_distances,delta):
    n = len(dataset)
    log_n = np.log(n)
    p = (0.5-delta)
    threshold = int(log_n / p)
    total_perturbed_kmeans_cost = 0
    total_perturbed_kmedian_cost = 0
    center_weights = {}
    assignment = {}

    maintained_distances = {}
    closest_centers = {}

    for c in C_i:
        distances = []
        for other_c in C_i:
            distance = np.linalg.norm(dataset[other_c] - dataset[c])
            distances.append((distance, other_c))
        distances.sort(key=lambda item: item[0])
        maintained_distances[c] = distances

        closest_centers[c] = []
        for i in range(min(threshold, len(distances))):
            closest_centers[c].append(distances[i][1])

    for x in range(len(dataset)):
        min_median_distance = float('inf')
        assigned_center = None

        for c in C_i:
            distances = []
            for close_c in closest_centers[c]:
                perturbed_distance = perturbed_distances[x, close_c]
                distances.append(perturbed_distance)
            distances.sort()
            median_distance = distances[len(distances) // 2]

            if median_distance < min_median_distance:
                min_median_distance = median_distance
                assigned_center = c

        total_perturbed_kmeans_cost += min_median_distance ** 2
        total_perturbed_kmedian_cost += min_median_distance

        if assigned_center not in center_weights:
            center_weights[assigned_center] = 0
        center_weights[assigned_center] += 1

        assignment[x] = assigned_center

    return total_perturbed_kmeans_cost, total_perturbed_kmedian_cost, center_weights


# # Running weighted k-means++ on the weighted instance(center_wights)

# In[13]:


def weighted_kmeans(dataset, k, center_weights):
    available_centers = list(center_weights.keys())
    selected_centers = []

    first_center = np.random.choice(available_centers, p=np.array(list(center_weights.values())) / sum(center_weights.values()))
    selected_centers.append(first_center)
    available_centers.remove(first_center)

    for _ in range(k - 1):
        min_distances = []

        for center in available_centers:
            min_distance = float("inf")

            for chosen_center in selected_centers:
                dist = np.linalg.norm(dataset[center] - dataset[chosen_center]) ** 2
                weighted_distance = center_weights[chosen_center] * dist
                min_distance = min(min_distance, weighted_distance)

            min_distances.append(min_distance)

        min_distances = np.array(min_distances)
        
        if len(available_centers)==0:
            return selected_centers
        probabilities = min_distances / np.sum(min_distances)
        next_center = np.random.choice(available_centers, p=probabilities)
        
        selected_centers.append(next_center)
        available_centers.remove(next_center)

    return selected_centers


# # Cost of kmeans in Weak Strong Oracle Model 

# In[16]:


total_points = 50000  # vary the total_points
delta= 0.3       # vary the value of delta for other results
total_final_cost = 0
total_so_edge_queries = 0
total_so_point_queries = 0

for i in range(5):
    print(f"\n--- Iteration {i+1} ---")

    # Run the perturbed K-means++
    C_i, no_so_edge_queries, no_so_point_queries = perturbed_k_means_plus_plus(dataset, perturbed_distances, delta, 7)
    
    # Compute perturbed cost and center weights
    a, b, center_weights = perturbed_cost(dataset, C_i, perturbed_distances, delta)
    
    # Final clustering using weighted K-means
    final_centers = weighted_kmeans(dataset, 7, center_weights)
    
    # Compute the cost of final clustering
    final_cost = calculate_cost(dataset, final_centers)
    
    # Accumulate totals
    total_final_cost += final_cost
    total_so_edge_queries += no_so_edge_queries
    total_so_point_queries += no_so_point_queries

    # Print iteration details
    print(f"Final cost: {final_cost}")
    print(f"SO edge queries: {no_so_edge_queries}")
    print(f"SO point queries: {no_so_point_queries}")

# Compute averages
average_final_cost = total_final_cost / 5
average_so_edge_queries = total_so_edge_queries / 5
average_so_point_queries = total_so_point_queries / 5

# Print averages
print('\n=== Averages over 5 runs ===')
print('Average final cost =', average_final_cost)
print('Average SO edge queries =', average_so_edge_queries)
print('Average SO point queries =', average_so_point_queries)


# In[ ]:




