import faiss
import faiss.contrib.torch_utils
import importlib
import samplers
import smoothed_cf_diffusion as scfd
import torch
from tqdm import tqdm
import os
import ot

importlib.reload(samplers)
importlib.reload(scfd)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
torch.set_default_dtype(torch.float32)

def generate_scfdm(training_samples,
                   smoothing_sigma, # for SCFDM
                   n_points # for SCFDM
                   ):
    
    target_sampler = samplers.EmpiricalDist(training_samples, device=device)  
    # Now construct SCFDM
    n_samples_per_iter = training_samples.shape[0]
    space_dims = training_samples.shape[1]
    n_model_samples = 5000
    fixed_noise = torch.randn(n_model_samples, 1, n_points, space_dims, device=device) * smoothing_sigma

    smoothed_score = scfd.SmoothedScore(space_dims=space_dims, 
                                        target_sampler=target_sampler, 
                                        n_samples=n_samples_per_iter, 
                                        smoothing_sigma=smoothing_sigma, 
                                        n_points=n_points,
                                        fixed_noise=fixed_noise,
                                        ).to(device)
    vel_field = scfd.VelFromScore(smoothed_score).to(device)
    
    return vel_field

def generate_knn_scfdm(training_samples,
                       index,
                       num_nn, # for KNN
                       n_random_samples,
                       smoothing_sigma, # for SCFDM
                       n_points # for SCFDM
                       ):
    
    target_sampler = samplers.EmpiricalDist(training_samples, device=device)  
    # Now construct SCFDM
    n_samples_per_iter = training_samples.shape[0]
    space_dims = training_samples.shape[1]
    n_model_samples = 5000
    fixed_noise = torch.randn(n_model_samples, 1, n_points, space_dims, device=device) * smoothing_sigma

    knn_smoothed_score = scfd.KNNSmoothedScore(space_dims=space_dims,
                                           target_sampler=target_sampler,
                                           n_random_samples=n_random_samples,
                                           smoothing_sigma=smoothing_sigma,
                                           n_points=n_points,
                                           fixed_noise=fixed_noise,
                                           index=index,
                                           num_nn=num_nn
                                           )
    vel_field = scfd.VelFromScore(knn_smoothed_score).to(device)
    
    return vel_field


def model_evaluation(model_samples_exact,
                     model_samples_knn
                     ):
    # Measure W2 distance between model samples from exact SCFDM and model samples from kNN approx SCFDM
    M = ot.dist(model_samples_exact, model_samples_knn, metric="sqeuclidean")
    # Compute Wasserstein distance
    a = torch.ones(model_samples_exact.shape[0], dtype=torch.float32).to(device) / model_samples_exact.shape[0]
    b = torch.ones(model_samples_knn.shape[0], dtype=torch.float32).to(device) / model_samples_knn.shape[0]
    w2_knn = ot.emd2(a, b, M, numItermax=1000000).sqrt()

    return w2_knn

# Construct method that runs model_evaluation for range of num_nn

def model_evaluation_over_num_nn(training_samples,
                                 sigma,
                                 n_random_samples,
                                 num_nn_list,
                                 index
                                 ):

    # Construct lists to store results
    w2_over_num_nn = []

    train_sampler = samplers.EmpiricalDist(training_samples)
    # Generate SCFDM
    vel_field_exact = generate_scfdm(training_samples,
                                   sigma,
                                   n_points=100
                                   )
    # Generate exact model samples
    model_samples_exact, true = scfd.advect_samples(train_sampler, 
                                                    vel_field_exact, 
                                                    space_dims=training_samples.shape[1], 
                                                    step_size=1e-2, 
                                                    n_model_samples=5000,
                                                    display=False
                                                    )
    # Loop over t
    for num_nn in tqdm(num_nn_list):
        print("num_nn: ", str(num_nn))
        num_repeat = 10
        w2_knn_avg = 0.0
        for i in range(num_repeat):
            vel_field_knn = generate_knn_scfdm(training_samples,
                                                index,
                                                num_nn,
                                                n_random_samples,
                                                sigma,
                                                n_points=100
                                                )
            # Generate model samples
            model_samples_knn, true = scfd.advect_samples(train_sampler, 
                                                    vel_field_knn, 
                                                    space_dims=training_samples.shape[1], 
                                                    step_size=1e-2, 
                                                    n_model_samples=5000,
                                                    display=False
                                                    )

            # Evaluate model
            w2_knn = model_evaluation(model_samples_exact,
                                    model_samples_knn
                                    )
            print("w2_knn: ", str(w2_knn))
            w2_knn_avg += w2_knn
            
        w2_knn_avg /= num_repeat
        print("Average w2_knn: ", str(w2_knn_avg))
        # Append results to lists
        w2_over_num_nn.append(torch.tensor(w2_knn_avg).float())
    
    return torch.stack(w2_over_num_nn, dim=0)

# Construct method that runs model_evaluation_over_num_nn for range of n_random_samples

def model_evaluation_over_n_random_samples(training_samples,
                                           sigma,
                                           n_random_samples_list,
                                           num_nn_list,
                                           index
                                           ):
    
    # Construct lists to store results
    w2_over_n_random_samples = []

    # Loop over n_random_samples
    for n_random_samples in tqdm(n_random_samples_list):
        print("n_random_samples: ", str(n_random_samples))
        # Evaluate model over num_nn
        w2_over_num_nn = model_evaluation_over_num_nn(training_samples,
                                                                   sigma,
                                                                   n_random_samples,
                                                                   num_nn_list,
                                                                   index
                                                                   )
        # Append results to lists
        w2_over_n_random_samples.append(w2_over_num_nn)
    
    return torch.stack(w2_over_n_random_samples, dim=0)
        

def main():
    #dataset_names = ['helix', '8gaussians', 'pinwheel', '2spirals', 'checkerboard', 'rings', 'circles', 'moons']
    dataset_names = ['checkerboard']
    num_nn_list = [1, 5, 10, 15, 20, 25, 50, 100]
    n_random_samples_list = [1, 5, 10, 15, 20, 25, 50, 100]
    sigma = 0.3
    os.makedirs('results/knn_exact', exist_ok=True)
    for dataset_name in dataset_names:
        print("Dataset: ", dataset_name)
        # Generate manifold samples and training samples
        toy_sampler = samplers.Toy2DDist(dist_name=dataset_name)
        n = 500 # training set
        training_samples = toy_sampler.sample(n).to(device)
        space_dims = training_samples.shape[1]

        # Generate exact search index
        index = faiss.IndexFlatL2(space_dims)
        index.add(training_samples.cpu().numpy())
        # 
        # Generate aggressive search index
        #M = 4
        #index = faiss.IndexHNSWFlat(space_dims, M)
        #index.add(training_samples.cpu().numpy())
        #index.hnsw.efSearch = 128

        # Evaluate model
        w2_over_n_random_samples = model_evaluation_over_n_random_samples(training_samples,
                                                                          sigma,
                                                                          n_random_samples_list,
                                                                          num_nn_list,
                                                                          index
                                                                          )
        
        # Save results as .pt file
        torch.save(w2_over_n_random_samples, 'results/knn_exact/' + dataset_name + '_w2_over_n_random_samples.pt')
        torch.save(torch.tensor(n_random_samples_list), 'results/knn_exact/' + dataset_name + '_n_random_samples.pt')
        torch.save(torch.tensor(num_nn_list), 'results/knn_exact/' + dataset_name + '_num_nn.pt')

if __name__ == "__main__":
    main()