
import numpy as np
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn.functional as F


# class iBOTLoss(nn.Module):
#     def __init__(self, out_dim, warmup_teacher_temp, teacher_temp, 
#                  warmup_teacher_temp_epochs, nepochs, ngcrops=2, student_temp=0.1, 
#                  center_momentum=0.9, lambda1=1.0):
#         super().__init__()
#         self.student_temp = student_temp
#         self.center_momentum = center_momentum
#         self.ngcrops = ngcrops
#         self.register_buffer("center", torch.zeros(1, out_dim))
#         self.lambda1 = lambda1

#         # we apply a warm up for the teacher temperature because
#         # a too high temperature makes the training instable at the beginning
#         self.teacher_temp_schedule = np.concatenate((
#             np.linspace(warmup_teacher_temp,
#                         teacher_temp, warmup_teacher_temp_epochs),
#             np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
#         ))

#     def forward(self, student_output, teacher_output, epoch):
#         """
#         Cross-entropy between softmax outputs of the teacher and student networks.
#         """
#         student_cls = student_output
#         teacher_cls = teacher_output

#         # [CLS] and patch for global patches
#         student_cls = student_cls / self.student_temp
#         student_cls_c = student_cls.chunk(self.ngcrops)
        
#         # teacher centering and sharpening
#         temp = self.teacher_temp_schedule[epoch]
#         teacher_cls_c = F.softmax((teacher_cls - self.center) / temp, dim=-1)
#         teacher_cls_c = teacher_cls_c.detach().chunk(self.ngcrops)

#         total_loss1, n_loss_terms1 = 0, 0
#         for q in range(len(teacher_cls_c)):
#             for v in range(len(student_cls_c)):
#                 if not(v == q):
#                     loss1 = torch.sum(-teacher_cls_c[q] * F.log_softmax(student_cls_c[v], dim=-1), dim=-1)
#                     total_loss1 += loss1.mean()
#                     n_loss_terms1 += 1
            
#         total_loss1 = torch.tensor(total_loss1) / n_loss_terms1 * self.lambda1
#         self.update_center(teacher_cls)                  
#         return total_loss1

#     @torch.no_grad()
#     def update_center(self, teacher_cls):
#         """
#         Update center used for teacher output.
#         """
#         cls_center = torch.sum(teacher_cls, dim=0, keepdim=True)
#         dist.all_reduce(cls_center)
#         cls_center = cls_center / (len(teacher_cls) * dist.get_world_size())
#         self.center = self.center * self.center_momentum + cls_center * (1 - self.center_momentum)



class iBOTLoss(nn.Module):
    def __init__(self, out_dim, ngcrops=2, nlcrops=0, warmup_teacher_temp=0.04, 
                 teacher_temp=0.04, warmup_teacher_temp2=0.04, teacher_temp2=0.07, 
                 warmup_teacher_temp_epochs=30, nepochs=300, student_temp=0.1, 
                 center_momentum=0.9, center_momentum2=0.9,
                 lambda1=1.0, lambda2=1.0, mim_start_epoch=0):
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.center_momentum2 = center_momentum2
        self.ngcrops = ngcrops
        self.nlcrops = nlcrops
        self.ncrops = ngcrops + nlcrops
        self.register_buffer("center", torch.zeros(1, out_dim))
        self.register_buffer("center2", torch.zeros(1, 1, out_dim))
        self.lambda1 = lambda1
        self.lambda2 = lambda2

        # we apply a warm up for the teacher temperature because
        # a too high temperature makes the training instable at the beginning
        self.teacher_temp_schedule = np.concatenate((
            np.linspace(warmup_teacher_temp,
                        teacher_temp, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
        ))
        self.teacher_temp2_schedule = np.concatenate((
            np.linspace(warmup_teacher_temp2,
                        teacher_temp2, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp2
        )) if mim_start_epoch == 0 else np.concatenate((
            np.ones(mim_start_epoch) * warmup_teacher_temp2,
            np.linspace(warmup_teacher_temp2,
                        teacher_temp2, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs - mim_start_epoch) * teacher_temp2
        ))

    def forward(self, student_output, teacher_output, epoch):
        """
        Cross-entropy between softmax outputs of the teacher and student networks.
        """
        student_cls = student_output
        teacher_cls = teacher_output

        # [CLS] and patch for global patches
        student_cls = student_cls / self.student_temp
        student_cls_c = student_cls.chunk(self.ncrops)
        
        # teacher centering and sharpening
        temp = self.teacher_temp_schedule[epoch]
        teacher_cls_c = F.softmax((teacher_cls - self.center) / temp, dim=-1)
        teacher_cls_c = teacher_cls_c.detach().chunk(self.ngcrops)

        total_loss1, n_loss_terms1 = 0, 0
        for q in range(len(teacher_cls_c)):
            for v in range(len(student_cls_c)):
                loss1 = torch.sum(-teacher_cls_c[q] * F.log_softmax(student_cls_c[v], dim=-1), dim=-1)
                total_loss1 += loss1.mean()
                n_loss_terms1 += 1
    
        total_loss1 = torch.tensor(total_loss1) / n_loss_terms1 * self.lambda1
        self.update_center(teacher_cls)                  
        return total_loss1

    @torch.no_grad()
    def update_center(self, teacher_cls):
        """
        Update center used for teacher output.
        """
        cls_center = torch.sum(teacher_cls, dim=0, keepdim=True)
        dist.all_reduce(cls_center)
        cls_center = cls_center / (len(teacher_cls) * dist.get_world_size())
        self.center = self.center * self.center_momentum + cls_center * (1 - self.center_momentum)

