class HopfieldSoftmaxSplitOODDetector(nn.Module):
    def __init__(
        self,
        feature_dim,
        n_heads,
        n_queries=1,
        beta=1.,
        return_features=False,
        similarity='dot',
        normalization=False,
        init_averaging_params=False,
        init_std=1.,
        share_params=False,
        learn_betas=False,
    ):
        super(HopfieldSoftmaxSplitOODDetector, self).__init__()
        self.feature_dim = feature_dim
        self.share_params = share_params
        self.learn_betas = learn_betas
        if share_params:
            self.n_heads = n_heads
        else:
            self.n_heads = n_heads // 2  # to match the number of params.
        assert self.n_heads >= 0
        self.n_queries = n_queries
        if beta is not None:
            self.beta = beta
        else:
            self.beta = 1 / torch.sqrt(torch.tensor(feature_dim))
        if self.learn_betas:
            self.betas = nn.Parameter(torch.full([self.n_heads], self.beta).float())
        else:
            self.register_buffer('betas', torch.full([self.n_heads], self.beta))
        self.return_features = return_features
        assert similarity in ('dot', 'euclidean')
        self.similarity = similarity
        self.normalization = normalization
        if init_std == 'feature_dim':
            self.init_std = 1. / torch.sqrt(torch.tensor(feature_dim, dtype=torch.float))
        else:
            self.init_std = init_std

        if share_params:
            self._averaging_params = None
        else:
            self._averaging_params = nn.Parameter(torch.zeros(self.n_heads, n_queries, feature_dim))  # H, Q, F
        self.linear_params = nn.Parameter(torch.zeros(self.n_heads, n_queries, feature_dim))  # H, Q, F
        self.init_averaging_params = init_averaging_params
        if self.init_averaging_params and self.share_params:
            raise ValueError('Cannot initialize averaging parameters if parameters are shared.')
        self._init_weights()
        self.register_buffer('_mean', torch.zeros(self.n_heads))
        self.register_buffer('_cum_log_Z', -torch.ones(self.n_heads) * float('inf'))

    @property
    def averaging_params(self):
        if self.share_params:
            return self.linear_params
        else:
            return self._averaging_params

    def _init_weights(self):
        self.linear_params.data.normal_(mean=0., std=self.init_std)
        if self.init_averaging_params and self._averaging_params is not None:
            self._averaging_params.data.normal_(mean=0., std=1.)

    def reset_mean(self):
        device = self._mean.device
        self._mean = torch.zeros(self.n_heads).to(device)
        self._cum_log_Z = -torch.ones(self.n_heads) * float('inf')
        self._cum_log_Z = self._cum_log_Z.to(device)

    def _masked_softmax_pooling(self, averaging_sims, linear_sims, masks):
        H, B, Q, S = averaging_sims.shape
        if not self.learn_betas and self.beta == 0:
            n = torch.sum(masks, dim=-1)  # B
            masks = masks.reshape(1, B, 1, S)
            linear_sims = torch.where(masks.bool(), linear_sims, torch.zeros_like(linear_sims))  # H, B, Q, S
            pooled = torch.sum(linear_sims, dim=(-2, -1))
            pooled = pooled / n.unsqueeze(0)  # H, B
            log_Z = torch.log(n)
            return pooled, log_Z.unsqueeze(0)

        H, B, Q, S = averaging_sims.shape
        assert list(masks.shape) == [B, S]
        masks = masks.reshape(1, B, 1, S)
        averaging_sims_zero = torch.where(masks.bool(), self.betas.reshape([len(self.betas), 1, 1, 1]) * averaging_sims, torch.full_like(averaging_sims, -torch.inf))  # H, B, Q, S
        log_Z = 1/self.betas.reshape([len(self.betas), 1, 1, 1]) * torch.logsumexp(averaging_sims_zero, dim=(-2, -1), keepdim=True)
        log_p = averaging_sims_zero - self.betas.reshape([len(self.betas), 1, 1, 1]) * log_Z
        p = torch.exp(log_p)
        pooled = torch.einsum('hbqs,hbqs->hb', linear_sims, p)

        return pooled, log_Z.flatten(-3, -1)  # H, B

    def _pairwise_similarity(self, x, y):
        """
        Pairwise similarity of two tensors with rank 3.
        
        :param x: shape ``(A, I, feature_dim)``, e.g. ``(n_heads, n_queries, feature_dim)``
        :param y: shape ``(B, J, feature_dim)``, e.g. ``(batch_size, sequence_len, feature_dim)``
        
        :return: Pairwise similarities; shape ``(A, B, I, J)``, e.g. ``(n_heads, batch_size, n_queries, sequence_len)``
        """
        if self.similarity == 'dot':
            return torch.einsum('aif,bjf->abij', x, y)
        elif self.similarity == 'euclidean':
            return -1/2 * torch.einsum('aif,aif->ai', x, x).reshape([x.shape[0], 1, x.shape[1], 1]) + torch.einsum('aif,bjf->abij', x, y) - 1/2 * torch.einsum('bjf,bjf->bj', y, y).reshape([1, y.shape[0], 1, y.shape[1]])

    @torch.no_grad()
    def partial_fit_mean(self, x, mask):
        if self.learn_betas:
            averaging_params = torch.nn.functional.normalize(self.averaging_params.flatten(-2, -1), dim=-1).reshape(self.averaging_params.shape)
        else:
            averaging_params = self.averaging_params
        averaging_sims = self._pairwise_similarity(averaging_params, x)
        linear_sims = self._pairwise_similarity(self.linear_params, x)
        features, log_Z = self._masked_softmax_pooling(averaging_sims, linear_sims, mask)

        batch_mean, batch_log_Z = self._compute_batch_mean(features, log_Z)
        batch_log_Z = batch_log_Z.flatten(-2, -1)
        if not self.learn_betas and self.beta == 0:
            p = torch.sigmoid(self._cum_log_Z - batch_log_Z)
        else:
            p = torch.sigmoid(self.betas * (self._cum_log_Z - batch_log_Z))
        mean = p * self._mean + (1 - p) * batch_mean
        if not self.learn_betas and self.beta == 0:
            cum_log_Z = torch.logaddexp(self._cum_log_Z, batch_log_Z)
        else:
            cum_log_Z = torch.logaddexp(self._cum_log_Z, self.betas * batch_log_Z)
        self._mean = mean
        self._cum_log_Z = cum_log_Z

    def _compute_batch_mean(self, features, log_Z):
        if not self.learn_betas and self.beta == 0:
            p = torch.softmax(log_Z, dim=-1)
            mean = torch.einsum('hb,hb->h', features, p)
            full_log_Z = torch.logsumexp(log_Z, dim=-1, keepdim=True)
            return mean, full_log_Z

        full_log_Z = 1/self.betas.unsqueeze(-1) * torch.logsumexp(self.betas.unsqueeze(-1) * log_Z, dim=-1, keepdim=True)
        log_p = log_Z - full_log_Z
        p = torch.exp(self.betas.unsqueeze(-1) * log_p)
        mean = torch.einsum('hb,hb->h', features, p)
        return mean, full_log_Z

    @torch.no_grad()
    def orthogonalize_unstable(self):
        """Orthogonalize the parameters using the Gram-Schmidt process. This implementation is numerically unstable."""
        assert self.linear_params.shape[1] == 1, 'Orthogonalization only implemented for n_queries=1'
        parameters = self.linear_params.squeeze(1)  # H, F
        orthogonal_parameters = parameters.clone()
        for p, parameter in enumerate(parameters):
            params_before = orthogonal_parameters[:p]
            weight = torch.einsum('f,nf->n', parameter, params_before) / torch.einsum('nf,nf->n', params_before, params_before)
            weighted_params_before = weight.reshape(-1, 1) * params_before
            orthogonal_parameters[p] = parameter - torch.sum(weighted_params_before, dim=0)
        self.linear_params.data = orthogonal_parameters.unsqueeze(1)

    @torch.no_grad()
    def orthogonalize(self):
        """Orthogonalize the the parameters using the QR decomposition."""

        param_shape = self.linear_params.shape

        parameters = self.linear_params.flatten(1, 2).float()  # H, Q x F
        H, F = parameters.shape

        Q, R = torch.linalg.qr(parameters.T, mode="reduced")  # F, H

        k = min(F, H)
        Qk = Q[:, :k]  # F, k
        rdiag = torch.diagonal(R, offset=0, dim1=-2, dim2=-1)[:k]  # k

        U_T = Qk * rdiag  # F, k

        # Set the extra rows to 0.
        if H > k:
            pad = torch.zeros(F, H - k, dtype=W.dtype, device=W.device)
            U_T = torch.cat([U_T, pad], dim=1)  # F, H

        self.linear_params.data = U_T.T.reshape(param_shape).to(self.linear_params.dtype)  # H, 1, F

    def normalizer(self):
        return torch.log(torch.einsum('hqf,hqf->h', self.linear_params, self.linear_params))

    @property
    @torch.no_grad()
    def mean(self):
        if torch.any(torch.isinf(self._cum_log_Z)):
            raise ValueError('No mean fitted. When in eval mode, please fit the mean first by supplying your data to partial_fit_mean')
        return self._mean

    def features(self, x, mask, use_for_mean=None):
        if len(x.shape) != 3:
            raise ValueError('Input tensor must have shape (batch_size, sequence_len, feature_dim)')
        
        B, S, F = x.shape
        H, Q, _ = self.averaging_params.shape

        if F != self.feature_dim:
            raise ValueError(f'Feature dimension mismatch: {x.shape[-1]} != {self.feature_dim}')
        if mask is not None and mask.shape != (B, S):
            raise ValueError(f'Mask shape must match input tensor shape excluding the last dimension, got {mask.shape} and {x.shape[:-1]}')

        if mask is None:
            mask = torch.ones(x.shape[:-1], device=x.device, dtype=torch.int)

        if use_for_mean is None:
            # use all values in x for mean computation
            use_for_mean = torch.ones([len(x)])

        # Compute pairwise similarity between parameters and x.
        if self.learn_betas:
            averaging_params = torch.nn.functional.normalize(self.averaging_params.flatten(-2, -1), dim=-1).reshape(self.averaging_params.shape)
        else:
            averaging_params = self.averaging_params
        averaging_sims = self._pairwise_similarity(averaging_params, x)  # H, B, Q, S
        linear_sims = self._pairwise_similarity(self.linear_params, x)   # H, B, Q, S
        assert averaging_sims.shape == (H, B, Q, S)
        assert linear_sims.shape == (H, B, Q, S)

        # Compute the softmax-pooled sequence representations
        features, log_Z = self._masked_softmax_pooling(averaging_sims, linear_sims, mask)

        if self.training:
            mean_features = features[:, use_for_mean.bool()]
            mean_log_Z = log_Z[:, use_for_mean.bool()]
            mean, _ = self._compute_batch_mean(mean_features, mean_log_Z)
            mean = mean.unsqueeze(-1)
        else:
            mean = self.mean.unsqueeze(1)  # H, 1

        centered_features = features - mean  # H, B

        return centered_features.transpose(0, 1)  # B, H

    def forward(self, x, mask=None, use_for_mean=None):
        centered_features = self.features(x, mask, use_for_mean=use_for_mean)
        
        if self.return_features:
            return centered_features

        squared_error = centered_features**2  # B, H
        mean_squared_error = torch.mean(squared_error, dim=1)  # B

        if self.normalization:
            mean_squared_error -= torch.mean(self.normalizer())

        return mean_squared_error
