import numpy as np
import math
import torch
import copy
import torch.multiprocessing as mp
from timeit import default_timer as timer

from robustopt_torch.costs import eucl_norm_sq_nb
from robustopt_torch.GenericSolver import GenericSolver
from robustopt_torch.mp_utils import *

def compute_grad_est(samp_num, lam, mu, optimizer, optimizer_args):

    # Initialize the sampling variable and regularizer
    solve_params = optimizer_args.copy()
    samp = mu.sample()
    solve_params["regularizer"] = lambda x: lam * eucl_norm_sq_nb(x, samp)
    solve_params["regularizer_grad"] = lambda x: lam * (x - samp)
    solve_params["solver_stats"] = {"grad_evals" : 0, "early_stops" : 0,
                                    "ls_fail" : 0}

    # Run multiple solves to get a gradient estimate
    grad_est = 0.0
    for i in range(samp_num):
        samp = mu.sample().squeeze()
        solve_params["init_val"] = samp
        sol = optimizer.solve(**solve_params)
        grad_est += (eucl_norm_sq_nb(sol, samp).item() - grad_est) / float(i + 1)
    return (grad_est, solve_params["solver_stats"])

def mp_sampling(rank, results_queue, exit_queue, states, samp_nums, lam,
                serialized_mu, optimizer, optimizer_args):

    # Deserialize mu and initialize process specific seeds
    new_mu = deserialize_mu(serialized_mu)
    new_mu.generator = init_generator_for_subproc(states[rank])

    # Get process-specific number of samples
    my_num_samp = samp_nums[rank]

    # Deserialize objective
    optimizer_args["objective"] = deserialize_func(optimizer_args["objective"])

    # Compute gradient estimate and return result
    lam, solver_stats = compute_grad_est(my_num_samp, lam, new_mu, optimizer,
                                         optimizer_args)
    results_queue.put((rank, (lam, new_mu.generator.get_state(), solver_stats)))
    exit_queue.get()


def mp_mu_update(rank, results_queue, exit_queue, states, batches, lam,
                 serialized_mu, optimizer, optimizer_args):

    # Deserialize mu and initialize process specific seeds
    new_mu = deserialize_mu(serialized_mu)
    new_mu.generator = init_generator_for_subproc(states[rank])

    # Get process-specific batch of samples
    beg, end = batches[rank]

    # Deserialize objective
    optimizer_args["objective"] = deserialize_func(optimizer_args["objective"])

    # Update mu and send back to master with poison pill technique
    new_mu_vals = new_mu[beg:end].detach().clone()
    solver_stats = update_mu(lam, new_mu_vals, optimizer, optimizer_args)
    results_queue.put((rank, (new_mu_vals, new_mu.generator.get_state(),
                              solver_stats)))
    exit_queue.get()

def update_mu(lam, mu_vals, optimizer, optimizer_args):
    # This function will modify mu

    # Initialize the sampling variable and regularizer
    solve_params = optimizer_args.copy()
    if len(mu_vals) < 1:
        raise ValueError("Error, mu must contain at least one sample")
    samp = mu_vals[0]
    solve_params["regularizer"] = lambda x: lam * eucl_norm_sq_nb(x, samp)
    solve_params["regularizer_grad"] = lambda x: lam * (x - samp)
    solve_params["solver_stats"] = {"grad_evals" : 0, "early_stops" : 0,
                                    "ls_fail" : 0}

    # Iterate through values in mu and update them
    for i in range(len(mu_vals)):
        samp = mu_vals[i]
        solve_params["init_val"] = samp
        mu_vals[i] = optimizer.solve(**solve_params)

    return solve_params["solver_stats"]

