#!/usr/bin/env python
# coding: utf-8

# In[3]:


import pandas as pd
from sklearn.decomposition import TruncatedSVD
from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import random
import seaborn as sns 
import math
from math import log
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D
import sys
import os
import pickle
from tqdm.notebook import tqdm 
import logging
from concurrent.futures import ProcessPoolExecutor, as_completed
from concurrent.futures import ThreadPoolExecutor, as_completed


# 
# # Generating SBM Data

# In[6]:


def sbm_data(k, d, total_points):
    base_points = total_points // k
    extra_points = total_points % k  
    points_per_cluster = [base_points] * k
    points_per_cluster[-1] += extra_points  
    print("Number of points in the last cluster:", points_per_cluster[-1]) 
    clusters = []
    for i in range(k):
        mu = np.zeros(d)
        mu[i] = 10**5  
        points = np.random.multivariate_normal(mean=mu, cov=np.eye(d), size=points_per_cluster[i])
        clusters.append(points)

    data = np.vstack(clusters)
    labels = np.repeat(np.arange(k), points_per_cluster)

    return data, labels

data, labels = sbm_data(7, 7, 10000)
print("Distance between data[4] and data[2000]:", np.linalg.norm(data[101] - data[100]))
print("Labels:\n", labels)
pca = PCA(n_components=3)
reduced_data = pca.fit_transform(data)

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
for i in range(7):  
    cluster_points = reduced_data[labels == i]
    ax.scatter(cluster_points[:, 0], cluster_points[:, 1], cluster_points[:, 2], label=f'Cluster {i}')

# Labels and title
ax.set_xlabel('PCA Component 1')
ax.set_ylabel('PCA Component 2')
ax.set_zlabel('PCA Component 3')
ax.set_title('3D PCA Visualization of 7D SBM Data Points with 7 Clusters')
ax.legend()
plt.show()


# 
# # Loading Distance Matrix

# In[9]:


delta=0.3
# Initialize empty dictionary to store matrices
distance_matrices_tsne = {}
tsne_path = "/home/pinki/sbm_10k/distance_matrix_10000_delta_0.3.npy"
distance_matrices_tsne[delta] = np.load(tsne_path)


# # Strong Baseline

# In[10]:


def farthest_point_traversal_kcenter(dataset, k, max_iterations):
    n_points = len(dataset)
    if k > n_points:
        raise ValueError("Number of centers (k) cannot exceed the number of points in the dataset.")

    all_max_radii = []

    for _ in range(max_iterations):
        centers = []
        first_center_idx = np.random.randint(n_points)
        centers.append(int(first_center_idx))  # Convert to plain integer

        min_distances = np.full(n_points, np.inf)

        for _ in range(1, k):
            for i in range(n_points):
                current_point = dataset[i]
                for center_idx in centers:
                    center_point = dataset[center_idx]
                    distance = np.linalg.norm(current_point - center_point)
                    min_distances[i] = min(min_distances[i], distance)

            largest_min_dist_point = np.argmax(min_distances)
            centers.append(int(largest_min_dist_point))  # Convert to plain integer

        clusters = []
        for _ in range(k):
            clusters.append([])

        max_radius = 0
        
        for i in range(n_points):
            current_point = dataset[i]
            nearest_center_idx = None
            nearest_distance = np.inf

            for center_idx in centers:
                center_point = dataset[center_idx]
                distance = np.linalg.norm(current_point - center_point)
                if distance < nearest_distance:
                    nearest_distance = distance
                    nearest_center_idx = centers.index(center_idx)

            clusters[nearest_center_idx].append(i)
            max_radius = max(max_radius, nearest_distance)

        all_max_radii.append(max_radius)
        print(centers)

    return sum(all_max_radii) / max_iterations


# # Cost of Strong Baseline

# In[13]:


dataset= data
max_iterations=20
baseline_cost= farthest_point_traversal_kcenter(dataset, 7, max_iterations)
print('Baseline Cost', baseline_cost)


# 
# # Starting Weak-Greedy Ball Carving

# In[16]:


def sample_initial_centers(S, radius,dataset):
    """
    Params:
    S : List of points to be covered
    radius : Threshold serving as radius
    true_distances : Function that takes two points and returns their true_distance

    Output:
    centers : Dict mapping centers to points covered in S
    """

    centers = {}  
    no_so_queries=0
    while len(S) > 0:
        center_i = S.pop(0)  
        centers[center_i] = []  

        for point in S[:]:  
            if np.linalg.norm(dataset[point]-dataset[center_i]) <= radius:
                no_so_queries+=1
                centers[center_i].append(point)
                S.remove(point)  

    return centers,no_so_queries


