"""
Pseudo Labeling (PL) and FixMatch
"""
import torch
from torch.nn.utils.convert_parameters import parameters_to_vector, vector_to_parameters

from algorithms.ocl import OCLAlgorithm
from algorithms.alg_utils import build_pseudo_labels, build_pseudo_dataset, ConcatDataset
from utils import key_is_none, attr_is_none



def PseudoLabeling(model, config, supervised_alg):
    """
    Pseudo Labeling (PL)
    Can be combined with any supervised OCL algorithm
      1. Virtual update: ERM on the labeled batch (as in NBO)
      2. Pseudo label the unlabeled batch with the current model
      3. Recover the model weights, and run supervised_alg on labeled + PLed batches

    supervised_alg - A supervised OCL algorithm class

    Config:
    Same as the config required for supervised_alg
    config.fixmatch  - Set this to True to use FixMatch
    config.gamma     - Use only gamma of the unlabeled samples (with the highest confidence)
    """
    class PL(supervised_alg, OCLAlgorithm):
        def __init__(self, model, config):
            super().__init__(model, config)
            self.gamma = 1 if attr_is_none(config, 'gamma') else config.gamma

        def train(self, t: int, feedback: dict) -> None:
            if key_is_none(feedback, 'batch_unlabeled'):
                # No unlabeled data. Use supervised alg
                super().train(t, feedback)
                return

            # 1. Pre PL: Virtual update with ERM (like NBO)
            w0 = parameters_to_vector(self.model.parameters())
            if not key_is_none(feedback, 'batch_labeled'):
                dataset = feedback['batch_labeled']
                self.erm(self.config.epochs, dataset)
            
            # 2. PL
            batch_unlabeled = feedback['batch_unlabeled']
            if not key_is_none(feedback, 'transform_weak'):
                batch_unlabeled.transform = feedback['transform_weak']
            pseudo_labels, confidence = build_pseudo_labels(batch_unlabeled, self.model, self.config)
            dataset_pl = build_pseudo_dataset(batch_unlabeled, pseudo_labels, confidence, self.gamma)

            # 3. Post PL: Recover model and then fine-tune
            w = parameters_to_vector(self.model.parameters())
            print('check: {}'.format(torch.linalg.norm(w - w0).item()))

            vector_to_parameters(w0, self.model.parameters())
            if not key_is_none(feedback, 'transform_strong'):
                batch_unlabeled.transform = feedback['transform_strong']
            if not key_is_none(feedback, 'batch_labeled'):
                dataset_pl = ConcatDataset([feedback['batch_labeled'], dataset_pl])
            feedback['batch_labeled'] = dataset_pl  # Change to labeled + unlabeled
            a = self.config.epochs
            self.config.epochs = self.config.epochs_unlabeled
            super().train(t, feedback)
            self.config.epochs = a

    return PL(model, config)