class PFDWolfe:
    def __init__(self, **params):
        self.params = {"num_iter" : 100,
                       "outer_solve_params" : {},
                       "inner_solve_params" : {},
                       "inner_optimizer_params" : {},
                       "print_n_iters" : 10,
                       "verbose" : True}
        self.params.update(params)

    def set_params(self, **params):
        """Convenience method to set parameters."""
        self.params.update(params)

    # @profile
    def run(self, ifunc_factory, mu, delta, lam_bounds,alpha,uniform_strategy):
        # mu :samples from mu_i
        if lam_bounds[0] > lam_bounds[1]:
            raise ValueError("Interval for Frank-Wolfe lambda is invalid.")

        # Shallow copy of inner solver parameters with a default learning rate
        # if none provided
        inner_solve_params = {"lr" : 1.0}
        inner_solve_params.update(self.params["inner_solve_params"])

        # Variables to store timing data
        print_num = self.params["print_n_iters"]
        ifunc_times = torch.zeros(print_num)
        ma_times = torch.zeros(print_num)
        mu_update_times = torch.zeros(print_num)

        # Setup the inner optimizer
        inner_optimizer = self._create_inner_optimizer()

        # Configure adaption of delta
        adapt_delta_params = self._configure_adapt_delta_params()
        # adapt_delta_params = {"monotone": True,
        #                       "adapt_thresh": 0.8,
        #                       "grad_norm_samp": 250,
        #                       "delta_min": 1e-3}

        bisection_tol = self.params["outer_solve_params"].get("bisection_tol", 1e-3)
        subgrad_tol = self.params["outer_solve_params"].get("subgrad_tol", 1e-4)
        num_samp = self.params["outer_solve_params"].get("num_samp", 10)
        processes = self.params["outer_solve_params"].get("processes", 1)
        if processes > 1:
            smp = mp.get_context("spawn")
            results_queue1, results_queue2, vals_queue, weights_queue, args_queue, \
                keywords_queue, exit_queue1, exit_queue2 = [smp.SimpleQueue()
                                                            for _ in range(8)]
            proc_samp_size = split_samples(num_samp, processes)
            proc_mu_batches = batch_indexes(len(mu.vals), processes)
            proc_states = generate_states(processes, mu.generator)

        aggregate_solver_stats = {"grad_evals" : 0, "early_stops" : 0, "ls_fail" : 0}
        lam = (lam_bounds[0] + lam_bounds[1]) / 2.0

        # the outer loop for FW framework
        for i in range(self.params["num_iter"]):
            # Calculate gradient / influence function
            ifunc_stime = timer()
            # ifunc partial.v : potential function
            ifunc = ifunc_factory(lambda x=None : mu.sample(x))
            ifunc_etime = timer()
            inner_solve_params["objective"] = ifunc
            ifunc_times[i % print_num] = ifunc_etime - ifunc_stime

            # Adapt delta to the local gradient if necessary
            if adapt_delta_params:
                delta = self._adapt_delta(adapt_delta_params, ifunc, mu, delta)
            lam_l, lam_r = lam_bounds[0], lam_bounds[1]
            final_grad = float("inf")

            # solver_stats = 0
            sgd_stime = timer()
            bisection_iterations = 0
            while lam_r - lam_l > bisection_tol and abs(final_grad) > subgrad_tol:
                if lam == lam_l or lam == lam_r:
                    lam = (lam_l + lam_r) / 2.0
                if lam - lam_bounds[0] < min(2.0 * bisection_tol, 1e-2 *(lam_bounds[1] -lam_bounds[0])):
                    print("Warning: lambda is very close to left endpoint!")
                    print(f"Lambda: {lam}")
                    if self.params["outer_solve_params"].get("adapt_delta_bisec_fail", False):

                        fail_delta_params = {"grad_norm_samp" : 50, "monotone" :
                                             True, "delta_min" : 1e-3,
                                             "adapt_thresh" : 0.9}
                        delta_old = delta
                        delta = self._adapt_delta(fail_delta_params, ifunc, mu, delta)
                        if delta >= delta_old:
                            delta = 0.5 * delta_old
                        if delta < fail_delta_params["delta_min"]:
                            print("Further progress can't be made")
                            break
                        lam_l, lam_r = lam_bounds[0], lam_bounds[1]
                        lam = (lam_l + lam_r) / 2.0
                if uniform_strategy == True:
                    lam = alpha/math.sqrt(delta)
                # Adapt the step size for smoothness
                adjusted_params = self._adjust_lr_for_smoothness(lam,inner_solve_params)
                # Solve inner proximal problem
                if processes > 1:
                    serial_mu = serialize_mu(mu, processes, vals_queue, weights_queue)
                    adjusted_params["objective"] = \
                    serialize_func(adjusted_params["objective"], processes,
                                   args_queue, keywords_queue)
                    ctx = mp.start_processes(mp_sampling, args=(results_queue1,
                                                                exit_queue1,
                                                                proc_states,
                                                                proc_samp_size, lam,
                                                                serial_mu,
                                                                inner_optimizer,
                                                                adjusted_params),
                                             nprocs=processes, join = False,
                                             start_method='spawn')
                    results_lam = [results_queue1.get() for _ in range(processes)]
                    # breakpoint()
                    for _ in range(processes):
                        exit_queue1.put(1)
                    ctx.join()

                    # breakpoint()
                    proc_lams = []
                    for rank, res in results_lam:
                        p_lam, p_state, p_solver_stats  = res
                        proc_lams.append((rank, p_lam))
                        proc_states[rank] = p_state
                        for key, value in p_solver_stats.items():
                            if key in aggregate_solver_stats:
                                aggregate_solver_stats[key] += value

                    final_estimate = calculate_mean_across_procs(proc_samp_size,
                                                                 proc_lams)
                else:
                    # get the supergradient w.r.t lam
                    final_estimate, solver_stats = compute_grad_est(num_samp, lam,
                                                               mu, inner_optimizer,
                                                               adjusted_params)
                    for key, value in solver_stats.items():
                        if key in aggregate_solver_stats:
                            aggregate_solver_stats[key] += value

                final_grad = final_estimate - delta
                if final_grad < 0.0:
                    lam_r = lam
                else:
                    lam_l = lam
                bisection_iterations = bisection_iterations + 1
                if uniform_strategy == True:
                    break
            sgd_etime = timer()

            ma_times[i % print_num] = sgd_etime - sgd_stime

            adjusted_params = self._adjust_lr_for_smoothness(lam, inner_solve_params)

            mu_stime = timer()

            if processes > 1:
                adjusted_params["objective"] = \
                serialize_func(adjusted_params["objective"], processes,
                               args_queue, keywords_queue)
                serial_mu = serialize_mu(mu, processes, vals_queue, weights_queue)
                ctx = mp.start_processes(mp_mu_update, args=(results_queue2,
                                                             exit_queue2,
                                                             proc_states,
                                                             proc_mu_batches,
                                                             lam, serial_mu,
                                                             inner_optimizer,
                                                             adjusted_params),
                                         nprocs=processes,
                                         join = False,
                                         start_method='spawn')
                results_mu = [results_queue2.get() for _ in range(processes)]
                for _ in range(processes):
                    exit_queue2.put(1)
                ctx.join()
                for p_id, res in results_mu:
                    p_mu_vals, p_state, p_solver_stats = res
                    proc_states[p_id] = p_state
                    p_beg, p_end = proc_mu_batches[p_id]
                    mu.vals[p_beg:p_end] = p_mu_vals
                    for key, value in p_solver_stats.items():
                        if key in aggregate_solver_stats:
                            aggregate_solver_stats[key] += value
            else:
                solver_stats = update_mu(lam, mu.vals, inner_optimizer,adjusted_params)
                for key, value in solver_stats.items():
                    if key in aggregate_solver_stats:
                        aggregate_solver_stats[key] += value

            mu_etime = timer()
            mu_update_times[i % print_num] = mu_etime - mu_stime

            if (i+1) % print_num == 0:
                if "plot_callback" in self.params:
                    self.params["plot_callback"](i, mu)
                if "plot_callback_ifunc" in self.params:
                    self.params["plot_callback_ifunc"]({"iteration" : i,
                                                        "iterate" : mu,
                                                        "objective" : ifunc,
                                                        "lambda" : lam,
                                                        "delta" : delta,
                                                        "early_stops" : aggregate_solver_stats["early_stops"],
                                                        "grad_evals" : aggregate_solver_stats["grad_evals"]})
                if self.params["verbose"]:
                    print("Completed iteration: {}".format(i+1))
                    print("Average time for ifunc: {}".format(ifunc_times.mean()))
                    print("Average time for MA: {}".format(ma_times.mean()))
                    print("Average time for mu update: {}".format(mu_update_times.mean()))
                    print("Number of line search fails: {}".format(aggregate_solver_stats["ls_fail"]))
                    print("Number of bisection: {}".format(bisection_iterations))

    def _configure_adapt_delta_params(self):
        """Initialize the delta adaption parameters properly."""
        adapt_delta = self.params["outer_solve_params"].get("adapt_delta", False)
        passed_delta_params = self.params["outer_solve_params"].get("adapt_delta_params", None)
        if adapt_delta or passed_delta_params is not None:
            adapt_delta_params = {"monotone" : True,
                                  "adapt_thresh" : 0.8,
                                  "grad_norm_samp" : 250,
                                  "delta_min" : 1e-3}
            adapt_delta_params.update(passed_delta_params)
        else:
            adapt_delta_params = {}

        return adapt_delta_params

    def _adapt_delta(self, adelta_params, ifunc, mu, delta):
        """Change delta to accomodate for changes in the gradient."""
        # import pdb
        # pdb.set_trace()
        grad_samples = mu.sample(adelta_params["grad_norm_samp"]).requires_grad_(True)
        for samp in grad_samples: ifunc(samp).backward()
        est_grad_norm = grad_samples.grad.square().sum(-1).mean()
        # print(est_grad_norm)
        if adelta_params["monotone"]:
            delta = min(delta,
                        adelta_params["adapt_thresh"] * est_grad_norm)
        else:
            delta = adelta_params["adapt_thresh"] * est_grad_norm

        if delta < adelta_params["delta_min"]:
            if self.params["verbose"]:
                print("Warning, delta in outer loop of Frank-Wolfe has " \
                      "become smaller than : {}".format(adelta_params["delta_min"]))
            delta = adelta_params["delta_min"]
        return delta

    def _create_inner_optimizer(self):
        """Initialize the parameters of the inner optimizer and create it."""

        inner_optimizer_params = {"optimizer" : "gd"}
        inner_optimizer_params.update(self.params["inner_optimizer_params"])
        return GenericSolver(**inner_optimizer_params)

    def _adjust_lr_for_smoothness(self, lam, params):
        """Adjust learning rate for smoothness of the proximal function."""

        old_lr = params["lr"]
        new_lr = old_lr / max(lam, 1.0)
        min_lr = params.get("min_lr", float("-inf"))
        if new_lr < min_lr:
            print("Warning: inner solver learning rate, adjusted for " \
                  "smoothness, is less rate than the minimum learning rate")
            new_lr = min_lr
        return {**params, "lr" : new_lr}
