import numpy as np
import torch
from typing import NamedTuple, List


class BiasClusterEmbeddings(NamedTuple):
    correctly_classified: np.ndarray | torch.Tensor
    incorrectly_classified: np.ndarray | torch.Tensor


class BiasEmbeddingsCalculator:
    def _average_all_pairs(self, correct: np.ndarray | torch.Tensor, incorrect: np.ndarray | torch.Tensor):
        N, dim_A = correct.shape
        M, dim_B = incorrect.shape

        assert dim_A == dim_B

        sum_of_averages = np.zeros(dim_A)

        for i in range(N):
            for j in range(M):
                sum_of_averages += (correct[i] + incorrect[j]) / 2

        return sum_of_averages / (N * M)
    
    def calculate_bias_embeddings(self, clustered_embeddings: List[BiasClusterEmbeddings]) -> np.ndarray:
        avg_emb_for_bias = []

        for emb_cluster in clustered_embeddings:
            avg_emb = self._average_all_pairs(
                emb_cluster.correctly_classified,
                emb_cluster.incorrectly_classified
            )
            avg_emb_for_bias.append(avg_emb)
        
        avg_emb_for_bias = np.vstack(avg_emb_for_bias)

        avg_of_avgs = np.average(avg_emb_for_bias, axis=0)

        bias_embeddings = []

        for avg_emb in avg_emb_for_bias:
            bias_embeddings.append(avg_emb - avg_of_avgs)
        
        return np.vstack(bias_embeddings)