# 
# # Proxy for distance between x and y d(x,y)

# In[19]:


def perturbed_closest_min(x, C_i, dataset, perturbed_distances):
    n = len(dataset)
    no_queries=0
    distances = []
    for c in C_i:
        distances.append((perturbed_distances[x, c], c))
        no_queries+=1

    distances.sort(key=lambda item: item[0])
    median_index = len(distances) // 2
    median_distance, corresponding_center = distances[median_index]

    return median_distance, corresponding_center,no_queries


# 
# # Weak-Greedy Ball Carving

# In[22]:


def k_center_with_oracle(dataset, k, R0, epsilon, perturbed_distances, delta, S_size):
    n = len(dataset)
    p = (0.5 - delta)
    no_so_edge_queries = 0
    no_wo_queries = 0
    no_so_point_queries = 0
    t = int(math.log(n) / epsilon)
    Y = list(range(n))
    #S_size = int((k * math.log(len(Y))) / p)
    S = np.random.randint(0, len(Y), size=S_size)
    T=S
    no_so_point_queries += S_size
    
    for l in range(t):  
        R = R0 * (1 + epsilon) ** l
        print("Radius:", R)
        clusters = []
        cluster_centers = []
        Y = list(range(n))
        S = T
        
        while len(Y) > 0:
            if len(cluster_centers) > k:
                break
            if len(Y) <= S_size:
                centers_remaining, returned_so_queries = sample_initial_centers(Y, 2 * R, dataset)
                no_so_edge_queries += returned_so_queries
                no_so_point_queries += len(Y)
                
                if (len(centers_remaining) + len(cluster_centers)) <= k:
                    for center, points in centers_remaining.items():
                        clusters.append(points)
                        cluster_centers.append(center)
                    return R, no_so_edge_queries, no_wo_queries, no_so_point_queries
                else:
                    break  

            complete_balls = []
            for c in S:
                T_c = []
                for t in S:
                    if np.linalg.norm(dataset[c] - dataset[t]) <= 2 * R:
                        no_so_edge_queries += 1
                        T_c.append(t)

                if len(T_c) >= int(math.log(len(Y)) / p):
                    complete_balls.append((c, T_c))
                
            num_complete_balls = len(complete_balls)
            print('number of complete balls:', num_complete_balls)

            if num_complete_balls ==0 :
                break  

            complete_balls.sort(key=lambda x: len(x[1]), reverse=True)
            centers = []
            for center, _ in complete_balls:
                centers.append(center)

            center, ball = complete_balls[0]

            assigned_points = []
            for y in Y:
                dist, assigned_center, wo_queries = perturbed_closest_min(y, ball, dataset, perturbed_distances)
                no_wo_queries += wo_queries
                if dist <= 4 * R:
                    assigned_points.append(y)

            print('number of assigned points',len(assigned_points))

            clusters.append(assigned_points)
            cluster_centers.append(center)
            S = list(set(S) - set(assigned_points))
            Y = list(set(Y) - set(assigned_points))
            print('remaining points to be clustered', len(Y))

        if (len(Y)==0):
            return R, no_so_edge_queries, no_wo_queries, no_so_point_queries


# 
# # Cost of Weak Greedy Ball Carving

# In[35]:


k = 7
R0 = 1.0
epsilon = 0.1
n = len(dataset)
log_n = math.log(n)
delta = 0.3
p = (0.5 - delta)
perturbed_distances = distance_matrices_tsne[delta]
S_size = k*int(log_n / (p ** 2))
total_R = 0
total_so_edge_queries = 0
total_so_point_queries = 0

for i in range(5):
    R, no_so_edge_queries, no_wo_queries, no_so_point_queries = k_center_with_oracle(dataset, k, R0, epsilon, perturbed_distances, delta, S_size)
    
    total_R += R
    total_so_edge_queries += no_so_edge_queries
    total_so_point_queries += no_so_point_queries

    # Print results of current iteration
    print(f"Iteration {i + 1}:")
    print(f"  - Cost (4 * R): {4 * R}")
    print(f"  - SO edge queries: {no_so_edge_queries}")
    print(f"  - SO point queries: {no_so_point_queries}")
    print()

# Averages after all iterations
avg_R = total_R / 5
avg_so_edge_queries = total_so_edge_queries / 5
avg_so_point_queries = total_so_point_queries / 5

print("Average over 10 iterations:")
print(f"  - Avg cost of k-center (4 * R): {4 * avg_R}")
print(f"  - Avg SO edge queries: {avg_so_edge_queries}")
print(f"  - Avg SO point queries: {avg_so_point_queries}")


# In[ ]:




