import json
import torch
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from PIL import Image
import os
from tqdm.auto import tqdm
import numpy as np
from torchvision.utils import make_grid
from matplotlib import pyplot as plt
from matplotlib import image as mpimg
from utils import *
from attacks import *
from imagenet import imagenet_classes, dollar_street_classes
import seaborn as sns
from googletrans import Translator
from imagenet import imagenet_classes
from country_list import original_country_list, translated_country_list
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from sklearn.cluster import KMeans

imagenet_folder = "/home/ubuntu/datasets/imagenet/imagenet_images/"
output_folder = "results"
imagenet_dataset = custom_imagenet_dataset(imagenet_folder, transform = None, country= False) 
size = len(imagenet_dataset)
split_length= [int(size*0.6), int(size*0.2), size-int(size*0.6)-int(size*0.2)]
train_set, val_set, test_set = torch.utils.data.random_split(imagenet_dataset, split_length)
training_set = Dataset_from_subset(train_set, data_transforms['train'])
val_set = Dataset_from_subset(val_set, data_transforms['val'])
    
train_dataloader = DataLoader(training_set, batch_size=32, shuffle=True, num_workers=8)
val_dataloader = DataLoader(val_set, batch_size=32, shuffle=False, num_workers=8)
# test_dataloader = DataLoader(test_set, batch_size=8, shuffle=False, num_workers=4)

device = "cuda"
model = models.resnet50(pretrained=True).eval()
for param in model.parameters():
    param.requires_grad=False
model.to(device)

num_epochs = 20
stepsize = 0.1
criterion = nn.CrossEntropyLoss().to(device)  

best_accuracy = 0

perturbations = []

correct = 0
total = 0


class_folders = os.listdir(imagenet_folder)
results_path = "./images_kmeans_results/"

# iterate all the classes
for i in range(len(class_folders)):    
    image_tensor = []
    class_name = class_folders[i]
    if class_name[0]=='.': 
        continue
    images = os.listdir(os.path.join(imagenet_folder, class_name))
    label = torch.tensor(imagenet_dataset.class_to_idx[class_name]).to(device)
    if os.path.exists(results_path+str(label.cpu().numpy())+".pt"):
        continue

    name_label_dict = {}
    center_label = []
    names = []
    
    # if this class has more than 10 classes
    if len(images)>10:
        total = len(images)
        # iterate all the images
        for j in range(len(images)):
            image_name = images[j].split(".")[0]
            names.append(image_name)
            img = data_transforms['val'](Image.open(os.path.join(imagenet_folder, class_name, images[j])))
            if img.shape[0]==1:
                img = img.repeat(3, 1, 1)
            if img.shape[0]==4:
                img = img[:3]
            img = img.unsqueeze(0)
            image_tensor.append(img)
        names = np.array(names)
        image_tensor = torch.cat(image_tensor)
        outputs = model(norm(image_tensor.to(device)))  
        pred = torch.argmax(outputs,dim=1)
        correct += float((pred==label).detach().sum().item())
 
        image_tensor = image_tensor.reshape(image_tensor.shape[0],-1) 
        # roughly 5 images are in one cluster
        num_clusters = int(image_tensor.shape[0]/5)
        kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(image_tensor.cpu().numpy())
        labels = kmeans.labels_
        center = kmeans.cluster_centers_
        correct_adv = 0
        
        for k in range(num_clusters):            
            distance_center = image_tensor[labels==k].cpu().numpy()-np.repeat(np.expand_dims(center[k],0), image_tensor[labels==k].shape[0], axis = 0)
            distance_center = np.linalg.norm(distance_center, axis = 1)
            image_ranking = np.argsort(distance_center)
            index = 0
            center_image = image_ranking[index]
            represent_image = image_tensor[labels==k][center_image].unsqueeze(0).to(device)
            pred_represent = torch.argmax(model(norm(represent_image.reshape(1,3,229,229))),dim=1)
            # we need the cluster center image correctly classified
            while pred_represent.squeeze().item()!=label.squeeze().item():
                index +=1
                if index>=len(image_ranking):
                    break
                center_image = image_ranking[index]
                represent_image = image_tensor[labels==k][center_image].unsqueeze(0).to(device)
                pred_represent =  torch.argmax(model(norm(represent_image.reshape(1,3,229,229))),dim=1)
            
            attack = PieAPP_GD_attack_batch(model, criterion, stepsize, represent_image.reshape(1,3,229,229), label.unsqueeze(0), "", "").detach()

            perturbation = attack.reshape(1,-1)-represent_image
            model.eval()
            
            for kk in range(image_tensor[labels==k].shape[0]):   
                outputs = model(norm(image_tensor[labels==k][kk].to(device).reshape(1,3,229,229)))
                pred = torch.argmax(outputs,dim=1)   
                
                outputs = model(norm((image_tensor[labels==k][kk].to(device)+perturbation).reshape(1,3,229,229)))
                pred = torch.argmax(outputs,dim=1)   
                correct_adv += float(pred==label) 
                
                name_label_dict[names[labels==k][kk]]=names[labels==k][center_image]
        
        torch.save(name_label_dict, os.path.join(results_path, str(label.cpu().numpy())+".pt"))
               
        
# this accumulates a list of anchor images (cluster center images) and save to "anchor_set.pt"      
anchor_set = []
for i in range(1000):
    if os.path.isfile("./images_kmeans_results/"+str(i)+".pt"):
        image_centers = set(torch.load("./images_kmeans_results/"+str(i)+".pt").values())
        anchor_set+=list(image_centers)
    
print(anchor_set, len(anchor_set))
torch.save(anchor_set, "./images_kmeans_results/anchor_set.pt")


 