import importlib
import samplers
import smoothed_cf_diffusion as scfd
import torch
from tqdm import tqdm
import os
import ot

importlib.reload(samplers)
importlib.reload(scfd)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
torch.set_default_dtype(torch.float64)

def generate_scfdm(training_samples,
                   smoothing_sigma, # for SCFDM
                   n_points # for SCFDM
                   ):
    
    target_sampler = samplers.EmpiricalDist(training_samples, device=device)  
    # Now construct SCFDM
    n_samples_per_iter = training_samples.shape[0]
    space_dims = training_samples.shape[1]
    n_model_samples = 5000
    fixed_noise = torch.randn(n_model_samples, 1, n_points, space_dims, device=device) * smoothing_sigma

    smoothed_score = scfd.SmoothedScore(space_dims=space_dims, 
                                        target_sampler=target_sampler, 
                                        n_samples=n_samples_per_iter, 
                                        smoothing_sigma=smoothing_sigma, 
                                        n_points=n_points,
                                        fixed_noise=fixed_noise,
                                        ).to(device)
    vel_field = scfd.VelFromScore(smoothed_score).to(device)
    
    return vel_field


def model_evaluation(model_samples_0,
                     model_samples_T
                     ):
    # Measure W2 distance between model samples starting at 0 and model samples starting at T
    M = ot.dist(model_samples_0, model_samples_T, metric="sqeuclidean")
    # Compute Wasserstein distance
    a = torch.ones(model_samples_0.shape[0], dtype=torch.float32).to(device) / model_samples_0.shape[0]
    b = torch.ones(model_samples_T.shape[0], dtype=torch.float32).to(device) / model_samples_T.shape[0]
    w2_0_T = ot.emd2(a, b, M, numItermax=1000000).sqrt()

    return w2_0_T

# Construct method that runs model_evaluation for range of T

def model_evaluation_over_T(training_samples,
                            sigma,
                            n_times
                            ):
    
    # Construct list of T
    t_tensor = torch.linspace(0.0, 1.0, n_times)[:-1]
    #t_tensor = torch.zeros(1)

    # Construct lists to store results
    w2_0_T_list = []

    train_sampler = samplers.EmpiricalDist(training_samples)
    # Loop over t
    for time in tqdm(t_tensor):
        # Generate SCFDM
        vel_field = generate_scfdm(training_samples,
                                   sigma,
                                   n_points=100
                                   )
        num_repeat = 5
        w2_0_T_avg = 0.0
        for i in range(num_repeat): # was num_repeat
            # Generate model samples
            model_samples, true = scfd.advect_samples(train_sampler, 
                                                    vel_field, 
                                                    space_dims=training_samples.shape[1], 
                                                    step_size=1e-2, 
                                                    n_model_samples=5000,
                                                    display=False,
                                                    T=float(time),
                                                    train_samples=training_samples,
                                                    )
            if float(time) == 0:
                model_samples_0 = model_samples
            model_samples_T = model_samples

            # Evaluate model

            w2_0_T = model_evaluation(model_samples_0,
                                        model_samples_T
                                        )
            print("W2 distance between model samples starting at 0 and model samples starting at T =", str(float(time)), ":", str(w2_0_T))
            w2_0_T_avg += w2_0_T
        w2_0_T_avg /= num_repeat
        print("Average W2 distance between model samples starting at 0 and model samples starting at T =", str(float(time)), ":", str(w2_0_T_avg))
        
        # Append results to lists
        w2_0_T_list.append(torch.tensor(w2_0_T_avg).float())
    
    return torch.stack(w2_0_T_list, dim=0), t_tensor

# Construct method that runs model_evaluation_over_T for range of sigmas

def model_evaluation_over_sigma(training_samples,
                                n_sigma,
                                n_times
                                ):
    
    # Construct list of sigmas
    sigma_tensor = torch.linspace(0.0, 1.0, n_sigma)

    # Construct lists to store results
    w2_sigma_list = []

    # Loop over sigmas
    for sigma in tqdm(sigma_tensor):
        print("sigma: ", str(float(sigma)))
        # Evaluate model
        w2_0_T_tensor, t_tensor = model_evaluation_over_T(training_samples,
                                                   float(sigma),
                                                   n_times
                                                   )
        
        # Append results to lists
        w2_sigma_list.append(w2_0_T_tensor)
    
    return torch.stack(w2_sigma_list, dim=0), sigma_tensor, t_tensor

        

def main():
    #dataset_names = ['helix', '8gaussians', 'pinwheel', '2spirals', 'checkerboard', 'rings', 'circles', 'moons']
    dataset_names = ['checkerboard']
    n_times = 51
    n_sigma = 6
    os.makedirs('results/starting_at_T', exist_ok=True)
    for dataset_name in dataset_names:
        print("Dataset: ", dataset_name)
        # Generate manifold samples and training samples
        toy_sampler = samplers.Toy2DDist(dist_name=dataset_name)
        n = 500 # training set
        training_samples = toy_sampler.sample(n).to(device)

        # Evaluate model over sigma
        w2_sigma_T_tensor, sigma_tensor, t_tensor = model_evaluation_over_sigma(training_samples,
                                                                                n_sigma,
                                                                                n_times
                                                                                )
        
        # Save results as .pt file
        fname = 'results/starting_at_T/' + dataset_name + '_w2_sigma_T.pt'
        torch.save(w2_sigma_T_tensor, fname)
        fname = 'results/starting_at_T/' + dataset_name + '_sigma_tensor.pt'
        torch.save(sigma_tensor, fname)
        fname = 'results/starting_at_T/' + dataset_name + '_t_tensor.pt'
        torch.save(t_tensor, fname)

if __name__ == "__main__":
    main()