
import torch
import torch.nn as nn


def _build_mlp(num_layers, input_dim, mlp_dim, output_dim, last_bn=False, dropout=True, dropout_p=0.1):
    mlp = []
    for l in range(num_layers):
        dim1 = input_dim if l == 0 else mlp_dim
        dim2 = output_dim if l == num_layers - 1 else mlp_dim

        # print('! layer {} dim1: {} dim2: {} !'.format(l, dim1, dim2))
        if l > 0 and dropout:
            mlp.append(nn.Dropout(dropout_p))
        mlp.append(nn.Linear(dim1, dim2, bias=False))

        if l < num_layers - 1:
            # mlp.append(nn.BatchNorm1d(dim2))
            mlp.append(nn.LayerNorm([dim2]))
            mlp.append(nn.ReLU(inplace=True))
        elif last_bn:
            # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157
            # for simplicity, we further removed gamma in BN
            # mlp.append(nn.BatchNorm1d(dim2, affine=False))
            mlp.append(nn.LayerNorm([dim2]))
            pass

    return nn.Sequential(*mlp)


def rescale_ks(ks, x):
    # print('raw ks: {}'.format(ks))
    n_k = ks.shape[0]
    k_sum = sum(ks)
    batch_size = x.shape[0]

    ratio = batch_size / k_sum
    ks = torch.tensor([int(s * ratio) for s in ks]).long()
    curr_cls_idx = 0
    while sum(ks) < batch_size:
        ks[curr_cls_idx % n_k] += 1
        curr_cls_idx += 1

    # print('final ks: {}\n\t{}'.format(ks, sum(ks)))

    return ks

@torch.no_grad()
def kmax_sinkhorn(x, ks, n_iters=20, rescale_k=True):
    # print('kmax sinkhorn')
    Q = torch.exp(x / 0.05)
    bsz = Q.shape[0]
    if rescale_k:
        ks = rescale_ks(ks, x)

    max_k = int(ks.max())
    ks = ks.to(x.device)
    z_weight = ((ks - 0.1) / ks)

    for it in range(n_iters):
        col_sorted, _ = torch.topk(Q, dim=0, k=max_k)
        z = torch.gather(col_sorted, dim=0, index=(ks - 1).unsqueeze(0))
        col_sum = torch.sum(Q, dim=0, keepdim=True)
        new_z = z_weight * z + (1 - z_weight) * col_sum
        Q /= new_z
        Q /= bsz

        sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
        Q /= sum_of_rows
        Q /= bsz

    Q *= bsz
    sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
    Q /= sum_of_rows

    return Q


@torch.no_grad()
def hybrid_sinkhorn(x, ks, n_iters, rescale_k=True, alpha=0.1):
    # print('hybrid sinkhorn')
    Q = torch.exp(x / 0.05)
    bsz = Q.shape[0]
    if rescale_k:
        ks = rescale_ks(ks, x)
    # print('ks: {}'.format(ks))
    max_k = int(ks.max())
    ks = ks.to(x.device)
    z_weight = ((ks - (alpha * ks)) / ks)

    for it in range(n_iters):
        col_sorted, _ = torch.topk(Q, dim=0, k=max_k)
        z = torch.gather(col_sorted, dim=0, index=(ks - 1).unsqueeze(0))
        col_sum = torch.sum(Q, dim=0, keepdim=True)
        col_sum = col_sum / ks
        # print('iter: {}'.format(it))
        # print('z: {}'.format(z.view(-1)))
        # print('col_sum: {}'.format(col_sum.view(-1)))
        # input('okty')
        new_z = z_weight * z + (1 - z_weight) * col_sum
        Q /= new_z
        Q /= bsz

        sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
        Q /= sum_of_rows
        Q /= bsz

    Q *= bsz
    sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
    Q /= sum_of_rows

    return Q


def make_linear_sinkhorn_scalar(n_classes, max_scalar, min_scalar=1):
    n_steps = n_classes - 1
    scalars = [(idx / n_steps) * min_scalar + ((n_steps - idx) / n_steps) * max_scalar for idx in range(n_classes)]
    return torch.tensor(scalars)


@torch.no_grad()
def my_scaling_sinkhorn(x, per_cluster_scalar, n_iters=3, random_add=False, amplification=None):
    Q = torch.exp(x / 0.05)
    bsz = Q.shape[0]
    per_cluster_scalar = per_cluster_scalar.to(x.device)
    if amplification is not None:
        amplification = amplification.to(x.device)
        # print('per_cluster_scalar pre amplification: {}'.format(per_cluster_scalar))
        per_cluster_scalar *= amplification
        # print('per_cluster_scalar post amplification: {}'.format(per_cluster_scalar))

    # print('raw: {}'.format(per_cluster_scalar))
    ratio = bsz / per_cluster_scalar.sum()
    per_cluster_scalar *= ratio
    # print('scaled: {}'.format(per_cluster_scalar))

    Q /= torch.sum(Q)
    for it in range(n_iters):
        col_sum = torch.sum(Q, dim=0, keepdim=True)
        col_sum /= per_cluster_scalar
        Q /= col_sum
        Q /= bsz

        if random_add and it % 10 == 0:
            Q += torch.rand_like(Q) * 0.25

        sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
        Q /= sum_of_rows
        Q /= bsz

    Q *= bsz

    return Q


@torch.no_grad()
def concat_all_gather(x):
    # print('\tx on device {}: {}'.format(self.rank, x.shape))
    n_x = torch.tensor([x.shape[0]], device=x.device)
    # print('\tn_x on device {}: {}'.format(self.rank, n_x))
    n_x_list = [torch.zeros_like(n_x) for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(n_x_list, n_x)
    n_x = torch.cat(n_x_list, dim=0).contiguous()
    # dist.barrier()
    max_size = n_x.max() + 1

    indicator = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)
    # print('\tindicator on device {}: {}'.format(self.rank, indicator.shape))

    if x.shape[0] != max_size:
        # print('\tPadding x on device {} (raw shape: {})'.format(self.rank, x.shape))
        x_padding = torch.zeros(max_size - x.shape[0], *x.shape[1:], device=x.device, dtype=x.dtype)
        indicator_padding = torch.zeros(max_size - x.shape[0], device=x.device, dtype=torch.bool)

        x = torch.cat([x, x_padding], dim=0).contiguous()
        # print('\t\tNew shape of x: {}'.format(x.shape))
        indicator = torch.cat([indicator, indicator_padding], dim=0).contiguous()
        # print('\t\tNew shape of indicator: {}'.format(indicator.shape))

    # dist.barrier()
    x_list = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(x_list, x)
    x = torch.cat(x_list, dim=0).contiguous()
    # if self.rank == 0:
    #     print('\tRaw aggregated x: {}'.format(x.shape))

    indicator_list = [torch.zeros_like(indicator) for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(indicator_list, indicator)
    indicator = torch.cat(indicator_list, dim=0).contiguous()
    # if self.rank == 0:
    #     print('\tAggregated indicator: {}'.format(indicator.shape))
    # dist.barrier()
    # print('\tmoving to cpu...')
    # x = x.cpu()
    # print('\tx on cpu...')
    # indicator = indicator.cpu()
    # print('\tindicator on cpu...')

    x = x[indicator == 1]
    # if self.rank == 0:
    #     print('\tAggregated x after considering indicator: {}'.format(x.shape))

    # dist.barrier()
    # print('\treturning')
    return x


