import numpy as np
import torch
from torchmetrics import CalibrationError

from src.simplex_layers import SimplexLayer


def projection_simplex(V, z=1, axis=None):
    """
    Projection of x onto the simplex, scaled by z:
        P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2
    z: float or array
        If array, len(z) must be compatible with V
    axis: None or int
        axis=None: project V by P(V.ravel(); z)
        axis=1: project each V[i] by P(V[i]; z[i])
        axis=0: project each V[:, j] by P(V[:, j]; z[j])
    """
    if axis == 1:
        n_features = V.shape[1]
        U = np.sort(V, axis=1)[:, ::-1]
        z = np.ones(len(V)) * z
        cssv = np.cumsum(U, axis=1) - z[:, np.newaxis]
        ind = np.arange(n_features) + 1
        cond = U - cssv / ind > 0
        rho = np.count_nonzero(cond, axis=1)
        theta = cssv[np.arange(len(V)), rho - 1] / rho
        return np.maximum(V - theta[:, np.newaxis], 0)

    elif axis == 0:
        return projection_simplex(V.T, z, axis=1).T

    else:
        V = V.ravel().reshape(1, -1)
        return projection_simplex(V, z, axis=1).ravel()


def compute_riesz_s_energy(simplex_points, d=2):
    diff = (simplex_points[:, None] - simplex_points[None, :])
    # calculate the squared euclidean from each point to another
    dist = np.sqrt((diff ** 2).sum(axis=2))
    # make sure the distance to itself does not count
    np.fill_diagonal(dist, np.inf)
    # epsilon which is the smallest distance possible to avoid an overflow during gradient calculation
    # eps = 10 ** (-320 / (d + 2))
    eps = 1e-4
    b = dist < eps
    dist[b] = eps
    # select only upper triangular matrix to have each mutual distance once
    mutual_dist = dist[np.triu_indices(len(simplex_points), 1)]    
    mutual_dist[np.argwhere(mutual_dist == 0).flatten()] = 1e-4
    # calculate the energy by summing up the squared distances
    energies = (1 / mutual_dist**d)
    energies = energies[~np.isnan(energies)]
    energy = energies.sum()
    log_energy = - np.log(len(mutual_dist)) + np.log(energy)
    return energy, log_energy


def get_label_dist(loader, num_classes):
    tmp_stats = np.zeros(num_classes)
    _, tmp_labels = zip(*loader.dataset)
    tmp_unique_labels, tmp_unique_counts = np.unique(tmp_labels, return_counts=True)
    for i, tmp_unique_label in enumerate(tmp_unique_labels):
        tmp_stats[int(tmp_unique_label)] = tmp_unique_counts[i]
    label_dist = np.array(tmp_stats)
    return label_dist


def get_stats(net, num_points, network_type, network_arch):

    # TODO this is the current behavior, is this correct?
    if network_arch == "PretrainedSqueezeNet":
        return 0, 0

    norms = {}
    numerators = {}
    difs = {}
    for i in range(num_points):
        norms[i] = 0.0
        for j in range(i + 1, num_points):
            numerators[(i, j)] = 0.0
            difs[(i, j)] = 0.0

    for idx, m in enumerate(net.modules()):
        # if is_full or is_fc_last:
        if isinstance(m, torch.nn.ParameterList):
            for i in range(num_points):
                vi = m[i]
                norms[i] += vi.pow(2).sum()
                for j in range(i + 1, num_points):
                    vj = m[j]
                    numerators[(i, j)] += (vi * vj).sum()
                    difs[(i, j)] += (vi - vj).pow(2).sum()

    cossim = 0
    l2 = 0 
    for i in range(num_points):
        for j in range(i + 1, num_points):
            cossim += numerators[(i, j)].pow(2) / (norms[i] * norms[j])
            l2 += difs[(i, j)]
    l2 = l2.pow(0.5).item()
    cossim = cossim.item()
    return cossim, l2


