import torch
import pytest
from itertools import product
from functools import partial

import numpy as np

# from robustopt_torch.costs import eucl_norm_sq, np_eucl_norm_sq, np_eucl_norm_sq_grad
from robustopt_torch.MMDIFunc import MMDIFunc
from robustopt_torch.student_teacher import StudentTeacher, total_num_param, one_hidden_layer_network
from robustopt_torch.kernels import gaussian_kern
from robustopt_torch.distributions import DiscreteDist, MixtureDist
from robustopt_torch.plot_metrics import metricPlotter, particlePlotter, write_to_file, metric_callback
from robustopt_torch.PFDWolfe import PFDWolfe
from robustopt_torch.funcutils import flatteniter, pack_tensors

from  robustopt_torch.logging_utils import distribution_center_statistics, gradient_norm_statistics, grad_evals

from mmd_flow_ref.student_teacher_ref import quadexp
from mmd_flow_ref.trainer import Trainer
from mmd_flow_ref.networks import quadexp
import argparse
import os

# Utility functions for plotting
################################

def mmd_metric_callback(log_file, print_most_recent, met_plotter,
                        *metric_calc_funcs):
    metric_calc_funcs = list(metric_calc_funcs)
    lam_calc_func = lambda iter_vars, metrics : \
        metrics.append_to_metric("Lambda", (iter_vars["iteration"],
                                            iter_vars["lambda"]))

    metric_calc_funcs.extend([partial(distribution_center_statistics, stats = "median"),
                             partial(gradient_norm_statistics, stats = "mean"),
                             lam_calc_func])

    return partial(metric_callback, metric_plotter = met_plotter,
                   metric_calc_funcs = metric_calc_funcs, log_file =
                   log_file, print_most_recent = print_most_recent)

def obj_eval_for_validation(iterate, validation_ifunc):
    return validation_ifunc.get_mmd(lambda x = None : iterate.vals)

def log_obj_metric(iter_vars, metric_plotter, validation_ifunc):
    obj_val = obj_eval_for_validation(iter_vars["iterate"], validation_ifunc)
    metric_plotter.append_to_metric("Objective Value", (iter_vars["iteration"] +
                                                        1, obj_val))

# Helper function to deal with getting params from a network
#############################################
def get_linear_params(linear_layer):
    bias = ()
    if linear_layer.bias is not None:
        bias = (linear_layer.bias,)
    return linear_layer.weight, *bias

