"""OCLFeedback Objects"""

import numpy as np
import torch.nn as nn

from datasets.pds_dataset import PDSSubDataset
from utils import evaluate_model

class OCLFeedback:
    """
    Prototype of an OCL feedback protocol.
    
    If t == 0 then this is the first batch.
    The feedback class is also responsible for evaluating the model.
    For t > 0, the feedback should contain feedback['performance'],
    which is the performance on the new data batch.
    The performance is evaluated by the eval_metric, such as acc, f1, etc.

    Config:
        Specify what are required in config.

    Requires:
        Specify additional required arguments to call this feedback.

    Returns:
        Specify what are included in the returned feedback.
    """
    def __init__(self, config, eval_metric):
        self.config = config
        self.eval_metric = eval_metric
    
    def __call__(self, t: int, batch_train: PDSSubDataset, batches_eval: list[PDSSubDataset],
                 model: nn.Module, *args, **kwargs) -> dict:
        # Default: Evaluate the model on each eval batch and return the performance
        result = {}
        for i in range(len(batches_eval)):
            c, n = evaluate_model(model, batches_eval[i], self.config, self.eval_metric)
            result['performance_{}'.format(i)] = c
        return result


class RandomLabelFeedback(OCLFeedback):
    """
    RLF

    Config:
        config.alpha    - Fraction of data to be labeled
    
    Returns:
        batch_labeled   - PDSSubDataset or None. A labeled batch.
        batch_unlabeled - PDSSubDataset or None. An unlabeled batch.
    """
    def __call__(self, t: int, batch_train: PDSSubDataset, batches_eval: list[PDSSubDataset],
                 model: nn.Module, *args, **kwargs) -> dict:
        
        if t == 0:
            return {'batch_labeled': batch_train}

        fb = super().__call__(t, batch_train, batches_eval, model, *args, **kwargs)
        alpha = self.config.alpha
        if alpha == 1:
            batch_labeled = batch_train
            batch_unlabeled = None
        elif alpha == 0:
            batch_labeled = None
            batch_unlabeled = batch_train.get_subset(labeled=False)
        else:
            n = len(batch_train)
            m = int(alpha * n)
            idx = np.arange(n)
            idx = np.random.permutation(idx)
            batch_labeled = batch_train.get_subset(indices=idx[:m], labeled=True)
            batch_unlabeled = batch_train.get_subset(indices=idx[m:], labeled=False)

        fb['batch_labeled'] = batch_labeled
        fb['batch_unlabeled'] = batch_unlabeled
        return fb

