import faiss
import numpy as np
import os
import samplers
import smoothed_cf_diffusion as scfd
import torch
import time
from tqdm import tqdm

device = torch.device("cpu")
torch.manual_seed(0)
torch.set_default_dtype(torch.float32)

def find_nearest_neighbors_torch(model_samples, train_samples, k=1):
    """Find the k nearest neighbors of advected samples in the training set.
    """

    # Compute distance between advected and train_samples
    D = torch.cdist(model_samples, train_samples)
    # Find indices of k nearest neighbors
    I = torch.topk(D, k, dim=1, largest=False)[1]
    return D, I

def main():
    # Load the CelebA latents
    celeba_latents = []
    celeba_latent_dir = ''
    for file in tqdm(os.listdir(celeba_latent_dir)):
        celeba_latents.append(torch.load(celeba_latent_dir + file, map_location=device))
    celeba_latents = torch.stack(celeba_latents).squeeze()

    # Normalize the latents
    latents = celeba_latents.detach().to(device)
    means = latents.mean(dim=0)
    max_norm_all = torch.max(torch.norm(latents - means, dim=1))
    latents_normalized = (latents - means) / max_norm_all

    # Convert trainset to torch Tensor
    training_samples = latents_normalized.to(device)

    space_dims = training_samples.shape[1]

    # Build a Faiss index for the training samples
    M = 4
    index = faiss.IndexHNSWFlat(space_dims, M)
    index.add(training_samples.cpu().numpy())
    index.hnsw.efSearch = 128

    exact_index = faiss.IndexFlatL2(space_dims)
    exact_index.add(training_samples.cpu().numpy())

    train_sampler = samplers.EmpiricalDist(training_samples, device=device) # Create an empirical distribution object from the training samples

    n_samples_per_iter = training_samples.shape[0] # Compute the closed-form score on the entire training set at each iteration
    n_model_samples = 10000 # Number of samples to generate using the closed-form diffusion model 

    n_random_samples = 300
    num_nn = 10000

    smoothing_sigma = 0.025
    n_points = 2

    knn_score = scfd.KNNSmoothedScore(space_dims=space_dims,
                                    target_sampler=train_sampler,
                                    n_random_samples=n_random_samples,
                                    smoothing_sigma=smoothing_sigma,
                                    n_points=n_points,
                                    index=index,
                                    num_nn=num_nn,
                                    )

    velocity_field = scfd.VelFromScore(knn_score)

    # Generate model samples in batches of 10 and time the process
    model_samples = []
    start_time = time.time()
    samples_per_batch = 25
    n_batches = n_model_samples // samples_per_batch
    for i in tqdm(range(n_batches)): # was 25
        model_samples_batch, true = scfd.advect_samples(train_sampler, 
                                            velocity_field, 
                                            space_dims=space_dims, 
                                            step_size=1e-2, 
                                            n_model_samples=samples_per_batch,
                                            T=0.99,
                                            train_samples=training_samples,
                                            display=False
                                            )
        model_samples.append(model_samples_batch)
    end_time = time.time()
    print('Time to generate model samples: ', end_time - start_time)
    model_samples = torch.cat(model_samples, dim=0)

    # Search for nearest neighbors of the model samples in the training set

    model_samples_faiss = model_samples.detach().cpu().numpy()
    D, I = exact_index.search(model_samples_faiss, 1)
    X_nn = training_samples[I.squeeze()]

    # Filter out the model samples that are identical to their nearest neighbors
    model_samples_filtered = model_samples[torch.tensor(D.squeeze() > 1e-6)]

    # Print the number of filtered model samples and time per filtered sample
    print('Number of filtered model samples: ', model_samples_filtered.shape[0])
    print('Time per filtered sample: ', (end_time - start_time) / model_samples_filtered.shape[0])

    # Denormalize the filtered model samples
    model_samples_filtered_denormalized = model_samples_filtered * max_norm_all + means
    
    # Save the model samples
    save_dir = ''
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    torch.save(model_samples_filtered_denormalized, os.path.join(save_dir, 'celeba_latent_samples_smoothing_sigma_0p025_npoints_2_M4.pt'))

    # Save all hyperparameters to JSON file
    hyperparams = {
        'smoothing_sigma': smoothing_sigma,
        'n_points': n_points,
        'num_nn': num_nn,
        'n_random_samples': n_random_samples,
        'n_model_samples': n_model_samples,
        'T': 0.99,
        'step_size': 1e-2,
    }
    torch.save(hyperparams, os.path.join(save_dir, 'hyperparams_smoothing_sigma_0p025_npoints_2_M4.json'))

    # Find NNs of the filtered model samples in the training set
    model_samples_filtered_faiss = model_samples_filtered.detach().cpu().numpy()
    D_filtered, I_filtered = exact_index.search(model_samples_filtered_faiss, 1)

    X_nn_filtered = training_samples[I_filtered.squeeze()]
    # denormalize
    X_nn_filtered = X_nn_filtered * max_norm_all + means

    # Save the nearest neighbors
    torch.save(X_nn_filtered, os.path.join(save_dir, 'celeba_latent_nn_smoothing_sigma_0p025_npoints_2_M4.pt'))

if __name__ == "__main__":
    main()