import torch
import torch.nn as nn
import torch.nn.functional as F


class DINOHead(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, mlp_dims=None, norm: bool = True) -> None:
        super().__init__()
        mlp_dims = mlp_dims or [in_dim * 2]
        dims = [in_dim] + mlp_dims + [out_dim]
        layers = []
        for i in range(len(dims) - 2):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            layers.append(nn.GELU())
        layers.append(nn.Linear(dims[-2], dims[-1]))
        self.mlp = nn.Sequential(*layers)
        self.norm = norm

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.mlp(x)
        if self.norm:
            y = F.normalize(y, dim=-1)
        return y


class MultiViewDINOLoss(nn.Module):
    def __init__(self, teacher_temp: float, student_temp: float, center_momentum: float) -> None:
        super().__init__()
        self.teacher_temp = teacher_temp
        self.student_temp = student_temp
        self.center_m = center_momentum
        self.register_buffer("center", torch.zeros(1), persistent=False)

    def forward(self, s1: torch.Tensor, s2: torch.Tensor, t1: torch.Tensor, t2: torch.Tensor) -> torch.Tensor:
        if self.center.ndim == 1:
            self.center = torch.zeros(1, t1.shape[-1], device=t1.device, dtype=t1.dtype)
        t1c = (t1 - self.center) / max(self.teacher_temp, 1e-6)
        t2c = (t2 - self.center) / max(self.teacher_temp, 1e-6)
        s1t = s1 / max(self.student_temp, 1e-6)
        s2t = s2 / max(self.student_temp, 1e-6)
        pt1 = F.softmax(t1c, dim=-1).detach()
        pt2 = F.softmax(t2c, dim=-1).detach()
        ps1 = F.log_softmax(s1t, dim=-1)
        ps2 = F.log_softmax(s2t, dim=-1)
        l1 = -(pt2 * ps1).sum(dim=-1).mean()
        l2 = -(pt1 * ps2).sum(dim=-1).mean()
        with torch.no_grad():
            new_center = pt1.mean(dim=0) + pt2.mean(dim=0)
            new_center = new_center / 2
            self.center = self.center * self.center_m + new_center * (1 - self.center_m)
        return (l1 + l2) / 2
