#!/usr/bin/env python
# coding: utf-8

# In[90]:


import numpy as np
import matplotlib.pyplot as plt
import random
import seaborn as sns
import math
from math import log
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D
import sys
import os
import pickle
import tqdm
from itertools import combinations


# # Applying t-SNE and SVD to mnist

# In[ ]:


data = pd.read_csv('mnist_train.csv')
labels = data.iloc[:, 0].values
images = data.iloc[:, 1:].values

scaler = MinMaxScaler()
images_scaled = scaler.fit_transform(images)

svd = TruncatedSVD(n_components=50, random_state=42)
images_svd_50d = svd.fit_transform(images_scaled)
 
tsne = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200, n_iter=1000)
images_tsne_2d = tsne.fit_transform(images_svd_50d)

label_indices = {}
for idx, label in enumerate(labels):
    if label not in label_indices:
        label_indices[label] = []
    label_indices[label].append(idx)

plt.figure(figsize=(8, 4))
custom_palette = sns.color_palette("tab10", 10)

sns.scatterplot(
    x=images_tsne_2d[:, 0],
    y=images_tsne_2d[:, 1],
    hue=labels,
    palette=custom_palette,
    s=60,
    alpha=0.7
)
plt.title("2D t-SNE Visualization", fontsize=16)
plt.xlabel("t-SNE Component 1")
plt.ylabel("t-SNE Component 2")
plt.legend(title="Digit", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.show()
print("Dictionary with labels as keys and indices of corresponding samples (preview):")
for key in list(label_indices.keys()):
    print("Label:", key, "Indices:", label_indices[key][:5], "...")


# # Function to Generate Corrupted Distance Matrices

# In[ ]:


logging.basicConfig(filename='distance_matrix_generation.log', level=logging.INFO)

def generate_distance_matrix(dataset, labels, delta, label_indices):
    size_data = dataset.shape[0]
    perturbed_distances = np.zeros((size_data, size_data))
    num_perturbed = 0

    for i in tqdm(range(size_data), mininterval=1, maxinterval=10, desc="Processing"):
        for j in range(size_data):
            if i == j:
                perturbed_distances[i, j] = 0
                continue

            random_value = random.uniform(0, 1)

            if random_value < 1 - delta:
                perturbed_distances[i, j] = np.linalg.norm(dataset[i] - dataset[j])
            else:
                if labels[i] == labels[j]:
                    other_labels = []
                    for key in label_indices.keys():
                        if key != labels[i]:
                            other_labels.append(key)

                    if other_labels:
                        chosen_label = random.choice(other_labels)
                        chosen_idx = random.choice(label_indices[chosen_label])
                        perturbed_distances[i, j] = np.linalg.norm(dataset[i] - dataset[chosen_idx])
                else:
                    other_labels = []
                    for key in label_indices.keys():
                        if key != labels[i] and key != labels[j]:
                            other_labels.append(key)

                    if other_labels:
                        chosen_label = random.choice(other_labels)
                        chosen_indices = random.sample(label_indices[chosen_label], 2)
                        perturbed_distances[i, j] = np.linalg.norm(dataset[chosen_indices[0]] - dataset[chosen_indices[1]])

                num_perturbed += 1

        # Log the progress periodically
        if i % 100 == 0:
            logging.info(f"Processed {i}/{size_data} rows.")

    total_elements = size_data * size_data
    fraction_perturbed = num_perturbed / total_elements
    logging.info(f"Fraction of perturbed points: {fraction_perturbed}")

    return perturbed_distances


# # Saving Corrupted Matrices

# In[ ]:


delta=0.1
print("Generating distance matrix for images_tsne_2d...")
distance_matrix_tsne = generate_distance_matrix(images_tsne_2d, labels, delta, label_indices)
tsne_path = '/home/pinki/miniconda3/distance_matrix_tsne.npy'
np.save(tsne_path, distance_matrix_tsne)
print(f"Distance matrix for t-SNE saved to {tsne_path}")


# In[ ]:


deltas = [0.1, 0.2, 0.3]  # Define the delta values
svd_base_path = '/home/pinki/miniconda3/distance_matrix_svd'

for delta in deltas:
    print(f"Generating distance matrix for images_svd_50d with delta={delta}...")
    distance_matrix_svd = generate_distance_matrix(images_svd_50d, labels, delta, label_indices)
    svd_path = f"{svd_base_path}_delta_{delta}.npy"  # Append delta value to the filename
    np.save(svd_path, distance_matrix_svd)
    print(f"Distance matrix for delta={delta} saved to {svd_path}")


# In[ ]:


deltas = [0.2, 0.3]  # Define the delta values
tsne_base_path = '/home/pinki/miniconda3/distance_matrix_tsne'

for delta in deltas:
    print(f"Generating distance matrix for images_tsne_2d with delta={delta}...")
    distance_matrix_tsne = generate_distance_matrix(images_tsne_2d, labels, delta, label_indices)
    tsne_path = f"{tsne_base_path}_delta_{delta}.npy"  # Append delta value to the filename
    np.save(tsne_path, distance_matrix_tsne)
    print(f"Distance matrix for delta={delta} saved to {tsne_path}")


# # Loading the Corrupted Matrices

# In[ ]:


tsne_base_path = '/home/pinki/miniconda3/distance_matrix_tsne'
svd_base_path = '/home/pinki/miniconda3/distance_matrix_svd'

deltas_tsne = [0.1, 0.2, 0.3]  
deltas_svd = [0.1, 0.2, 0.3]

distance_matrices_tsne = {}
distance_matrices_svd = {}

for delta in deltas_tsne:
    tsne_path = f"{tsne_base_path}_delta_{delta}.npy" if delta != 0.1 else f"{tsne_base_path}.npy"
    distance_matrices_tsne[delta] = np.load(tsne_path)
    print(f"Loaded distance matrix for t-SNE with delta={delta} from {tsne_path}")

# Load svd distance matrices
for delta in deltas_svd:
    svd_path = f"{svd_base_path}_delta_{delta}.npy"
    distance_matrices_svd[delta] = np.load(svd_path)
    print(f"Loaded distance matrix for SVD with delta={delta} from {svd_path}")




