import torch.nn as nn
from src.model import get_backbone, get_projection
import torch.nn.functional as F
import math
import torch
import torch.distributed as dist

class BaseMethod(nn.Module):
    """
        Base class for self-supervised loss implementation.
        It includes backbone and projection for training function.
    """
    def __init__(self, cfg):
        super().__init__()
        self.backbone, self.out_size = get_backbone(cfg)
        self.projection = get_projection(self.out_size, cfg)
        self.emb_size = cfg.emb
        self.dist = cfg.distributed
        self.trade_off = (math.log2(cfg.bs) - 3) * 0.01
        self.m = 0

    def INTL(self, x):  # Iterative Normalization with Trace loss
        _, D = x.size()
        d = torch.pow(x, 2).sum(axis = 1) / (D - 1)
        tl = d.add_(-1).pow_(2).sum()
        return tl

    def IterNorm(self, x, iters=4): # Iterative Normalization
        if self.dist:
            x = torch.cat(FullGatherLayer.apply(x), dim=0)
        M, D = x.size() # x: m * d
        x = x - x.mean(dim=1).reshape(M, 1)
        sigma = (x @ x.T) / (D - 1) # covariance matrix
        trace = sigma.diagonal().sum()
        sigma_norm = sigma / trace # normalize sigma
        P = torch.eye(M, device='cuda') # identity matrix: m * m
        for _ in range(iters):
            P = 1/2 * (3 * P - torch.matrix_power(P, 3) @ sigma_norm)
        return P / trace.sqrt() @ x
    
    def norm_mse(self, x0, x1):
        x0 = F.normalize(x0)
        x1 = F.normalize(x1)
        return 2 - 2 * (x0 * x1).sum(dim=-1).mean()
    
    def forward(self, samples):
        raise NotImplementedError
    
class FullGatherLayer(torch.autograd.Function):
    """
    Gather tensors from all process and support backward propagation
    for the gradients across processes.
    """
    @staticmethod
    def forward(ctx, x):
        output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
        dist.all_gather(output, x)
        return tuple(output)
    @staticmethod
    def backward(ctx, *grads):
        all_gradients = torch.stack(grads)
        dist.all_reduce(all_gradients)
        return all_gradients[dist.get_rank()]