import torch
import torch.nn as nn

import torch.distributed as dist
from modules.gather import GatherLayer


def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()


class BarlowTwinsLoss(torch.nn.Module):

    def __init__(self, device, world_size, lambda_param=5e-3, proj_dim=512):
        super(BarlowTwinsLoss, self).__init__()
        self.lambda_param = lambda_param
        self.device = device
        self.world_size = world_size
        self.bn = nn.BatchNorm1d(proj_dim, affine=False)
        if proj_dim > 4048:
            self.scale_loss = 1
        else:
            self.scale_loss = 1

    def forward(self, z_a: torch.Tensor, z_b: torch.Tensor):
        # Method 1
        # normalize repr. along the batch dimension
        # if self.world_size > 1:
        #     z_a = torch.cat(GatherLayer.apply(z_a), dim=0)  # [B, C=D]
        #     z_b = torch.cat(GatherLayer.apply(z_b), dim=0)

        # z_a_norm = (z_a - z_a.mean(0)) / (z_a.std(0) + 1e-6)  # NxD
        # z_b_norm = (z_b - z_b.mean(0)) / (z_b.std(0) + 1e-6)  # NxD

        # N, D = z_a_norm.shape
        #
        # # cross-correlation matrix
        # c = torch.mm(z_a_norm.T, z_b_norm) / N  # DxD
        # # loss
        # c_diff = (c - torch.eye(D, device=self.device)).pow(2)  # DxD
        # # multiply off-diagonal elems of c_diff by lambda
        # c_diff[~torch.eye(D, dtype=bool)] *= self.lambda_param
        # loss = c_diff.sum()

        # Method 2, empirical cross-correlation matrix
        N, _ = z_a.shape
        c = self.bn(z_a).T @ self.bn(z_b)

        # sum the cross-correlation matrix between all gpus
        c.div_(N*self.world_size)
        if self.world_size > 1:
            torch.distributed.all_reduce(c)

        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c).pow_(2).sum()
        loss = on_diag + self.lambda_param * off_diag
        return loss


class BarlowTwinsTemporalLoss(torch.nn.Module):

    def __init__(self, device, world_size, cross_temporal=False, simplified_loss=False, lambda_param=5e-3, proj_dim=512):
        super(BarlowTwinsTemporalLoss, self).__init__()
        self.lambda_param = lambda_param
        self.device = device
        self.world_size = world_size
        self.cross_temporal = cross_temporal
        self.simplified_loss = simplified_loss
        self.bn = nn.BatchNorm1d(proj_dim, affine=False)
        if proj_dim > 4048:
            self.scale_loss = 1
        else:
            self.scale_loss = 1

    def forward(self, z_a: torch.Tensor, z_b: torch.Tensor):
        # normalize repr. along the batch dimension
        if not self.cross_temporal:
            z_a = z_a.flatten(0, 1)
            z_b = z_b.flatten(0, 1)

            N, _ = z_a.shape
            c = self.bn(z_a).T @ self.bn(z_b)

            # sum the cross-correlation matrix between all gpus
            c.div_(N * self.world_size)
            if self.world_size > 1:
                torch.distributed.all_reduce(c)

            on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
            off_diag = off_diagonal(c).pow_(2).sum()
            loss = on_diag + self.lambda_param * off_diag
        else:
            # compute cross temporal loss
            T, B, D = z_a.shape
            # only compute the loss composed of the first and the last time-step.
            if self.simplified_loss:
                # CREATE A NEW TENSOR
                z = torch.zeros((4, B, D), device=z_a.device)
                z[0, ...] = z_a[0, ...]
                z[1, ...] = z_a[T - 1, ...]
                z[2, ...] = z_b[0, ...]
                z[3, ...] = z_b[T - 1, ...]
                iteration = 2
                t_scale = 12
            else:
                z = torch.cat((z_a, z_b), dim=0)
                iteration = T
                t_scale = 2*T*(2*T-1)

            # loss initialization
            loss = 0
            for i in range(2 * iteration):
                for j in range(2 * iteration):
                    if i == j:
                        continue
                    z_a = z[i, ...]  # [B,D]
                    z_b = z[j, ...]

                    N = B
                    c = self.bn(z_a).T @ self.bn(z_b)

                    # sum the cross-correlation matrix between all gpus
                    c.div_(N * self.world_size)
                    if self.world_size > 1:
                        torch.distributed.all_reduce(c)

                    on_diag = torch.diagonal(c).add_(-1).pow_(2).sum().mul(self.scale_loss)
                    off_diag = off_diagonal(c).pow_(2).sum().mul(self.scale_loss)
                    loss += on_diag + self.lambda_param * off_diag

            loss /= t_scale  # [2T*(2T-1)/2]*2
        return loss
