import torch
from itertools import product
from functools import partial
import matplotlib.pyplot as plt
import numpy as np

from robustopt_torch.costs import eucl_norm_sq, np_eucl_norm_sq, np_eucl_norm_sq_grad
from robustopt_torch.DeconvolutionIFunc import DeconvolutionIFunc
from robustopt_torch.distributions import DiscreteDist, MixtureDist, MultivariateNormalDist
from robustopt_torch.plot_metrics import metricPlotter, particlePlotter, write_to_file, metric_callback
from robustopt_torch.PFDWolfe import PFDWolfe
from robustopt_torch.logging_utils import distribution_center_statistics, gradient_norm_statistics, grad_evals

def deconv_metric_callback(log_file, print_most_recent, met_plotter,
                           *metric_calc_funcs):
    metric_calc_funcs = list(metric_calc_funcs)
    lam_calc_func = lambda iter_vars, metrics : \
        metrics.append_to_metric("Lambda", (iter_vars["iteration"],
                                            iter_vars["lambda"]))

    metric_calc_funcs.extend([partial(distribution_center_statistics, stats = "median"),
                              lam_calc_func])

    return partial(metric_callback, metric_plotter = met_plotter,
                   metric_calc_funcs = metric_calc_funcs, log_file =
                   log_file, print_most_recent = print_most_recent)

def obj_eval_for_validation(iterate, dconv_ifunc, num_validation = 1000):
    validation_obj = dconv_ifunc.get_ifunc(lambda x = None : iterate.sample() if \
                                           x is None else iterate.sample(x))
    return validation_obj(iterate.sample(num_validation)).mean().item(), validation_obj

def log_obj_and_grad_norm(iter_vars, metric_plotter, dconv_ifunc):
    obj_val, obj = obj_eval_for_validation(iter_vars["iterate"], dconv_ifunc)
    metric_plotter.append_to_metric("Objective Value", (iter_vars["iteration"] +
                                                        1, obj_val))
    gradient_norm_statistics(iter_vars, metric_plotter, stats = "mean",
                             objective = obj)

def particle_plotter_wrapper(iter_vars, particle_plotter, particle_filename,
                             metric_callback):
    iteration, iterate = iter_vars["iteration"], iter_vars["iterate"]
    particle_plotter.add_particles(**{f"Iteration {iteration + 1}" :
                                      iterate.vals})
    write_to_file(particle_filename, particle_plotter.particles)
    metric_callback(iter_vars)



# Generating mixtures of Gaussians

# Generate a mixture of gaussians which are centered at random shifts from vertices of the hypercube
def hypercube_gaussian_mix(num_dim, num_means = None, mixture_vars = 0.4,
                           center_jitter_var = 3.0, scaling_factor = 10.0,
                           generator = None):
    if num_means is not None:
        target_means = []
        while len(target_means) < num_means:
            vertex = torch.bernoulli(torch.ones(num_dim) / 2.0, generator = generator)
            if any((vertex == mean).all() for mean in target_means): continue
            target_means.append(vertex)
    else:
        target_means = [torch.as_tensor(b_vec, dtype = torch.float) for b_vec in
                        product(range(2), repeat=num_dim)]

    target_means = [scaling_factor * mean - scaling_factor / 2.0 for mean in
                    target_means]
    target_means = [t_mean + t_mean.detach().clone().normal_(0.0, center_jitter_var,
                                                             generator = generator)
                    for t_mean in target_means]
    target_components = [MultivariateNormalDist(t_mean, mixture_vars *
                                                torch.eye(num_dim),
                                                generator = generator) for t_mean in target_means]
    return MixtureDist(*target_components, generator = generator)


# Samples from gaussian kernels
################################

def gauss_kde(vals, generator = None):
    band = 1.0
    return vals + band * torch.normal(torch.zeros_like(vals), generator = generator)