# Factory for argparser which is used by the reference implementation to
# configure it's problem
def make_parser():

    parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')

    # Optimizer parameters
    parser.add_argument('--lr', default=.1, type=float, help='learning rate')
    parser.add_argument('--batch_size', default = 100 ,type= int,  help='batch size')
    parser.add_argument('--total_epochs', default=10000, type=int, help='total number of epochs')
    parser.add_argument('--optimizer',  default = 'SGD' ,type= str,   help='Optimizer')
    parser.add_argument('--use_scheduler',   action='store_true',  help=' By default uses the ReduceLROnPlateau scheduler ')

    # Loss parameters
    parser.add_argument('--loss', default = 'mmd_noise_injection',type= str,  help='loss to optimize: mmd_noise_injection, mmd_diffusion, sobolev')
    parser.add_argument('--with_noise',   default = True ,type= bool, help='to use noise injection set to true')
    parser.add_argument('--noise_level', default = 1. ,type= float,  help=' variance of the injected noise ')
    parser.add_argument('--noise_decay_freq', default = 1000 ,type= int,  help='decays the variance of the injected every 1000 epochs by a factor "noise_decay"')
    parser.add_argument('--noise_decay', default = 0.5 ,type= float,  help='factor for decreasing the variance of the injected noise')

    # Hardware parameters
    parser.add_argument('--device', default = 0 ,type= int,  help='gpu device, set -1 for cpu')
    parser.add_argument('--dtype', default = 'float32' ,type= str,  help='precision: single: float32 or double: float64')

    # Reproducibility parameters
    parser.add_argument('--seed', default = 1 ,type= int,  help='seed for the random number generator on pytorch')
    parser.add_argument('--log_dir', default = '',type= str,  help='log directory ')
    parser.add_argument('--log_name', default = 'mmd',type= str,  help='log name')
    parser.add_argument('--log_in_file', action='store_true',  help='to log output on a file')
    parser.add_argument('--metrics_log_file', default = 'reference_mmd_metrics.pickle',
                        type= str, help='file for logging metrics')

    # Network parameters
    parser.add_argument('--bias',   action='store_true',  help='ste to include bias in the network parameters')
    parser.add_argument('--teacher_net', default = 'OneHidden' ,type= str,   help='teacher network')
    parser.add_argument('--student_net', default = 'NoisyOneHidden' ,type= str,   help='student network')
    parser.add_argument('--d_int', default = 50 ,type= int,  help='dim input data')
    parser.add_argument('--d_out', default = 1 ,type= int,  help='dim out feature')
    parser.add_argument('--H', default = 3  ,type= int,  help='num of hidden layers in the teacher network')
    parser.add_argument('--num_particles', default = 1000 ,type= int,  help='num_particles*H = number of hidden units in the student network ')

    # Initialization parameters
    parser.add_argument('--mean_student', default = 0.001  ,type= float,  help='mean initial value for the student weights')
    parser.add_argument('--std_student', default = 1.  ,type= float,  help='std initial value for the student weights')
    parser.add_argument('--mean_teacher', default = 0.  ,type= float,  help='mean initial value for the teacher weights')
    parser.add_argument('--std_teacher', default = 1.  ,type= float,  help='std initial value for the teacher weights')

    # Data parameters
    parser.add_argument('--input_data', default = 'Spherical' ,type= str,   help='input data distribution')
    parser.add_argument('--N_train', default = 1000 ,type= int,  help='num samples for training')
    parser.add_argument('--N_valid', default = 1000 ,type= int,  help='num samples for validation')

    parser.add_argument('--config',  default = '' ,type= str,     help='config file for non default parameters')

    return parser