def test(
    net,
    valloader,
    client_alpha,
    num_classes,
    device: str,
    strategy_name: str,
):
    """Validate the network on the entire test set."""
    if len(valloader) == 0:  # skip this step if there are no validation samples
        return {
            "val_loss": 0,
            "val_acc": 0,
            "val_ece": 0,
        }

    net = net.to(device=device).eval()
    ece_fn = CalibrationError(task="multiclass", num_classes=num_classes)

    accumulated_loss = 0.0
    correct_samples = 0
    pred_vectors = []
    label_vectors = []
    with torch.no_grad():
        if strategy_name == "FLOCO":
            if np.array(client_alpha).ndim > 1:
                sampled_alpha = client_alpha[np.random.choice(len(client_alpha))]
            else:
                sampled_alpha = client_alpha
            set_net_alpha(net, sampled_alpha)
        tmp_pred_vector = []
        for batch in valloader:
            data, target = batch
            data, target = data.to(device=device), target.flatten().to(device=device)
            output = net(data)
            pred = output.argmax(dim=1)
            correct_samples += torch.sum(pred == target).item()
            tmp_pred_vector.append(torch.softmax(output, dim=1).detach().cpu())
            label_vectors.append(target.detach().cpu())  # TODO ???
            accumulated_loss += torch.nn.CrossEntropyLoss()(output, target).item()
        tmp_pred_vector_cat = np.concatenate(tmp_pred_vector)
        pred_vectors.append(tmp_pred_vector_cat)

    val_loss = accumulated_loss / len(valloader)
    val_acc = correct_samples / len(valloader.dataset)

    label_vector = torch.cat(label_vectors)
    pred_vectors_mean = torch.tensor(np.array(pred_vectors)).mean(axis=0)

    return {
        "val_loss": val_loss,
        "val_acc": val_acc,
        "val_ece": ece_fn(pred_vectors_mean, label_vector)
    }


def server_test_approx(server_round, net, testloader, writer, num_classes, cfg, folds=None, strategy_config={}) -> None:
    """Validate the network on the entire test set."""
    cluster_centers = strategy_config.get('centers', [])

    net = net.to(device=cfg.device).eval()
    ece_fn = CalibrationError(task="multiclass", num_classes=num_classes)

    # Ignore unnecessary endpoint evaluation
    simplex_center = np.ones(cfg.rule.num_points) / np.ones(cfg.rule.num_points).sum()
    alpha_iterator = [simplex_center]

    all_targets = []
    classes = np.arange(num_classes)
    for ii, alpha in enumerate(alpha_iterator):
        if cfg.rule.num_points > 1:
            set_net_alpha(net, alpha)

        accumulated_loss = 0.0
        correct_samples = 0
        pred_vectors = []
        label_vectors = []

        # count predictions for each class
        correct_pred = {classname: 0 for classname in classes}
        total_pred = {classname: 0 for classname in classes}
        with torch.no_grad():
            for batch in testloader:
                data, target = batch
                data, target = data.to(cfg.device), target.flatten().to(cfg.device)
                if ii == 0:
                    all_targets.append(target.cpu())
                output = net(data)
                accumulated_loss += torch.nn.CrossEntropyLoss()(output, target).item()
                # get the index of the max log-probability
                pred = output.argmax(dim=1)
                correct_samples += torch.sum(pred == target).item()
                pred_vectors.append(torch.softmax(output, dim=1).detach().cpu())
                label_vectors.append(target.detach().cpu())
                # Get single class accs
                for label, prediction in zip(target, pred):
                    if label == prediction:
                        correct_pred[classes[label]] += 1
                    total_pred[classes[label]] += 1

        test_loss = accumulated_loss / len(testloader)
        test_acc = correct_samples / len(testloader.dataset)

        pred_vector = torch.cat(pred_vectors)
        label_vector = torch.cat(label_vectors)
        test_ece = ece_fn(pred_vector, label_vector)

        writer.add_scalar(f"test/center_loss", test_loss, server_round)
        writer.add_scalar(f"test/center_acc", test_acc, server_round)
        writer.add_scalar(f"test/center_ece", test_ece, server_round)

        # Add single model accs to all subspace sampled alphas for comparison
        if cfg.rule.num_points == 1:
            writer.add_scalar(f"test/ensembled_test_acc", test_acc, server_round)

    if cfg.rule.num_points > 1:
        cossim, l2 = get_stats(net, cfg.rule.num_points, cfg.strategy.network_type, cfg.dataset_model.network_arch)
        writer.add_scalar(f"test/norm", l2, server_round)
        writer.add_scalar(f"test/cossim", cossim, server_round)
        ensembled_test_acc = 0.0
        writer.add_scalar(f"test/ensembled_test_acc", ensembled_test_acc, server_round)


def set_net_alpha(net, alphas: tuple[float, ...]):
    for m in net.modules():
        if isinstance(m, SimplexLayer):
            m.set_alphas(alphas)