# Final deconvolution examples
def test_dconv_solve_mix_kde():
    # Boilerplate to use the proper datatype while testing
    curr_dt = torch.get_default_dtype()
    torch.set_default_dtype(torch.float)

    # Initialize random number generation
    rng = torch.Generator()
    rng.manual_seed(13172549)

    # Problem parameters
    dims = 2
    eps = 0.2
    target_size = 100
    iterate_size = 200
    target_filename = "pfd_deconvolution_target_particles.pickle"
    metric_filename = "pfd_deconvolution_metrics_kde.pickle"
    particle_filename = "pfd_deconvolution_particles.pickle"

    # Initialize the target mixture of gaussians
    target_dist = hypercube_gaussian_mix(dims, generator = rng)
    # = target_dist.distributions[0].mean // target_dist.distributions[0].covariance
    empirical_target_dist = DiscreteDist(target_dist.sample((target_size, )),generator = rng)

    # Write out the target particles distribution
    target_particles = particlePlotter(Reference = (empirical_target_dist.vals,
                                                    empirical_target_dist.weights))
    write_to_file(target_filename, target_particles.particles)

    # Create the iterate distribution
    iterate_mean = torch.Tensor([10, 10])
    iterate_sigma = torch.eye(dims)
    iterate_dist = MultivariateNormalDist(iterate_mean, iterate_sigma)

    empirical_iterate_dist = DiscreteDist(iterate_dist.sample((iterate_size,)), generator = rng, sampling_kernel = gauss_kde)

    # Initialize the deconvolution ifunc
    dconv_ifunc = DeconvolutionIFunc(eucl_norm_sq, eps,
                                     empirical_target_dist.vals,
                                     empirical_target_dist.weights)

    # Common solver parameters
    num_iter = 3000
    step_size = 40.0
    g_samp = 1000
    p_checks = 20
    dconv_ifunc.set_solve_params(num_iter = num_iter, grad_samp = g_samp)
    dconv_ifunc.set_solver_config(optimizer_params = {"lr" : step_size},
                                  progress_checks = p_checks,
                                  lr_schedule_mode = None)

    # Metric plotters
    log_obj = partial(log_obj_and_grad_norm, dconv_ifunc = dconv_ifunc)
    met_plot = metricPlotter()
    metric_cb = deconv_metric_callback(metric_filename, True, met_plot,
                                       log_obj, grad_evals)

    # Particle plotters
    particle_plot = particlePlotter()
    plotters_callback = partial(particle_plotter_wrapper, particle_plotter =
                                particle_plot, particle_filename =
                                particle_filename, metric_callback = metric_cb)

    # Add initial/particle plot and objective metric
    particle_plot.add_particles(**{"Iteration 0" : empirical_iterate_dist.vals})
    write_to_file(particle_filename, particle_plot.particles)
    obj_val_iteration0, _ = obj_eval_for_validation(empirical_iterate_dist, dconv_ifunc)
    met_plot.append_to_metric("Objective Value", (0, obj_val_iteration0))

    # Config and run solver
    solver_config = {"num_iter" : 30,
                     "outer_solve_params" : {"bisection_tol" : 1e-3,
                                             "subgrad_tol" : 1e-3,
                                             "num_samp" : 30},
                     "inner_solve_params" : {"num_iter" : 5000,
                                             "lr" : 1.0,
                                             "min_lr" : 1e-6,
                                             "verbose" : True,
                                             "stopping_threshold" : 1e-6},
                     "plot_callback_ifunc" : plotters_callback,
                     "print_n_iters" : 1,
                     "verbose" : True}
    pfd_solver = PFDWolfe(**solver_config)

    problem_config = {"ifunc_factory" : dconv_ifunc.get_ifunc,
                      "mu" : empirical_iterate_dist,
                      "delta" : 0.5, # stepsize for frank-wolfe
                      "lam_bounds" : (1e-6, 60.0),
                      "alpha": 1, # scaling factor for the uniform strategy
                      "uniform_strategy": False}
    pfd_solver.run(**problem_config)
    # Boilerplate to reset the torch default datatype and make other tests work
    torch.set_default_dtype(curr_dt)


