import math

import torch
import torch.distributed as dist
import torch.nn.functional as F
from tqdm import tqdm


@torch.no_grad()
def extract_patches(layer, X):
    X = F.unfold(
        X,
        kernel_size=layer.kernel_size,
        dilation=layer.dilation,
        padding=layer.padding,
        stride=layer.stride,
    )  # (batch, ch*ks*ks, height*width)
    X = X.permute(1, 0, 2)  # (ch*ks*ks, batch, height*width)
    X = X.reshape(X.shape[0], -1)  # (dim, batch*height*width)
    return X


# from https://carstenschelp.github.io/2019/05/12/Online_Covariance_Algorithm_002.html
@torch.jit.script
def combine_cov(count, mean, cov, other_count, other_mean, other_cov):
    merged_count = count + other_count
    count_corr = (other_count * count) / merged_count

    flat_mean_diff = other_mean - mean
    mean.add_(flat_mean_diff * other_count / merged_count)

    mean_diffs = torch.broadcast_to(flat_mean_diff, cov.shape).T
    cov.mul_(count / merged_count)
    cov.add_(
        other_cov * (other_count / merged_count)
        + mean_diffs * mean_diffs.T * (count_corr / merged_count)
    )
    count.copy_(merged_count)


@torch.no_grad()
def dataset_cov(train_loader, loader_instance,
                model, layers: list, layer_names: list,
                sample_portion=1.0):
    iters = math.ceil(len(train_loader) * sample_portion)

    counts = {
        name: torch.zeros((), device='cuda', dtype=torch.long)
        for name in layer_names
    }
    means = {name: None for name in layer_names}
    covs = {name: None for name in layer_names}
    ids = {}

    if loader_instance is None:
        loader_instance = iter(train_loader)
    for _ in tqdm(range(iters), desc="Calculating covariance"):
        try:
            X, y = next(loader_instance)
        except StopIteration:
            loader_instance = iter(train_loader)
            X, y = next(loader_instance)
        X = X.cuda(non_blocking=True).float()

        coro = model.coroutine(X)
        out = None
        while True:
            try:
                out, cur_layer = coro.send(out)
            except StopIteration:
                break

            for layer, layer_name in zip(layers, layer_names):
                if cur_layer is layer:
                    ids[id(cur_layer)] = layer_name

                    patches = extract_patches(layer, out)

                    other_count = torch.tensor(patches.size(1), dtype=torch.long)
                    other_mean = patches.mean(1)
                    other_cov = torch.cov(patches)

                    count = counts[layer_name]
                    if count == 0:
                        count.copy_(other_count)
                        means[layer_name] = other_mean
                        covs[layer_name] = other_cov
                    else:
                        mean = means[layer_name]
                        cov = covs[layer_name]
                        combine_cov(
                            count, mean, cov,
                            other_count, other_mean, other_cov
                        )
                    break

    # reduce
    print("Combining covariances")
    for layer_name in layer_names:
        count = counts[layer_name]
        mean = means[layer_name]
        cov = covs[layer_name]

        other_count = count.new_empty(count.shape)
        other_mean = mean.new_empty(mean.shape)
        other_cov = cov.new_empty(cov.shape)

        rank, world_size = dist.get_rank(), dist.get_world_size()
        gap = 2
        for stage in range(math.ceil(math.log2(world_size))):
            dst = rank // 2 * gap
            src = dst + gap // 2
            if rank % 2 == 0:
                if src < world_size:
                    dist.recv(other_count, src)
                    dist.recv(other_mean, src)
                    dist.recv(other_cov, src)
                    combine_cov(
                        count, mean, cov,
                        other_count, other_mean, other_cov
                    )
            else:
                dist.send(count, dst)
                dist.send(mean, dst)
                dist.send(cov, dst)
                break
            rank //= 2
            gap *= 2

        dist.broadcast(count, 0)
        dist.broadcast(mean, 0)
        dist.broadcast(cov, 0)

    return covs, ids, loader_instance
