import torch
from robustopt_torch.distributions import is_discrete
from robustopt_torch.costs import eucl_norm_sq
from robustopt_torch.funcutils import ensurelist

def grad_evals(iteration_variables, metric_plotter):
    metric_plotter.append_to_metric("Gradient evaluations",
                                    (iteration_variables["iteration"],
                                     iteration_variables["grad_evals"]))

def tensor_stats(tensor, stats = "all"):
    supported_stats = ["mean", "median", "max", "min"]
    if stats == "all":
        stats = supported_stats
    stats = set(ensurelist(stats))
    if not stats.issubset(set(supported_stats)):
        raise ValueError("Unsupported tensor statistic!")

    computed_stats = {}
    if "mean" in stats:
        computed_stats["mean"] = tensor.mean().item()
    if "median" in stats:
        computed_stats["median"] = tensor.median().item()
    if "max" in stats:
        computed_stats["max"] = tensor.max().item()
    if "min" in stats:
        computed_stats["min"] = tensor.min().item()
    return computed_stats

def flattened_upper_tri(tensor, diag = 0):
    if tensor.dim() < 2:
        raise ValueError("Tensor must be at least two dimensional")
    row, col = tensor.shape[-2:]
    indices = torch.triu_indices(row, col, offset = diag)
    indices = indices[0] * col + indices[1]
    return torch.index_select(torch.flatten(tensor, start_dim=-2, end_dim=-1),
                              -1, indices)

def distribution_center_statistics(iteration_variables, metric_plotter,
                                   stats = "all"):
    iterate = iteration_variables["iterate"]
    if not is_discrete(iterate):
        raise ValueError("Iterate is not a discrete distribution!")
    center_dists = flattened_upper_tri(eucl_norm_sq(iterate.vals, iterate.vals),
                                       diag = 1)
    center_stats = tensor_stats(center_dists, stats)
    iteration = iteration_variables["iteration"]
    for stat, val in center_stats.items():
        metric_plotter.append_to_metric(f"{stat.capitalize()} center distance",
                                        (iteration, val))

def collect_gradient_norm_stats(samples, objective, stats):
    grad_vars = samples.detach().clone().requires_grad_(True)
    objective(grad_vars).sum().backward()
    grad_norms = grad_vars.grad.square().sum(-1)
    return tensor_stats(grad_norms, stats)

def gradient_norm_statistics(iteration_variables, metric_plotter, stats = "all",
                             objective = None, num_samp = float("inf")):
    iterate = iteration_variables["iterate"]
    if num_samp == float("inf") and not is_discrete(iterate):
        raise ValueError("Iterate is not a discrete distribution! " \
                         "Sample number must be specified.")
    if objective is None:
        objective = iteration_variables["objective"]

    if num_samp == float("inf"):
        grad_stats = collect_gradient_norm_stats(iterate.vals, objective, stats)
    else:
        grad_stats = collect_gradient_norm_stats(iterate.sample(num_samp),
                                                 objective, stats)

    iteration = iteration_variables["iteration"]
    for stat, val in grad_stats.items():
        metric_plotter.append_to_metric(f"{stat.capitalize()} gradient norm",
                                        (iteration, val))