def test_mmd_flow_final(random_seed,reference_metric_file, pfd_metric_file):

    # Initialize random number generation
    if random_seed:
        rng = torch.Generator()
        rng.manual_seed(random_seed)
    else:
        rng = torch.Generator()
        rng.manual_seed(220294320)


    # Problem parameters
    in_sz = 50
    hidden_sz = 3
    out_sz = 1
    nonlinearity = quadexp()
    bias = False
    num_data_points = 100
    num_particles = 20

    # Problem parameters for reference solution
    parser = make_parser()
    args = parser.parse_args([])
    args.device = -1
    args.seed = random_seed
    args.bias = bias

    args.d_int = in_sz
    args.H = hidden_sz
    args.d_out = out_sz
    args.num_particles = num_particles

    args.batch_size = args.N_train = args.N_valid = num_data_points
    args.dtype = "float32"
    args.log_in_file = False

    # First, setup MMD with noise injection
    args.loss = "mmd_noise_injection"
    args.total_epochs = 20000
    log_file, log_ext = os.path.splitext(reference_metric_file)
    args.metrics_log_file = "".join([log_file, "_noise_injection", log_ext])

    # Create reference network
    noise_injec_network = Trainer(args)

    # Copy teacher parameters in reference implementation to pfd implementation
    teacher = noise_injec_network.teacherNet
    teacher_layers = (teacher.linear1, teacher.linear2)
    teacher_params = [param for param in flatteniter((get_linear_params(layer)
                                                      for layer in
                                                      teacher_layers),
                                                     keep_tensors = True)]
    flattened_teacher_params = pack_tensors(param.data for param in
                                            teacher_params).detach().clone()

    # Get student reference networks, training, and validation data
    student = noise_injec_network.student
    data = noise_injec_network.data_train.dataset.X.detach().clone()
    valid_data = noise_injec_network.data_valid.dataset.X.detach().clone()

    # Check that the teacher network has been properly initialized for pfd iterations
    nonlin = quadexp()
    test_output = one_hidden_layer_network(data, flattened_teacher_params,
                                           (in_sz, hidden_sz, bias),
                                           (hidden_sz, out_sz, bias),
                                           nonlin)
    teacher_ref_output = teacher(noise_injec_network.data_train.dataset.X).detach().clone()
    assert torch.allclose(teacher_ref_output, test_output)

    # Set up the iterate
    empirical_iterate_params = torch.normal(args.mean_student *
                                            torch.ones(num_particles,
                                                       flattened_teacher_params.numel()),
                                            std = args.std_student, generator =
                                            rng)
    empirical_iterate_dist = DiscreteDist(empirical_iterate_params, generator =
                                          rng)

    # Set up training and validation samplers and influence functions
    teacher_sampler = lambda x = None : flattened_teacher_params
    data_sampler = lambda x = None : data
    valid_data_sampler = lambda x = None : valid_data
    s_and_t_ifunc = StudentTeacher(data_sampler, teacher_sampler, in_sz,
                                   hidden_sz, out_sz, nonlin, bias = bias)
    s_and_t_valid = StudentTeacher(valid_data_sampler, teacher_sampler, in_sz,
                                   hidden_sz, out_sz, nonlin, bias = bias)

    # Check that teacher output on the validation set is close
    test_valid_output = s_and_t_valid.network(valid_data_sampler(),
                                              flattened_teacher_params)
    assert torch.allclose(test_valid_output, noise_injec_network.data_valid.dataset.Y)

    # Set up the metric plotters
    log_obj = partial(log_obj_metric, validation_ifunc = s_and_t_valid)
    met_plot = metricPlotter()
    metric_cb = mmd_metric_callback(pfd_metric_file, True, met_plot, log_obj,
                                    grad_evals)

    met_plot.append_to_metric("Objective Value", (0,
                                                  obj_eval_for_validation(empirical_iterate_dist,
                                                                          s_and_t_valid)))
    solver_config = {"num_iter" : 70,
                     "outer_solve_params" : {"bisection_tol" : 1e-1,
                                             "subgrad_tol" : 1e-1,
                                             "num_samp" : 30},
                     "inner_solve_params" : {"num_iter" : 5000,
                                             "lr" : 1.0,
                                             "min_lr" : 1e-8,
                                             "verbose" : True,
                                             "stopping_threshold" : 1e-6},
                     "plot_callback_ifunc" : metric_cb,
                     "print_n_iters" : 1,
                     "verbose" : True}
    pfd_solver = PFDWolfe(**solver_config)

    problem_config = {"ifunc_factory" : s_and_t_ifunc.get_ifunc,
                      "mu" : empirical_iterate_dist,
                      "delta" : 0.5,
                      "lam_bounds" : (1e-6, 15.0),
                      "alpha":0.05,
                      "uniform_strategy": True}
    pfd_solver.run(**problem_config)

    # # Train noise injection
    # noise_injec_network.train()
    #
    # # Setup the Sobolev network
    # args.loss = "sobolev"
    # args.total_epochs = 2300
    # log_file, log_ext = os.path.splitext(reference_metric_file)
    # args.metrics_log_file = "".join([log_file, "_sobolev", log_ext])
    #
    # # Train the sobolev network
    # sobolev_network = Trainer(args)
    # sobolev_network.train()

def test_mmd_flow_final_multiple():
    # Boilerplate to use the proper datatype while testing
    curr_dt = torch.get_default_dtype()
    torch.set_default_dtype(torch.float)

    # Initialize random number generation
    rng = torch.Generator()
    rng.manual_seed(130595092)

    num_replications = 10
    seeds = torch.randint(high = 100000000, size = (num_replications, ), generator = rng)

    ref_filenames = [f"reference_mmd_flow_metrics_cons_replication_{i+1}.pickle" for
                     i in range(num_replications)]
    pfd_filenames = [f"pfd_mmd_flow_metrics_cons5_replication_{i+1}.pickle" for
                     i in range(num_replications)]

    for repl, params in enumerate(zip(seeds, ref_filenames, pfd_filenames)):
        print(f"Executing replication {repl}")
        seed, ref_fname, pfd_fname = params
        test_mmd_flow_final(seed.item(), ref_fname, pfd_fname)

    # Boilerplate to reset the torch default datatype and make other tests work
    torch.set_default_dtype(curr_dt)


test_mmd_flow_final_multiple()