#!/usr/bin/env python
# coding: utf-8

# In[3]:


import pandas as pd
from sklearn.decomposition import TruncatedSVD
from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import random
import seaborn as sns 
import math
from math import log
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D
import sys
import os
import pickle
from tqdm.notebook import tqdm 
import logging



# # SBM Data Generation 

# In[2]:


def sbm_data(k, d, total_points):
    base_points = total_points // k
    extra_points = total_points % k  
    points_per_cluster = [base_points] * k
    points_per_cluster[-1] += extra_points  
    print("Number of points in the last cluster:", points_per_cluster[-1]) 

    clusters = []
    for i in range(k):
        mu = np.zeros(d)
        mu[i] = 10**5  
        points = np.random.multivariate_normal(mean=mu, cov=np.eye(d), size=points_per_cluster[i])
        clusters.append(points)

    data = np.vstack(clusters)
    labels = np.repeat(np.arange(k), points_per_cluster)

    return data, labels


data, labels = sbm_data(7, 7, 10000)


print("Distance between data[4] and data[2000]:", np.linalg.norm(data[101] - data[100]))

print("Labels:\n", labels)

pca = PCA(n_components=3)
reduced_data = pca.fit_transform(data)

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
for i in range(7):  
    cluster_points = reduced_data[labels == i]
    ax.scatter(cluster_points[:, 0], cluster_points[:, 1], cluster_points[:, 2], label=f'Cluster {i}')

# Labels and title
ax.set_xlabel('PCA Component 1')
ax.set_ylabel('PCA Component 2')
ax.set_zlabel('PCA Component 3')
ax.set_title('3D PCA Visualization of 7D SBM Data Points with 7 Clusters')
ax.legend()
plt.show()


# In[ ]:


#Generating corrupted distance matrices for SBM data


# In[ ]:


def generate_distance_matrix(dataset, labels, delta):
    size_data = dataset.shape[0]
    perturbed_distances = np.zeros((size_data, size_data))
    
    num_perturbed = 0

    for i in tqdm.tqdm(range(size_data)):
        for j in range(size_data):
            if i == j:
                perturbed_distances[i, j] = 0  
                continue

            random_value = random.uniform(0, 1)  

            if random_value < 1 - delta:
                
                perturbed_distances[i, j] = np.linalg.norm(dataset[i] - dataset[j])
            else:
                if labels[i] == labels[j]:
                    perturbed_distances[i, j] = 10**5  
                else:
                    perturbed_distances[i, j] = 1 

                num_perturbed += 1  

   
    total_elements = size_data * size_data
    fraction_perturbed = num_perturbed / total_elements

    print("Fraction of perturbed points: ", fraction_perturbed)

    return perturbed_distances


# # Saving the generated matrices for future use

# In[ ]:


total_points_list = [10000, 20000, 50000, 100000]
delta_values = [0.1, 0.2, 0.3]
k = 7
d = 7

save_dir = "/home/pinki/miniconda3/"
os.makedirs(save_dir, exist_ok=True)

for total_points in total_points_list:
    print(f"Generating and saving dataset for {total_points} points...")
    data, labels = sbm_data(k, d, total_points)
    np.save(os.path.join(save_dir, f"dataset_{total_points}.npy"), data)

    for delta in delta_values:
        print(f"Generating distance matrix for delta = {delta}...")
        distance_matrix = generate_distance_matrix(data, labels, delta)
        np.save(os.path.join(save_dir, f"distance_matrix_{total_points}_delta_{delta}.npy"), distance_matrix)

print("Files saved successfully.")


# # Loading the distance matrices

# In[ ]:


save_dir = "/home/pinki/miniconda3/"
total_points_list = [10000, 20000, 50000, 100000]
delta_values = [0.1, 0.2, 0.3]

datasets = {}
distance_matrices = {}

for total_points in total_points_list:
    print(f"Loading dataset for {total_points} points...")
    dataset_filename = os.path.join(save_dir, f"dataset_{total_points}.npy")
    data = np.load(dataset_filename)
    datasets[total_points] = data
    
    distance_matrices[total_points] = {}
    for delta in delta_values:
        if total_points == 100000 and delta != 0.1:
            continue  # Skip loading other distance matrices for 10000 points
        
        print(f"Loading distance matrix for delta = {delta} and {total_points} points...")
        distance_matrix_filename = os.path.join(save_dir, f"distance_matrix_{total_points}_delta_{delta}.npy")
        distance_matrix = np.load(distance_matrix_filename)
        distance_matrices[total_points][delta] = distance_matrix

print("Files loaded successfully.")


