#!/usr/bin/env python
# coding: utf-8

# In[3]:


import pandas as pd
from sklearn.decomposition import TruncatedSVD
from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import random
import seaborn as sns 
import math
from math import log
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D
import sys
import os
import pickle
from tqdm.notebook import tqdm 
import logging
from concurrent.futures import ProcessPoolExecutor, as_completed
from concurrent.futures import ThreadPoolExecutor, as_completed


# 
# # Applying t-SNE and SVD to mnist

# In[4]:


data = pd.read_csv('/home/pinki/mnist_train.csv')
labels = data.iloc[:, 0].values
images = data.iloc[:, 1:].values

scaler = MinMaxScaler()
images_scaled = scaler.fit_transform(images)

svd = TruncatedSVD(n_components=50, random_state=42)
images_svd_50d = svd.fit_transform(images_scaled)
 
tsne = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200, n_iter=1000)
images_tsne_2d = tsne.fit_transform(images_scaled)

label_indices = {}
for idx, label in enumerate(labels):
    if label not in label_indices:
        label_indices[label] = []
    label_indices[label].append(idx)

plt.figure(figsize=(8, 4))
custom_palette = sns.color_palette("tab10", 10)

sns.scatterplot(
    x=images_tsne_2d[:, 0],
    y=images_tsne_2d[:, 1],
    hue=labels,
    palette=custom_palette,
    s=60,
    alpha=0.7
)
plt.title("2D t-SNE Visualization", fontsize=16)
plt.xlabel("t-SNE Component 1")
plt.ylabel("t-SNE Component 2")
plt.legend(title="Digit", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.show()
print("Dictionary with labels as keys and indices of corresponding samples (preview):")
for key in list(label_indices.keys()):
    print("Label:", key, "Indices:", label_indices[key][:5], "...")


# 
# # Loading Distance Matrix

# In[5]:


distance_matrices_tsne = {}
tsne_path = "/home/pinki/mnist_matrices/distance_matrix_tsne_delta_0.3.npy"
distance_matrices_tsne[0.1] = np.load(tsne_path)


# # Strong Baseline

# In[7]:


def farthest_point_traversal_kcenter(dataset, k, max_iterations):
    n_points = len(dataset)
    if k > n_points:
        raise ValueError("Number of centers (k) cannot exceed the number of points in the dataset.")

    all_max_radii = []

    for _ in range(max_iterations):
        centers = []
        first_center_idx = np.random.randint(n_points)
        centers.append(int(first_center_idx))  # Convert to plain integer

        min_distances = np.full(n_points, np.inf)

        for _ in range(1, k):
            for i in range(n_points):
                current_point = dataset[i]
                for center_idx in centers:
                    center_point = dataset[center_idx]
                    distance = np.linalg.norm(current_point - center_point)
                    min_distances[i] = min(min_distances[i], distance)

            largest_min_dist_point = np.argmax(min_distances)
            centers.append(int(largest_min_dist_point))  # Convert to plain integer

        clusters = []
        for _ in range(k):
            clusters.append([])

        max_radius = 0
        
        for i in range(n_points):
            current_point = dataset[i]
            nearest_center_idx = None
            nearest_distance = np.inf

            for center_idx in centers:
                center_point = dataset[center_idx]
                distance = np.linalg.norm(current_point - center_point)
                if distance < nearest_distance:
                    nearest_distance = distance
                    nearest_center_idx = centers.index(center_idx)

            clusters[nearest_center_idx].append(i)
            max_radius = max(max_radius, nearest_distance)

        all_max_radii.append(max_radius)
        print(centers)

    return sum(all_max_radii) / max_iterations


# # Strong Baseline Cost

# In[6]:


dataset=images_tsne_2d
max_iterations=10
baseline_cost= farthest_point_traversal_kcenter(dataset, 10, max_iterations)
print('Baseline Cost', baseline_cost)


# 
# # Starting Weak-Greedy Ball Carving

# In[8]:


def sample_initial_centers(S, radius,dataset):
    """
    Params:
    S : List of points to be covered
    radius : Threshold serving as radius
    true_distances : Function that takes two points and returns their true_distance

    Output:
    centers : Dict mapping centers to points covered in S
    """

    centers = {}  
    no_so_queries=0
    while len(S) > 0:
        center_i = S.pop(0)  
        centers[center_i] = []  

        for point in S[:]:  
            if np.linalg.norm(dataset[point]-dataset[center_i]) <= radius:
                no_so_queries+=1
                centers[center_i].append(point)
                S.remove(point)  

    return centers,no_so_queries


# 
# # Proxy for distance between x and y d(x,y)

# In[9]:


def perturbed_closest_min(x, C_i, dataset, perturbed_distances):
    n = len(dataset)
    no_queries=0
    distances = []
    for c in C_i:
        distances.append((perturbed_distances[x, c], c))
        no_queries+=1

    distances.sort(key=lambda item: item[0])
    median_index = len(distances) // 2
    median_distance, corresponding_center = distances[median_index]

    return median_distance, corresponding_center,no_queries


# 
# # Weak-Greedy Ball Carving(k-center with weak-strong oracle)

# In[10]:


def k_center_with_oracle(dataset, k, R0, epsilon, perturbed_distances, delta, S_size):
    n = len(dataset)
    p = (0.5 - delta)
    no_so_edge_queries = 0
    no_wo_queries = 0
    no_so_point_queries = 0
    t = int(math.log(n) / epsilon)
    Y = list(range(n))
    #S_size = int((k * math.log(len(Y))) / p)
    S = np.random.randint(0, len(Y), size=S_size)
    T=S
    no_so_point_queries += S_size
    
    for l in range(t):  
        R = R0 * (1 + epsilon) ** l
        print("Radius:", R)
        clusters = []
        cluster_centers = []
        Y = list(range(n))
        S = T
        
        while len(Y) > 0:
            if len(cluster_centers) > k:
                break
            if len(Y) <= S_size:
                centers_remaining, returned_so_queries = sample_initial_centers(Y, 2 * R, dataset)
                no_so_edge_queries += returned_so_queries
                no_so_point_queries += len(Y)
                
                if (len(centers_remaining) + len(cluster_centers)) <= k:
                    for center, points in centers_remaining.items():
                        clusters.append(points)
                        cluster_centers.append(center)
                    return R, no_so_edge_queries, no_wo_queries, no_so_point_queries
                else:
                    break  

            complete_balls = []
            for c in S:
                T_c = []
                for t in S:
                    if np.linalg.norm(dataset[c] - dataset[t]) <= 2 * R:
                        no_so_edge_queries += 1
                        T_c.append(t)

                if len(T_c) >= int(math.log(len(Y)) / p):
                    complete_balls.append((c, T_c))
                
            num_complete_balls = len(complete_balls)
            print('number of complete balls:', num_complete_balls)

            if num_complete_balls ==0 :
                break  

            complete_balls.sort(key=lambda x: len(x[1]), reverse=True)
            centers = []
            for center, _ in complete_balls:
                centers.append(center)

            center, ball = complete_balls[0]

            assigned_points = []
            for y in Y:
                dist, assigned_center, wo_queries = perturbed_closest_min(y, ball, dataset, perturbed_distances)
                no_wo_queries += wo_queries
                if dist <= 4 * R:
                    assigned_points.append(y)

            print('number of assigned points',len(assigned_points))

            clusters.append(assigned_points)
            cluster_centers.append(center)
            S = list(set(S) - set(assigned_points))
            Y = list(set(Y) - set(assigned_points))
            print('remaining points to be clustered', len(Y))

        if (len(Y)==0):
            return R, no_so_edge_queries, no_wo_queries, no_so_point_queries


# 
# # Cost of Weak Greedy Ball Carving

# In[21]:


k = 10
R0 = 1.0
epsilon = 0.1
dataset = images_tsne_2d
n = len(dataset)
log_n = math.log(n)
delta = 0.3
p = (0.5 - delta)
perturbed_distances = distance_matrices_tsne[0.1]
S_size = k*int(log_n / (p ** 2))
total_R = 0
total_so_edge_queries = 0
total_so_point_queries = 0

for i in range(5):
    R, no_so_edge_queries, no_wo_queries, no_so_point_queries = k_center_with_oracle(dataset, k, R0, epsilon, perturbed_distances, delta, S_size)
    
    total_R += R
    total_so_edge_queries += no_so_edge_queries
    total_so_point_queries += no_so_point_queries

    # Print results of current iteration
    print(f"Iteration {i + 1}:")
    print(f"  - Cost (4 * R): {4 * R}")
    print(f"  - SO edge queries: {no_so_edge_queries}")
    print(f"  - SO point queries: {no_so_point_queries}")
    print()

# Averages after all iterations
avg_R = total_R / 5
avg_so_edge_queries = total_so_edge_queries / 5
avg_so_point_queries = total_so_point_queries / 5

print("Average over 10 iterations:")
print(f"  - Avg cost of k-center (4 * R): {4 * avg_R}")
print(f"  - Avg SO edge queries: {avg_so_edge_queries}")
print(f"  - Avg SO point queries: {avg_so_point_queries}")


# In[ ]:




