import importlib
import samplers
import smoothed_cf_diffusion as scfd
import torch
from tqdm import tqdm
import os
import ot
from entropy_estimators import continuous

importlib.reload(samplers)
importlib.reload(scfd)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
torch.set_default_dtype(torch.float32)

def generate_scfdm(training_samples,
                   smoothing_sigma, # for SCFDM
                   n_model_samples,
                   n_points # for SCFDM
                   ):
    
    target_sampler = samplers.EmpiricalDist(training_samples, device=device)  
    # Now construct SCFDM
    n_samples_per_iter = training_samples.shape[0]
    space_dims = training_samples.shape[1]
    fixed_noise = torch.randn(n_model_samples, 1, n_points, space_dims, device=device) * smoothing_sigma

    smoothed_score = scfd.SmoothedScore(space_dims=space_dims, 
                                        target_sampler=target_sampler, 
                                        n_samples=n_samples_per_iter, 
                                        smoothing_sigma=smoothing_sigma, 
                                        n_points=n_points,
                                        fixed_noise=fixed_noise,
                                        ).to(device)
    vel_field = scfd.VelFromScore(smoothed_score).to(device)
    
    return vel_field


def model_evaluation(model_samples,
                     manifold_samples,
                     training_samples
                     ):
    # Measure W2 distance between model samples and manifold samples
    M = ot.dist(model_samples, manifold_samples, metric="sqeuclidean").to(device)
    a = torch.ones(model_samples.shape[0], dtype=torch.float32).to(device) / model_samples.shape[0]
    b = torch.ones(manifold_samples.shape[0], dtype=torch.float32).to(device) / manifold_samples.shape[0]
    # Compute Wasserstein distance
    w2_model_manifold = ot.emd2(a, b, M, numItermax=1000000).sqrt()
    print("W2 between manifold samples and model samples:", str(w2_model_manifold))

    # Measure W2 distance between model samples and training samples
    M = ot.dist(training_samples, model_samples, metric="sqeuclidean").to(device)
    a = torch.ones(training_samples.shape[0], dtype=torch.float32).to(device) / training_samples.shape[0]
    b = torch.ones(model_samples.shape[0], dtype=torch.float32).to(device) / model_samples.shape[0]
    # Compute Wasserstein distance
    w2_training_model = ot.emd2(a, b, M, numItermax=1000000).sqrt()
    print("W2 between training samples and model samples:", str(w2_training_model))

    # Measure entropy of model samples
    entropy_model_samples = continuous.get_h(model_samples.cpu().numpy(), k=1, min_dist=1e-6)
    print("Entropy of model samples:", str(entropy_model_samples))

    return w2_model_manifold, w2_training_model, entropy_model_samples

# Construct method that runs model_evaluation for range of sigmas

def model_evaluation_over_sigma(training_samples,
                                manifold_samples,
                                n_sigma
                                ):
    
    # Construct list of sigmas
    sigma_tensor = torch.linspace(0.0, 1.0, n_sigma)

    # Construct lists to store results
    w2_model_manifold_list = []
    w2_training_model_list = []
    entropy_model_samples_list = []

    train_sampler = samplers.EmpiricalDist(training_samples)
    # Loop over sigmas
    for sigma in tqdm(sigma_tensor):
        print("sigma:", float(sigma))
        # Generate SCFDM
        vel_field = generate_scfdm(training_samples,
                                   float(sigma),
                                   n_model_samples=manifold_samples.shape[0],
                                   n_points=10
        )
        num_repeat = 10
        w2_model_manifold_avg = 0.0
        w2_training_model_avg = 0.0
        entropy_model_samples_avg = 0.0
        for i in range(num_repeat):
            # Generate model samples
            model_samples, true = scfd.advect_samples(train_sampler, 
                                                    vel_field, 
                                                    space_dims=training_samples.shape[1], 
                                                    step_size=1e-2, 
                                                    n_model_samples=manifold_samples.shape[0],
                                                    display=False
                                                    )
            
            # Evaluate model
            w2_model_manifold, w2_training_model, entropy_model_samples = model_evaluation(model_samples,
            manifold_samples,
            training_samples
            )
            w2_model_manifold_avg += w2_model_manifold
            w2_training_model_avg += w2_training_model
            entropy_model_samples_avg += entropy_model_samples
        w2_model_manifold_avg /= num_repeat
        w2_training_model_avg /= num_repeat
        entropy_model_samples_avg /= num_repeat
        print("Average W2 between manifold samples and model samples:", str(w2_model_manifold_avg))
        print("Average W2 between training samples and model samples:", str(w2_training_model_avg))
        print("Average entropy of model samples:", str(entropy_model_samples_avg))

        # Append results to lists
        w2_model_manifold_list.append(torch.tensor(w2_model_manifold_avg).float())
        w2_training_model_list.append(torch.tensor(w2_training_model_avg).float())
        entropy_model_samples_list.append(torch.tensor(entropy_model_samples_avg).float())
    
    return torch.stack(w2_model_manifold_list, dim=0), torch.stack(w2_training_model_list, dim=0), torch.stack(entropy_model_samples_list, dim=0), sigma_tensor

        

def main():
    #dataset_names = ['helix', '8gaussians', 'pinwheel', '2spirals', 'checkerboard', 'rings', 'circles', 'moons']
    dataset_names = ['checkerboard']
    n_sigma = 51
    os.makedirs('results/2d_generalization', exist_ok=True)
    for dataset_name in dataset_names:
        print("Dataset:", dataset_name)
        # Generate manifold samples and training samples
        toy_sampler = samplers.Toy2DDist(dist_name=dataset_name)
        N = 5000 # for manifold approximation
        n = 500 # training set
        manifold_samples = toy_sampler.sample(N).to(device)
        training_samples = toy_sampler.sample(n).to(device)

        # Evaluate model over sigma
        w2_model_manifold, w2_training_model, entropy_model_samples, sigma_tensor = model_evaluation_over_sigma(training_samples,
                                                                                                                manifold_samples,
                                                                                                                n_sigma
                                                                                                                )
        
        # Save results as .pt file
        fname = 'results/2d_generalization/' + dataset_name + '_w2_model_manifold.pt'
        torch.save(w2_model_manifold, fname)
        fname = 'results/2d_generalization/' + dataset_name + '_w2_training_model.pt'
        torch.save(w2_training_model, fname)
        fname = 'results/2d_generalization/' + dataset_name + '_entropy_model_samples.pt'
        torch.save(entropy_model_samples, fname)
        fname = 'results/2d_generalization/' + dataset_name + '_sigma_tensor.pt'
        torch.save(sigma_tensor, fname)

if __name__ == "__main__":
    main()