import torch


def get_task_matrices(y, query_set_mask, task_mask, scatter_mask):
    '''
        Get a list of task matrices. Different types of tasks may be batched together.
        This one returns a list of label matrices from some of the outputs of the batching function(they can have different shapes)
        :return:
    '''

    output = []
    for i, task_id in enumerate(torch.unique(task_mask)):
        filt_idx = torch.where(task_mask == task_id)[0]
        n_classes = int(len(filt_idx) / len(torch.unique(scatter_mask[filt_idx])))
        qm = torch.where(query_set_mask[filt_idx] == 1)[0]
        y_filt = y[filt_idx][qm]
        assert len(y_filt) % n_classes == 0
        output.append(y_filt.reshape(-1, n_classes))
    return output

