import numpy as np
from scipy.optimize import minimize
import pandas as pd
import torch
import torchvision
from torchvision import transforms
from sklearn.metrics.pairwise import cosine_similarity
import csv
import argparse

parser = argparse.ArgumentParser(description='DecGLOC')

parser.add_argument('--l1', type=float, default=0.01)
parser.add_argument('--l2', type=float, default=10)
parser.add_argument('--eps', type=float, default=1)
parser.add_argument('--k', type=int, default=5)


args = parser.parse_args()



def compute_neighbors(features, k):
    num_samples, _, _, _ = features.shape
    reshaped_features = features.reshape(num_samples, -1)
    cosine_similarities = cosine_similarity(reshaped_features)
    k_nearest_indices = np.argsort(cosine_similarities, axis=1)[:, :-k-1:-1]
    return cosine_similarities, k_nearest_indices

def map_to_original_indices(neighbors_indices, original_indices):
    mapped_indices = original_indices[neighbors_indices]
    return mapped_indices

def calculate_neighbor_difference(neighbors1, neighbors2):
    set_neighbors1 = [set(neighbors1[i]) for i in range(neighbors1.shape[0])]
    set_neighbors2 = [set(neighbors2[i]) for i in range(neighbors2.shape[0])]
    
    difference_per_sample = [len(set_neighbors1[i] - set_neighbors2[i]) for i in range(len(set_neighbors1))]
    
    return np.array(difference_per_sample)



def objective(values_array, values_delete, S,  eps_per_sample, eps_mean,  neighbors_indices, l1, l2):
    obj1 = np.sum([S[i, j] * (values_array[i] - values_array[j])**2 for i in range(len(values_array)) for j in neighbors_indices[i]])
    obj2 = l1 * np.sum([(eps_per_sample[i] / eps_mean) * (values_delete[i] - values_array[i])**2 for i in range(len(values_delete[:-1]))])
    obj3 = l2 * np.linalg.norm(values_array)
    return obj1 + obj2 + obj3



original_data_values = pd.read_csv("")

original_index = original_data_values['indices'].values
original_values = original_data_values['data_values'].values

index_array = np.array(original_index)
values_array = np.array(original_values)


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)


features = []
labels = []
for index in index_array:
    feature, label = trainset[index]
    features.append(feature)
    labels.append(label)

features_tensor = torch.stack(features)
labels_tensor = torch.tensor(labels)


random_delete_index =  np.random.choice(index_array)
delete_location = np.where(index_array == random_delete_index)[0]


k = args.k

cosine_similarity_wanquan, neighbors_indices = compute_neighbors(features_tensor, k) 

mapped_indices = map_to_original_indices(neighbors_indices, index_array)
print(mapped_indices)



new_features = features.copy()  
del new_features[delete_location.item()]
features_delete_tensor = torch.stack(new_features)

new_labels = labels.copy()  
del new_labels[delete_location.item()]
labels_delete_tensor = torch.tensor(new_labels)

index_delete = index_array[index_array != random_delete_index]


cosine_similarity_new, neighbors_indices_new = compute_neighbors(features_delete_tensor, k) 
# print(neighbors_indices_new) # 500*10

mapped_indices_new = map_to_original_indices(neighbors_indices_new, index_delete)
print(mapped_indices_new)


difference_per_sample = calculate_neighbor_difference(np.delete(mapped_indices, delete_location.item(), axis=0), mapped_indices_new)

print(difference_per_sample)

eps_per_sample = (len(index_array)/len(index_delete)) * (1+difference_per_sample/10) * args.eps

print(eps_per_sample)

eps_mean = np.mean(eps_per_sample)
labels_array = np.array(labels)
equal_labels = (labels_array[:, np.newaxis] == labels_array[np.newaxis, :]).astype(int)
Sim = cosine_similarity_wanquan * (2 * equal_labels - 1)

sum_fenzi = 0
sum_fenmu = 0
for k in neighbors_indices_new[:,-1]:
    sum_fenzi = sum_fenzi + Sim[-1,k] * values_array[k]
    sum_fenmu = sum_fenmu + Sim[-1,k]
new_beta = sum_fenzi/sum_fenmu
# values_added = np.append(values_array, new_beta)
# values_delete = np.delete(values_array, delete_location.item(), axis=0)
values_delete = values_array


# initial_guess = values_added

result = minimize(objective, values_array, args=(values_delete, Sim,  eps_per_sample, eps_mean,  neighbors_indices, args.l1, args.l2))



optimized_beta = result.x
print("Optimized beta values:", optimized_beta)
print(optimized_beta.shape)

with open("cifar10-decgloc"+str(args.eps)+".csv", 'w', newline='') as csvfile:
    csvwriter = csv.writer(csvfile)
    

    for item in optimized_beta:
        csvwriter.writerow([item])

print("Optimized objective value:", result.fun)




