import torch

from avalanche.models import avalanche_forward
from avalanche.benchmarks.datasets.multi_label_dataset.common_utils import *


class SupervisedProblem:
    @property
    def mb_x(self):
        """Current mini-batch input."""
        return self.mbatch[0]

    @property
    def mb_y(self):
        """Current mini-batch target."""
        # return self.mbatch[1]

        # zy add here
        if not self.mixup:
            task_id = list(set(self.mb_task_id.tolist()))
            assert len(task_id) == 1
            return make_common_onehot(spec_onehot=self.mbatch[1], task_label=task_id[0])
        else: # perform mixup
            return self.mbatch[-2]


    @property
    def mb_task_id(self):
        """Current mini-batch task labels."""
        assert len(self.mbatch) >= 3
        return self.mbatch[-1]

    def criterion(self):
        """Loss function for supervised problems."""
        # print("loss function:", self._criterion)
        # c_nums = self.experience.dataset._datasets[0].c_nums
        return self._criterion(self.mb_output, self.mb_y)


    def forward(self):
        """Compute the model's output given the current mini-batch."""
        return avalanche_forward(self.model, self.mb_x)

    def _check_minibatch(self):
        """Check if the current mini-batch has 3 components."""
        assert len(self.mbatch) >= 3



def check_zeroorone(targets):
    for i in range(targets.shape[1]):
        cloumn = targets[:,i]
        if len(torch.unique(cloumn)) == 2:
            return True
    return False