from enum import Enum
from functools import partial
from logging import getLogger
import math
import os
from typing import List, Union, Callable, Tuple, Optional
import sys

import numpy as np
import torch
from torch import distributed as dist
from torch import nn
from torch.nn import functional as F

from einops import rearrange

from hadamard import get_hadamard_matrix

_logger = getLogger(__name__)


class NormalMethod(Enum):
    NONE = 0
    LOSS = 1
    STANDARDIZE = 3
    ZCA_WHITEN = 4
    HADAMARD_WHITEN = 7
    PCA_WHITEN = 8
    WPCA_WHITEN = 9
    PHI_STANDARDIZE = 10
    GLOBAL_STANDARDIZE = 11


_NORM_METHOD = NormalMethod.NONE


def cosine_similarity_loss(x: torch.Tensor, y: torch.Tensor, dim=-1, **kwargs):
    loss = 1 - F.cosine_similarity(x, y, dim=dim)
    return loss


def _ohem_reduce(loss: torch.Tensor):
    loss = loss.flatten(1)

    _, worst_idxs = torch.topk(loss, k=min(loss.shape[1] // 4, 1024), dim=1, largest=True, sorted=False)

    # For random selection, still bias it toward harder examples
    probs = loss.detach().sqrt()
    # Don't re-sample the worst indices
    probs.scatter_(dim=1, index=worst_idxs, value=0.0)

    rand_idxs = torch.multinomial(probs + 1e-8, num_samples=min(loss.shape[1] // 8, 512), replacement=False)

    mask = torch.zeros(loss.shape, dtype=torch.bool, device=loss.device)
    mask.scatter_(dim=1, index=worst_idxs, value=True)
    mask.scatter_(dim=1, index=rand_idxs, value=True)

    loss = torch.where(mask, loss, 0)

    num_valid = mask.sum(dim=1, dtype=torch.float32)

    loss = loss.sum(dim=1) / num_valid
    return loss


def masked_sum(t: torch.Tensor, mask: torch.Tensor, **kwargs) -> torch.Tensor:
    s = torch.where(mask, t, 0).sum(**kwargs)
    a2 = dict(kwargs)
    if 'dtype' not in a2:
        a2['dtype'] = s.dtype
    ct = mask.sum(**a2)

    return s, ct

def masked_mean(t: torch.Tensor, mask: torch.Tensor, **kwargs) -> torch.Tensor:
    s, ct = masked_sum(t, mask, **kwargs)
    return s / ct


class LossFnStateBase(nn.Module):
    def __init__(self, name: str, feature_dim: int, ohem: bool):
        super().__init__()
        self.name = name
        self.feature_dim = feature_dim
        self.ohem = ohem
        self.dist_group: dist.ProcessGroup = None

        self.register_buffer('fwd_count', torch.tensor(0, dtype=torch.float64), persistent=True)
        self.register_buffer('num_samples', torch.tensor(0.0, dtype=torch.float64), persistent=True)
        self.register_buffer('sample_sum', torch.zeros(feature_dim, dtype=torch.float64), persistent=True)

    def masked_mean(self, t: torch.Tensor, mask: torch.Tensor, **kwargs):
        return masked_mean(t, mask.unsqueeze(1).expand(-1, self.feature_dim, -1, -1), **kwargs)

    def masked_sum(self, t: torch.Tensor, mask: torch.Tensor, **kwargs):
        s, ct = masked_sum(t, mask.unsqueeze(1).expand_as(t), **kwargs)
        return s, ct[0]

    @property
    def expected_mean(self):
        return torch.where(self.num_samples > 0, self.sample_sum / self.num_samples, 0)

    @torch.no_grad()
    def update(self, loss_fn_base: 'LossFnBase', teacher_features: torch.Tensor, loss_mask: torch.Tensor):
        self.fwd_count += 1

        sample_sum, num_samples = self.masked_sum(teacher_features, loss_mask, dim=(0, 2, 3), dtype=torch.float64)

        if dist.is_initialized():
            dist.all_reduce(sample_sum, op=dist.ReduceOp.SUM, group=self.dist_group)
            dist.all_reduce(num_samples, op=dist.ReduceOp.SUM, group=self.dist_group)

        self.sample_sum += sample_sum
        self.num_samples += num_samples

        return self.expected_mean

    def transform_targets(self, teacher_features: torch.Tensor) -> torch.Tensor:
        return teacher_features

    def transform_student(self, student_features: torch.Tensor) -> torch.Tensor:
        return student_features

    def transform_loss(self, loss: torch.Tensor) -> torch.Tensor:
        return loss

    def modify_linear(self, final: nn.Linear):
        pass

    @torch.no_grad()
    def synchronize(self):
        if not dist.is_initialized():
            return

        src_rank = self._global_rank_for_group_rank()

        if src_rank >= 0:
            self._broadcast(src_rank)

    def _global_rank_for_group_rank(self, target_rank: int = 0, reduction_group: dist.ProcessGroup = None):
        if not dist.is_initialized():
            return target_rank

        group_rank = dist.get_rank(self.dist_group)
        global_rank = dist.get_rank()

        # Figure out which rank runs the broadcast
        src_rank = torch.tensor(global_rank if group_rank == target_rank else -1, dtype=torch.int32, device='cuda')
        dist.all_reduce(src_rank, op=dist.ReduceOp.MAX, group=reduction_group)
        src_rank = src_rank.item()
        return src_rank

    def _broadcast(self, src_rank: int, group: dist.ProcessGroup = None):
        dist.broadcast(self.fwd_count, src_rank, group=group)
        dist.broadcast(self.num_samples, src_rank, group=group)
        dist.broadcast(self.sample_sum, src_rank, group=group)


class LossNormState(LossFnStateBase):
    def __init__(self, name: str, feature_dim: int, ohem: bool):
        super().__init__(name, feature_dim, ohem)
        self.register_buffer('expected_loss_sum', torch.zeros(feature_dim, dtype=torch.float64), persistent=True)

    @property
    def expected_loss(self):
        return torch.where(self.num_samples > 0, self.expected_loss_sum / self.num_samples, 1)

    @torch.no_grad()
    def update(self, loss_fn_base: 'LossFnBase', teacher_features: torch.Tensor, loss_mask: torch.Tensor):
        expected_mean = super().update(loss_fn_base, teacher_features, loss_mask)

        expected_loss = loss_fn_base.get_balanced_loss(
            expected_mean.reshape(1, -1, 1, 1).expand_as(teacher_features),
            teacher_features,
        )

        expected_loss_sum, _ = self.masked_sum(expected_loss, loss_mask, dim=(0, 2, 3), dtype=torch.float64)

        if dist.is_initialized():
            dist.all_reduce(expected_loss_sum, op=dist.ReduceOp.SUM, group=self.dist_group)

        self.expected_loss_sum += expected_loss_sum

    def transform_loss(self, loss: torch.Tensor) -> torch.Tensor:
        expected_loss = self.expected_loss

        # Some feature dimensions may literally be 0, so don't amplify the loss on them.
        factor = torch.where(expected_loss > 0, 1 / expected_loss.clamp_min(1e-5), 1.0)

        if (not dist.is_initialized() or dist.get_rank(self.dist_group) == 0) and (int(self.fwd_count.item()) % 100 == 0):
            print(f'Loss Norm - {self.name} - Mean: {factor.mean().item():.05f}, Min: {factor.amin().item():.05f}, Max: {factor.amax().item():.05f}', file=sys.stderr)

        loss = loss * factor.reshape(1, -1, 1, 1)
        return loss

    def _broadcast(self, src_rank: int, group: dist.ProcessGroup = None):
        super()._broadcast(src_rank, group)
        dist.broadcast(self.expected_loss_sum, src_rank, group=group)


class StandardizeNormState(LossFnStateBase):
    def __init__(self, name: str, feature_dim: int, ohem: bool):
        super().__init__(name, feature_dim, ohem)

        self.register_buffer('var_sum', torch.zeros(feature_dim, dtype=torch.float64), persistent=True)

    @property
    def variance(self):
        return self.var_sum / (self.num_samples - 1)

    @torch.no_grad()
    def update(self, loss_fn_base: 'LossFnBase', teacher_features: torch.Tensor, loss_mask: torch.Tensor):
        exp_mean = super().update(loss_fn_base, teacher_features, loss_mask)

        mc_feats = teacher_features - exp_mean.reshape(1, -1, 1, 1)

        var_sum, _ = self.masked_sum(mc_feats.pow(2), loss_mask, dim=(0, 2, 3), dtype=torch.float64)
        if dist.is_initialized():
            dist.all_reduce(var_sum, op=dist.ReduceOp.SUM, group=self.dist_group)
        self.var_sum += var_sum

    @torch.no_grad()
    def transform_targets(self, teacher_features: torch.Tensor) -> torch.Tensor:
        w = (teacher_features - self.expected_mean.reshape(1, -1, 1, 1))
        s = self.variance
        s = torch.where(s > 0, s.rsqrt(), 1.0)

        w = w * s.reshape(1, -1, 1, 1)
        return w.to(teacher_features.dtype)

    @torch.no_grad()
    def transform_student(self, student_features: torch.Tensor) -> torch.Tensor:
        exp_mean = self.expected_mean.reshape(1, -1, 1, 1)
        var = self.variance.reshape(1, -1, 1, 1)
        exp_std = torch.where(var > 0, var.sqrt(), 1.0)

        w = student_features * exp_std.to(student_features.dtype) + exp_mean.to(student_features.dtype)
        return w

    def modify_linear(self, final: nn.Linear):
        v = self.variance
        m = self.expected_mean

        final.weight.data *= v.reshape(-1, 1)
        if final.bias is not None:
            final.bias.data *= v

            final.bias.data += m

    def _broadcast(self, src_rank: int, group: dist.ProcessGroup = None):
        super()._broadcast(src_rank, group)
        dist.broadcast(self.var_sum, src_rank, group=group)


class GlobalStandardizeNormState(StandardizeNormState):
    @property
    def expected_mean(self):
        return super().expected_mean.mean()

    @property
    def variance(self):
        ss = self.var_sum.sum()
        num_samples = self.num_samples * self.feature_dim
        var = ss / (num_samples - 1)
        return torch.where(num_samples > 0, var, 1)

    @torch.no_grad()
    def transform_targets(self, teacher_features: torch.Tensor):
        w = teacher_features - self.expected_mean
        d = self.variance.rsqrt()

        w = w * d
        return w.to(teacher_features.dtype)

    @torch.no_grad()
    def transform_student(self, student_features: torch.Tensor) -> torch.Tensor:
        mean = self.expected_mean
        std = self.variance.sqrt()

        w = student_features * std.to(student_features.dtype) + mean.to(student_features.dtype)
        return w

    def modify_linear(self, final: nn.Linear):
        mean = self.expected_mean
        std = self.variance.sqrt()

        final.weight.data *= std
        if final.bias is not None:
            final.bias.data *= std
            final.bias.data += mean


class WhitenNormState(LossFnStateBase):
    def __init__(self, name: str, feature_dim: int, ohem: bool, update_period: int = 100):
        super().__init__(name, feature_dim, ohem)
        self.update_period = update_period
        self.register_buffer('eye', torch.eye(feature_dim, dtype=torch.float64), persistent=False)
        self.register_buffer('inv_whiten', self.eye.clone(), persistent=True)
        self.register_buffer('whiten', self.eye.clone(), persistent=True)
        self.register_buffer('cov_sum', torch.zeros(feature_dim, feature_dim, dtype=torch.float64), persistent=True)

    @property
    def covariance(self):
        return self.cov_sum / (self.num_samples - 1)

    @property
    def max_samples(self) -> int:
        return 30 * self.update_period

    @torch.no_grad()
    @torch.autocast('cuda', enabled=False)
    def update(self, loss_fn_base: 'LossFnBase', teacher_features: torch.Tensor, loss_mask: torch.Tensor):
        fwd_count = int(self.fwd_count.item())

        if fwd_count == 0 and self._load_from_cache():
            return

        # Annoyingly, `eigh`, `svd`, and `eig` aren't stable for producing the eigenvectors,
        # which means that this method will consistently produce different rotations.
        # The good news is that once we get enough samples, we're pretty close to the expectation, and we can
        # stop re-estimating this.
        if fwd_count > self.max_samples:
            self.fwd_count += 1
            return

        self._update_samples(loss_fn_base, teacher_features, loss_mask)

        if fwd_count % self.update_period == 0:
            self._wrap_update_projections(fwd_count)
            self._calc_projection_error()

        if fwd_count == self.max_samples:
            self._save_cache()

    def _get_cache_path(self):
        safe_name = self.name.replace('(', '_').replace(')', '_').replace(' ', '_').replace(',', '-')
        fname = f'{safe_name}.pth'
        cache_dir = os.path.join(torch.hub.get_dir(), 'evfm', 'fd_loss_states', 'whiten')
        cache_path = os.path.join(cache_dir, fname)
        return cache_path

    def _load_from_cache(self) -> bool:
        # cache_path = self._get_cache_path()
        # if os.path.exists(cache_path):
        #     buffers = torch.load(cache_path, map_location='cpu')
        #     _logger.info(f'Loaded whitening state from cache: {cache_path}')
        #     for k, v in buffers.items():
        #         local_buff = getattr(self, k, None)
        #         if local_buff is not None:
        #             local_buff.copy_(v)

        #     # Recompute the dynamic buffers
        #     self._wrap_update_projections(self.fwd_count.item())
        #     return True
        return False

    def _save_cache(self):
        if dist.get_rank(self.dist_group) != 0:
            return

        # cache_path = self._get_cache_path()
        # cache_dir = os.path.dirname(cache_path)

        # _logger.info(f'Saving state to cache file: {cache_path}')

        # os.makedirs(cache_dir, exist_ok=True)
        # buffers = dict(self.named_buffers())
        # torch.save(buffers, cache_path)
        pass

    def _update_samples(self, loss_fn_base: 'LossFnBase', teacher_features: torch.Tensor, loss_mask: torch.Tensor):
        expected_mean = super().update(loss_fn_base, teacher_features, loss_mask)

        flat_feat = rearrange(teacher_features, 'b c h w -> (b h w) c')
        flat_mask = loss_mask.flatten().unsqueeze(1)

        flat_feat = flat_feat.double() - expected_mean.unsqueeze(0)  # Mean center

        flat_feat = torch.where(flat_mask, flat_feat, 0.0)

        cov = flat_feat.T @ flat_feat

        if dist.is_initialized():
            dist.all_reduce(cov, op=dist.ReduceOp.SUM, group=self.dist_group)

        self.cov_sum += cov

        return expected_mean, flat_feat

    def _wrap_update_projections(self, fwd_count: int):
        inv_whiten = self.inv_whiten.clone()
        whiten = self.whiten.clone()

        self._update_projections(fwd_count)

        if dist.get_rank(self.dist_group) == 0:
            # This allows us to measure how much the projections are changing
            # by measuring how close the new estimate is to reconstructing the
            # identity matrix given the old estimate.
            p2 = self.inv_whiten @ whiten - self.eye
            p3 = inv_whiten @ self.whiten - self.eye
            energy = (p2 + p3) / 2
            _logger.info(f'Rotation Change Energy: {energy.norm().item():.6f}')

        if dist.is_initialized():
            group_rank_0_global_rank = self._global_rank_for_group_rank(reduction_group=self.dist_group)
            self._broadcast(group_rank_0_global_rank, self.dist_group)
        pass

    def _update_projections(self, fwd_count: int):
        raise NotImplementedError("Subclasses must implement this!")

    @torch.autocast('cuda', enabled=False)
    def transform_targets(self, teacher_features: torch.Tensor) -> torch.Tensor:
        b, c, h, w = teacher_features.shape

        flat_feat = rearrange(teacher_features, 'b c h w -> (b h w) c')

        flat_feat = flat_feat - self.expected_mean.unsqueeze(0)

        flat_white = flat_feat @ self.whiten.T

        teacher_features = rearrange(flat_white, '(b h w) c -> b c h w',
                                     b=b, c=c, h=h, w=w).to(teacher_features.dtype)

        if dist.get_rank(self.dist_group) == 0 and int(self.fwd_count.item()) % 50 == 0:
            whiten_error = (torch.cov(flat_white.T) - self.eye).abs().mean()
            _logger.info(f'Whiten Error ({self.name}): {whiten_error.item()}')

        return teacher_features

    @torch.no_grad()
    def transform_student(self, student_features: torch.Tensor) -> torch.Tensor:
        mean = self.expected_mean.to(student_features.dtype)
        inv_whiten = self.inv_whiten.to(student_features.dtype)

        b, c, h, w = student_features.shape

        flat_feat = rearrange(student_features, 'b c h w -> (b h w) c')

        flat_feat = flat_feat @ inv_whiten.T
        flat_feat = flat_feat + mean

        student_features = rearrange(flat_feat, '(b h w) c -> b c h w', b=b, c=c, h=h, w=w)

        return student_features

    def modify_linear(self, final: nn.Linear):
        _logger.info(f'De-normalizing linear layer! Method: {type(self).__name__}')
        m = self.expected_mean.to(final.weight.dtype)
        w = self.inv_whiten.to(final.weight.dtype)

        w2 = w @ final.weight
        final.weight.data.copy_(w2)

        if final.bias is not None:
            b2 = w @ final.bias.unsqueeze(1)
            final.bias.data.copy_(b2.squeeze(1))

            final.bias.data += m

    def _calc_projection_error(self):
        if dist.get_rank(self.dist_group) != 0:
            return

        # Measure the magnitude error for each input
        norm = self.inv_whiten.norm(dim=0)

        minVal = norm.amin().item()
        maxVal = norm.amax().item()
        valRange = maxVal - minVal

        _logger.info(f'Projection Error Mag - Mean: {norm.mean().item():.4f}, Min: {minVal:.4f}, Max: {maxVal:.4f}, Std: {norm.std().item():.4f}, Range: {valRange:.4f}')
        pass

    def _eig_decomp(self, cov: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # To deal with dead neurons
        cov = torch.where(cov != 0, cov, 1e-10 * self.eye)

        factor = 1 / cov.diag().median()
        cov = cov * factor

        # # L is the eigenvalue vector
        # # V is the eigenvector matrix, in column format
        L, V = torch.linalg.eigh(cov)

        # threshold = L.amax() * L.shape[0] * torch.finfo(L.dtype).eps
        threshold = 0
        mask = L > threshold

        L /= factor

        return L, V, mask

    def _broadcast(self, src_rank: int, group: dist.ProcessGroup = None):
        super()._broadcast(src_rank, group)
        dist.broadcast(self.inv_whiten, src_rank, group=group)
        dist.broadcast(self.whiten, src_rank, group=group)
        dist.broadcast(self.cov_sum, src_rank, group=group)


class ZCAWhitenNormState(WhitenNormState):
    def __init__(self, name: str, feature_dim: int, ohem: bool, update_period: int = 100):
        super().__init__(name, feature_dim, ohem, update_period)

    def _update_projections(self, fwd_count: int):
        cov = self.covariance

        L, V, mask = self._eig_decomp(cov)

        sqrt_L = torch.where(mask, torch.sqrt(L), 0)
        inv_sqrt_L = torch.where(mask, torch.rsqrt(L), 0)

        inv_whiten = (V @ sqrt_L.diag()) @ V.T
        whiten = (V @ inv_sqrt_L.diag()) @ V.T

        self.inv_whiten.copy_(inv_whiten)
        self.whiten.copy_(whiten)

        return L, V, mask


class PCAWhitenNormState(ZCAWhitenNormState):
    def __init__(self, name: str, feature_dim: int, ohem: bool, update_period: int = 100, weighted: bool = False):
        super().__init__(name, feature_dim, ohem, update_period)

        self.weighted = weighted
        if weighted:
            self.register_buffer('weights', torch.zeros(feature_dim, dtype=torch.float64))

    def _update_projections(self, fwd_count: int):
        L, V, mask = super()._update_projections(fwd_count)

        rotation = V.T

        self.whiten.copy_(rotation @ self.whiten)
        self.inv_whiten.copy_(self.inv_whiten @ rotation.T)

        if self.weighted:
            sqrt_L = torch.where(L > 0, L.sqrt(), 1e-10)
            normalizer = sqrt_L.mean()
            weights = sqrt_L / normalizer
            self.weights.copy_(weights)

        return L, V, mask

    def transform_loss(self, loss: torch.Tensor) -> torch.Tensor:
        if self.weighted:
            loss = loss * self.weights.reshape(1, -1, 1, 1)
        return loss

    def _broadcast(self, src_rank: int, group: dist.ProcessGroup = None):
        super()._broadcast(src_rank, group)
        if self.weighted:
            dist.broadcast(self.weights, src_rank, group=group)



class HadamardWhitenNormState(ZCAWhitenNormState):
    def __init__(self, name: str, feature_dim: int, ohem: bool, update_period: int = 100):
        super().__init__(name, feature_dim, ohem, update_period)

        H = get_hadamard_matrix(feature_dim)

        if dist.is_initialized():
            dist.broadcast(H, src=0)
        self.register_buffer('rotation', H, persistent=True)

    def _update_projections(self, fwd_count: int):
        L, V, mask = super()._update_projections(fwd_count)

        # Undo the rotation by the eigenvectors (e.g. convert this to
        # PCA Whitening), and then rotate by the Hadamard
        # matrix so that the error is spread evenly across all feature
        # dimensions
        rotation = self.rotation @ V.T

        self.whiten.copy_(rotation @ self.whiten)
        self.inv_whiten.copy_(self.inv_whiten @ rotation.T)
        pass

    def _broadcast(self, src_rank: int, group: dist.ProcessGroup = None):
        super()._broadcast(src_rank, group)
        dist.broadcast(self.rotation, src_rank, group=group)


class PHIStandardization(WhitenNormState):
    def __init__(self, name: str, feature_dim: int, ohem: bool, update_period: int = 100):
        super().__init__(name, feature_dim, ohem, update_period)

        H = get_hadamard_matrix(feature_dim)
        if dist.is_initialized():
            dist.broadcast(H, src=0)
        self.register_buffer('rotation', H, persistent=True)

    def _update_projections(self, fwd_count: int):
        cov = self.covariance

        # L, V, mask = self._eig_decomp(cov)
        L, V = torch.linalg.eigh(cov)
        mask = L >= 0
        L = torch.where(mask, L, 0)

        normalizer = L.mean().sqrt()

        inv_normalizer = 1 / normalizer

        normalizer = self.eye * normalizer
        inv_normalizer = self.eye * inv_normalizer

        rotation: torch.Tensor = self.rotation

        whiten = rotation @ inv_normalizer @ V.T
        inv_whiten = V @ normalizer @ rotation.T

        self.inv_whiten.copy_(inv_whiten)
        self.whiten.copy_(whiten)

        return L, V, mask

    def _broadcast(self, src_rank: int, group: dist.ProcessGroup = None):
        super()._broadcast(src_rank, group)
        dist.broadcast(self.rotation, src_rank, group=group)


NORM_METHOD_MAP = {
    None: LossFnStateBase,
    NormalMethod.NONE: LossFnStateBase,
    NormalMethod.LOSS: LossNormState,
    NormalMethod.STANDARDIZE: StandardizeNormState,
    NormalMethod.ZCA_WHITEN: ZCAWhitenNormState,
    NormalMethod.HADAMARD_WHITEN: HadamardWhitenNormState,
    NormalMethod.PCA_WHITEN: PCAWhitenNormState,
    NormalMethod.WPCA_WHITEN: partial(PCAWhitenNormState, weighted=True),
    NormalMethod.PHI_STANDARDIZE: PHIStandardization,
    NormalMethod.GLOBAL_STANDARDIZE: GlobalStandardizeNormState,
}


class LossFnBase(nn.Module):
    def __init__(self, name: str, feature_dim: int, ohem: bool = False, norm_mode: NormalMethod = _NORM_METHOD):
        super().__init__()
        self.name = name
        self.feature_dim = feature_dim
        self.ohem = ohem

        self.loss_state = NORM_METHOD_MAP[norm_mode](name, feature_dim, ohem)

        self.loss_state.register_buffer('error_num_steps', torch.tensor(0, dtype=torch.int64))
        self.loss_state.register_buffer('error_num_samples', torch.tensor(0, dtype=torch.float64))
        self.loss_state.register_buffer('error_cov_sum', torch.zeros(feature_dim, feature_dim, dtype=torch.float64))

        self.loss_state.register_buffer('dn_error_cov_sum', torch.zeros(feature_dim, feature_dim, dtype=torch.float64))

    @property
    def dist_group(self):
        return self.loss_state.dist_group

    @dist_group.setter
    def dist_group(self, v: dist.ProcessGroup):
        self.loss_state.dist_group = v

    def get_balanced_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError("Subclasses must implement this!")

    def update(self, teacher_features: torch.Tensor, loss_mask: torch.Tensor):
        self.loss_state.update(self, teacher_features, loss_mask)

    def reduce_loss(self, loss: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor:
        loss = loss.flatten(1)
        valid_mask = valid_mask.flatten(1)

        # loss = torch.where(valid_mask, loss, 0)
        loss = valid_mask * loss

        if self.ohem:
            loss = _ohem_reduce(loss)
        else:
            loss = loss.sum(dim=1)

            avg_valid = valid_mask.sum(dim=1, dtype=torch.float32).mean()
            if self.dist_group is not None and dist.is_initialized():
                dist.all_reduce(avg_valid, op=dist.ReduceOp.AVG, group=self.dist_group)

            loss = loss / avg_valid.clamp_min(1)

        return loss

    @torch.no_grad()
    def get_loss_components(self,
                            normalized_student_features: torch.Tensor,
                            denormalized_teacher_features: torch.Tensor,
                            normalized_teacher_features: torch.Tensor,
                            valid_mask: torch.Tensor):
        denormalized_student_features = self.loss_state.transform_student(normalized_student_features.detach())

        denormalized_loss = F.mse_loss(denormalized_student_features, denormalized_teacher_features, reduction='none')
        denormalized_loss = denormalized_loss.mean(dim=1)
        denormalized_loss = self.reduce_loss(denormalized_loss, valid_mask).mean()

        loss_components = dict(corrected_loss=denormalized_loss)

        return loss_components

    def forward(self, normalized_student_features: torch.Tensor,
                      denormalized_teacher_features: torch.Tensor,
                      sample_mask: torch.Tensor,
                      resize_fn) -> torch.Tensor:
        self.update(denormalized_teacher_features, sample_mask)

        normalized_teacher_features = self.loss_state.transform_targets(denormalized_teacher_features)

        normalized_student_features, (normalized_teacher_features, denormalized_teacher_features), valid_mask = resize_fn(normalized_student_features, (normalized_teacher_features, denormalized_teacher_features))

        loss = self._compute_loss(normalized_student_features, normalized_teacher_features, valid_mask)

        loss_components = self.get_loss_components(normalized_student_features, denormalized_teacher_features,
                                                   normalized_teacher_features, valid_mask)

        return loss, loss_components

    def _compute_loss(self, normalized_student_features: torch.Tensor, normalized_teacher_features: torch.Tensor, valid_mask: torch.Tensor):
        raise NotImplementedError()


class DefaultLossFn(LossFnBase):
    def __init__(self, name: str, feature_dim: int, beta: float = 1.0, ohem: bool = False, norm_mode: NormalMethod = _NORM_METHOD):
        super().__init__(name, feature_dim, ohem, norm_mode)
        self.beta = beta

    def get_balanced_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        return F.mse_loss(pred, target, reduction='none')

    def _compute_loss(self, normalized_student_features: torch.Tensor, normalized_teacher_features: torch.Tensor, valid_mask: torch.Tensor):
        loss_balanced = self.get_balanced_loss(normalized_student_features, normalized_teacher_features)

        loss_balanced = self.loss_state.transform_loss(loss_balanced)
        loss_balanced = loss_balanced.mean(dim=1)

        loss_cos = cosine_similarity_loss(normalized_student_features, normalized_teacher_features, dim=1)

        loss = 0.9 * loss_cos + 0.1 * loss_balanced
        loss = self.reduce_loss(loss, valid_mask)

        return loss


class BasicLossFn(LossFnBase):
    def __init__(self, loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
                 name: str, feature_dim: int, ohem: bool = False, norm_mode: NormalMethod = _NORM_METHOD, **kwargs):
        super().__init__(name, feature_dim, ohem, norm_mode)
        self.loss_fn = loss_fn
        self.kwargs = kwargs

    def get_balanced_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        return self.loss_fn(pred, target, **self.kwargs)

    def _compute_loss(self, normalized_student_features: torch.Tensor, normalized_teacher_features: torch.Tensor, valid_mask: torch.Tensor):
        loss = self.get_balanced_loss(normalized_student_features, normalized_teacher_features)
        loss = self.loss_state.transform_loss(loss)
        if loss.ndim == 4:
            loss = loss.mean(dim=1)
        loss = self.reduce_loss(loss, valid_mask)
        return loss