import torch
import numpy as np
import torch.nn.functional as F
import random

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class EMA:
    def __init__(self, beta, total_steps):
        self.beta = beta
        self.step = 0
        self.total_steps = total_steps

    def update_average(self, old, new):
        if old is None:
            return new
        beta = 1 - (1 - self.beta) * (np.cos(np.pi * self.step / self.total_steps) + 1) / 2.0
        self.step += 1
        return old * beta + (1 - beta) * new

def update_moving_average(ema_updater, teacher, student):
    for t_param, s_param in zip(teacher.parameters(), student.parameters()):
        t_param.data = ema_updater.update_average(t_param.data, s_param.data)

def loss_fn(x, y):
    x = F.normalize(x, dim=-1, p=2)
    y = F.normalize(y, dim=-1, p=2)
    return 2 - 2 * (x * y).sum(dim=-1)