#test_dconv_solve_mix_kde()
def test_dconv_solve_mix_kde_high_dim(random_seed,metric_filename,delta,alpha):
    # Boilerplate to use the proper datatype while testing
    curr_dt = torch.get_default_dtype()
    torch.set_default_dtype(torch.float)

    # Initialize random number generation
    if random_seed:
        rng = torch.Generator()
        rng.manual_seed(random_seed)
    else:
        rng = torch.Generator()
        rng.manual_seed(220294320)

    # Problem parameters
    dims = 64
    num_means = 7
    eps = 0.2
    target_size = 50
    iterate_size = 200

    # Initialize the target mixture of gaussians
    target_dist = hypercube_gaussian_mix(dims, num_means = num_means, generator = rng)
    empirical_target_dist = DiscreteDist(target_dist.sample((target_size, )),
                                         generator = rng)

    # Create the iterate distribution
    iterate_mean = torch.zeros(dims)
    iterate_sigma = torch.eye(dims)
    iterate_dist = MultivariateNormalDist(iterate_mean, iterate_sigma)

    empirical_iterate_dist = DiscreteDist(iterate_dist.sample((iterate_size,
                                                                 )), generator = rng,
                                          sampling_kernel = gauss_kde)

    # Initialize the deconvolution ifunc
    dconv_ifunc = DeconvolutionIFunc(eucl_norm_sq, eps,
                                     empirical_target_dist.vals,
                                     empirical_target_dist.weights)

    # Common solver parameters
    num_iter = 500
    step_size = 40.0
    g_samp = 1000
    p_checks = 20
    dconv_ifunc.set_solve_params(num_iter = num_iter, grad_samp = g_samp)
    dconv_ifunc.set_solver_config(optimizer_params = {"lr" : step_size},
                                  progress_checks = p_checks,
                                  lr_schedule_mode = None)

    # Metric plotters
    log_obj = partial(log_obj_and_grad_norm, dconv_ifunc = dconv_ifunc)
    met_plot = metricPlotter()
    metric_cb = deconv_metric_callback(metric_filename, True, met_plot,
                                       log_obj, grad_evals)

    # Add inital objective metric
    obj_val_iteration0, _ = obj_eval_for_validation(empirical_iterate_dist, dconv_ifunc)
    met_plot.append_to_metric("Objective Value", (0, obj_val_iteration0))

    # Config and run solver
    solver_config = {"num_iter" : 70,
                     "outer_solve_params" : {"bisection_tol" : 1e-3,
                                             "subgrad_tol" : 1e-3,
                                             "num_samp" : 30},
                     "inner_solve_params" : {"num_iter" : 5000,
                                             "lr" : 1.0,
                                             "min_lr" : 1e-8,
                                             "verbose" : True,
                                             "stopping_threshold" : 1e-6},
                     "plot_callback_ifunc" : metric_cb,
                     "print_n_iters" : 1,
                     "verbose" : True}
    pfd_solver = PFDWolfe(**solver_config)

    problem_config = {"ifunc_factory" : dconv_ifunc.get_ifunc,
                      "mu" : empirical_iterate_dist,
                      "delta" : delta,
                      "lam_bounds" : (1e-6, 60.0),
                      "alpha":alpha,
                      "uniform_strategy": True}
    pfd_solver.run(**problem_config)
    # Boilerplate to reset the torch default datatype and make other tests work
    torch.set_default_dtype(curr_dt)







# Task 1: 2D Gaussian Deconvolution
test_dconv_solve_mix_kde()


## Task 2: High-dimensional Gaussian Deconvolution -  random initialization 10 runs
# curr_dt = torch.get_default_dtype()
# torch.set_default_dtype(torch.float)
# rng = torch.Generator()
# rng.manual_seed(220294320)
# num_replications = 10
# alpha = 1
# delta = 0.5
# seeds = torch.randint(high = 100000000, size = (num_replications, ), generator = rng)
# m_filenames = [f"deconv_high_dim_{i+1}.pickle" for i in range(num_replications)]
#
# for repl, params in enumerate(zip(seeds, m_filenames)):
#     print(f"Executing replication {repl}")
#     seed, fname = params
#     test_dconv_solve_mix_kde_high_dim(seed.item(),fname,delta,alpha)
#     # Boilerplate to reset the torch default datatype and make other tests work
#     torch.set_default_dtype(curr_dt)


## Task 3: hyperparameter sensitivity
# curr_dt = torch.get_default_dtype()
# torch.set_default_dtype(torch.float)
# # Initialize random number generation
# rng = torch.Generator()
# rng.manual_seed(220294320)
# alpha = 1
# delta_param = [15,20,25,30,35,40]
# delta = 0.5
# alpha_param = [0.2,0.5,1,2,5,10]
# for item in alpha_param:
#     m_filename = 'deconv_high_dim_constant_alpha_'+str(item)+'.pickle'
#     random_seed = None
#     test_dconv_solve_mix_kde_high_dim(random_seed, m_filename, delta, item)
