import torch
import torch.nn.functional as F
from ..models.initializer import initialize_model
from .ERM import ERM
from .single_model_algorithm import SingleModelAlgorithm
from ..scheduler import LinearScheduleWithWarmupAndThreshold
from ...wilds.common.utils import split_into_groups, numel
from ..configs.supported import process_pseudolabels_functions
import copy
from ..utils import load, move_to, detach_and_clone, collate_list, concat_input


class PseudoLabel(SingleModelAlgorithm):
    """
    PseudoLabel.
    This is a vanilla pseudolabeling algorithm which updates the model per batch and incorporates a confidence threshold.

    Original paper:
        @inproceedings{lee2013pseudo,
            title={Pseudo-label: The simple and efficient semi-supervised learning method for deep neural networks},
            author={Lee, Dong-Hyun and others},
            booktitle={Workshop on challenges in representation learning, ICML},
            volume={3},
            number={2},
            pages={896},
            year={2013}
            }
    """
    def __init__(self, config, d_out, grouper, loss, metric, n_train_steps):
        model = initialize_model(config, d_out=d_out)
        model = model.to(config.device)
        # initialize module
        super().__init__(
            config=config,
            model=model,
            grouper=grouper,
            loss=loss,
            metric=metric,
            n_train_steps=n_train_steps,
        )
        # algorithm hyperparameters
        self.lambda_scheduler = LinearScheduleWithWarmupAndThreshold(
            max_value=config.self_training_lambda,
            step_every_batch=True, # step per batch
            last_warmup_step=0,
            threshold_step=config.pseudolabel_T2 * n_train_steps
        )
        self.schedulers.append(self.lambda_scheduler)
        self.scheduler_metric_names.append(None)
        self.confidence_threshold = config.self_training_threshold
        if config.process_pseudolabels_function is not None:
            self.process_pseudolabels_function = process_pseudolabels_functions[config.process_pseudolabels_function]
        # Additional logging
        self.logged_fields.append("pseudolabels_kept_frac")
        self.logged_fields.append("classification_loss")
        self.logged_fields.append("consistency_loss")

    def process_batch(self, batch, unlabeled_batch=None):
        """
        Overrides single_model_algorithm.process_batch().
        Args:
            - batch (tuple of Tensors): a batch of data yielded by data loaders
            - unlabeled_batch (tuple of Tensors or None): a batch of data yielded by unlabeled data loader
        Output:
            - results (dictionary): information about the batch
                - y_true (Tensor): ground truth labels for batch
                - g (Tensor): groups for batch
                - metadata (Tensor): metadata for batch
                - y_pred (Tensor): model output for batch
                - unlabeled_g (Tensor): groups for unlabeled batch
                - unlabeled_metadata (Tensor): metadata for unlabeled batch
                - unlabeled_y_pseudo (Tensor): pseudolabels on the unlabeled batch, already thresholded 
                - unlabeled_y_pred (Tensor): model output on the unlabeled batch, already thresholded 
        """
        # Labeled examples
        x, y_true, metadata = batch
        n_lab = len(metadata)
        x = move_to(x, self.device)
        y_true = move_to(y_true, self.device)
        g = move_to(self.grouper.metadata_to_group(metadata), self.device)

        # package the results
        results = {
            'g': g,
            'y_true': y_true,
            'metadata': metadata
        }

        if unlabeled_batch is not None:
            x_unlab, metadata_unlab = unlabeled_batch
            x_unlab = move_to(x_unlab, self.device)
            g_unlab = move_to(self.grouper.metadata_to_group(metadata_unlab), self.device)
            results['unlabeled_metadata'] = metadata_unlab
            results['unlabeled_g'] = g_unlab

            # Special case for models where we need to pass in y:
            # we handle these in two separate forward passes
            # and turn off training to avoid errors when y is None
            # Note: we have to specifically turn training in the model off
            # instead of using self.train, which would reset the log
            if self.model.needs_y:
                self.model.train(mode=False)
                unlabeled_output = self.get_model_output(x_unlab, None)

                _, unlabeled_y_pseudo, pseudolabels_kept_frac, mask = self.process_pseudolabels_function(
                    unlabeled_output,
                    self.confidence_threshold
                )
                x_unlab = x_unlab[mask]

                self.model.train(mode=True)
                outputs = self.get_model_output(
                    torch.cat((x, x_unlab), dim=0),
                    collate_list([y_true, unlabeled_y_pseudo]),
                )
                unlabeled_y_pred = outputs[n_lab:]
            else:
                x_cat = concat_input(x, x_unlab)
                outputs = self.get_model_output(x_cat, None)
                unlabeled_output = outputs[n_lab:]
                unlabeled_y_pred, unlabeled_y_pseudo, pseudolabels_kept_frac, _ = self.process_pseudolabels_function(
                    unlabeled_output,
                    self.confidence_threshold
                )

            results['y_pred'] = outputs[:n_lab]
            results['unlabeled_y_pred'] = unlabeled_y_pred
            results['unlabeled_y_pseudo'] = detach_and_clone(unlabeled_y_pseudo)
        else:
            results['y_pred'] = self.get_model_output(x, y_true)
            pseudolabels_kept_frac = 0

        self.save_metric_for_logging(
            results, "pseudolabels_kept_frac", pseudolabels_kept_frac
        )
        return results

    def objective(self, results):
        # Labeled loss
        classification_loss = self.loss.compute(
            results['y_pred'],
            results['y_true'],
            return_dict=False)
        # Pseudolabeled loss
        if 'unlabeled_y_pseudo' in results:
            loss_output = self.loss.compute(
                results['unlabeled_y_pred'],
                results['unlabeled_y_pseudo'],
                return_dict=False,
            )
            consistency_loss = loss_output * results['pseudolabels_kept_frac']
        else:
            consistency_loss = 0

        # Add to results for additional logging
        self.save_metric_for_logging(
            results, "classification_loss", classification_loss
        )
        self.save_metric_for_logging(
            results, "consistency_loss", consistency_loss
        )

        return classification_loss + self.lambda_scheduler.value * consistency_loss
