import faiss
import numpy as np
import os
import samplers
import smoothed_cf_diffusion as scfd
import torch
import time
from tqdm import tqdm
from PIL import Image

device = torch.device("cpu")
torch.manual_seed(0)
torch.set_default_dtype(torch.float32)

def find_nearest_neighbors_torch(model_samples, train_samples, k=1):
    """Find the k nearest neighbors of advected samples in the training set.
    """

    # Compute distance between advected and train_samples
    D = torch.cdist(model_samples, train_samples)
    # Find indices of k nearest neighbors
    I = torch.topk(D, k, dim=1, largest=False)[1]
    return D, I

def main():
    # Load the train_butterflies -- these are png files, and I want to convert them to a tensor of floats
    butterflies = []
    butterflies_dir = 'train_butterflies/'
    for file in tqdm(os.listdir(butterflies_dir)):
        img = Image.open(butterflies_dir + file)
        img = np.array(img)
        img = img / 255.0
        img = torch.tensor(img).permute(2, 0, 1).unsqueeze(0)
        butterflies.append(img)
    butterflies = torch.cat(butterflies).squeeze()

    # Reshape to (B,D)
    butterflies = butterflies.view(butterflies.shape[0], -1)

    # Normalize the images to lie in unit ball
    means = butterflies.mean(dim=0)
    max_norm_all = torch.max(torch.norm(butterflies - means, dim=1))
    butterflies_normalized = (butterflies - means) / max_norm_all

    # Move training samples to device
    training_samples = butterflies_normalized.to(device)

    # Build a Faiss index for the training samples
    space_dims = training_samples.shape[1]
    exact_index = faiss.IndexFlatL2(space_dims)
    exact_index.add(training_samples.cpu().numpy())

    # Define EmpiricalDist using the training samples

    train_sampler = samplers.EmpiricalDist(training_samples, device=device)
    space_dims = training_samples.shape[1]

    train_sampler = samplers.EmpiricalDist(training_samples, device=device) # Create an empirical distribution object from the training samples

    n_samples_per_iter = training_samples.shape[0] # Compute the closed-form score on the entire training set at each iteration
    n_model_samples = 1000 # Number of samples to generate using the closed-form diffusion model 

    smoothing_sigma = 0.1
    n_points = 2

    smoothed_score = scfd.SmoothedScore(space_dims=space_dims,
                                    target_sampler=train_sampler,
                                    n_samples=training_samples.shape[0], # Use all training samples as this dataset is small
                                    smoothing_sigma=smoothing_sigma,
                                    n_points=n_points,
                                    )

    velocity_field = scfd.VelFromScore(smoothed_score)

    # Generate model samples in batches of 10 and time the process
    model_samples = []
    start_time = time.time()
    samples_per_batch = 1
    n_batches = n_model_samples // samples_per_batch
    for i in tqdm(range(n_batches)):
        model_samples_batch, true = scfd.advect_samples(train_sampler, 
                                            velocity_field, 
                                            space_dims=space_dims, 
                                            step_size=1e-2, 
                                            n_model_samples=samples_per_batch,
                                            T=0.98,
                                            train_samples=training_samples,
                                            display=False
                                            )
        model_samples.append(model_samples_batch)
    end_time = time.time()
    print('Time to generate model samples: ', end_time - start_time)
    model_samples = torch.cat(model_samples, dim=0)

    # Search for nearest neighbors of the model samples in the training set

    model_samples_faiss = model_samples.detach().cpu().numpy()
    D, I = exact_index.search(model_samples_faiss, 1)
    X_nn = training_samples[I.squeeze()]

    # Filter out the model samples that are identical to their nearest neighbors
    model_samples_filtered = model_samples[torch.tensor(D.squeeze() > 1e-6)]

    # Print the number of filtered model samples and time per filtered sample
    print('Number of filtered model samples: ', model_samples_filtered.shape[0])
    print('Time per filtered sample: ', (end_time - start_time) / model_samples_filtered.shape[0])

    # Denormalize the filtered model samples
    model_samples_filtered_denormalized = model_samples_filtered * max_norm_all + means
    
    # Save the model samples
    save_dir = 'butterflies_samples'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    torch.save(model_samples_filtered_denormalized, os.path.join(save_dir, 'butterflies_samples_smoothing_sigma_0p1_npoints_2.pt'))

    # Save all hyperparameters to JSON file
    hyperparams = {
        'smoothing_sigma': smoothing_sigma,
        'n_points': n_points,
        'n_samples_per_iter': n_samples_per_iter,
        'n_model_samples': n_model_samples,
        'T': 0.98,
        'step_size': 1e-2,
    }
    torch.save(hyperparams, os.path.join(save_dir, 'hyperparams_smoothing_sigma_0p1_npoints_2.json'))

    # Find NNs of the filtered model samples in the training set
    model_samples_filtered_faiss = model_samples_filtered.detach().cpu().numpy()
    D_filtered, I_filtered = exact_index.search(model_samples_filtered_faiss, 1)

    X_nn_filtered = training_samples[I_filtered.squeeze()]
    # denormalize
    X_nn_filtered = X_nn_filtered * max_norm_all + means

    # Save the nearest neighbors
    torch.save(X_nn_filtered, os.path.join(save_dir, 'butterflies_nn_smoothing_sigma_0p1_npoints_2.pt'))

if __name__ == "__main__":
    main()