import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from types import SimpleNamespace
from criteria.proxynca import Criterion as Crit
from itertools import chain, combinations

"""================================================================================================="""
ALLOWED_MINING_OPS  = None
REQUIRES_BATCHMINER = False
REQUIRES_OPTIM      = True


class Criterion(Crit):
    def __init__(self, opt):
        """
        Args:
            opt: Namespace containing all relevant parameters.
        """
        n_proxies = 2 ** opt.n_classes
        assert n_proxies <= opt.embed_dim, f"Number of proxies ({n_proxies}) is larger than embed dim ({opt.embed_dim})!"

        new_opt = SimpleNamespace(**vars(opt))
        new_opt.n_classes = n_proxies
        self.mat = torch.from_numpy(2**np.arange(opt.n_classes)).long()

        super(Criterion, self).__init__(new_opt)        

        self.name       = 'multiproxyncapowerset'
        
    def forward(self, batch, labels, **kwargs):
        new_labels = labels @ self.mat
        return super().forward(batch, new_labels, **kwargs)
