from abc import ABC, abstractmethod
from typing import List, Dict, Any

import torch

class ResetArgsMixin:

    def _reset_activation_checkpointing_args(self):
        self.last_recompute_granularity = self.config.recompute_granularity
        self.last_recompute_method = self.config.recompute_method
        self.last_recompute_num_layers = self.config.recompute_num_layers
        self.config.recompute_granularity = None
        self.config.recompute_method = None
        self.config.recompute_num_layers = None

    def _restore_activation_checkpointing_args(self):
        self.config.recompute_granularity = self.last_recompute_granularity
        self.config.recompute_method = self.last_recompute_method
        self.config.recompute_num_layers = self.last_recompute_num_layers

    def _reset_sequence_parallelism_args(self):
        self.last_sequence_parallel = self.config.sequence_parallel
        self.config.sequence_parallel = False
        if self.model is None:
            return
        unwrap_models = self.unwrap_model_func(self.model)
        for model in unwrap_models:
            for mod in model.modules():
                if hasattr(mod, "sequence_parallel"):
                    assert not hasattr(mod, 'last_sequence_parallel')
                    mod.last_sequence_parallel = mod.sequence_parallel
                    mod.sequence_parallel = False

    def _restore_sequence_parallelism_args(self):
        self.config.sequence_parallel = self.last_sequence_parallel
        if self.model is None:
            return
        unwrap_models = self.unwrap_model_func(self.model)
        for model in unwrap_models:
            for mod in model.modules():
                if hasattr(mod, "sequence_parallel"):
                    mod.sequence_parallel = mod.last_sequence_parallel
                    delattr(mod, 'last_sequence_parallel')

                        
    def _reset_context_parallelism_args(self):
        pass

                        
    def _restore_context_parallelism_args(self):
        pass

    def _reset_parallel_output(self):
        self.last_parallel_output = self.parallel_output
        self.parallel_output = False
        if self.post_process:
            self.output_layer.gather_output = True

    def _restore_parallel_output(self):
        self.parallel_output = self.last_parallel_output
        if self.post_process:
            self.output_layer.gather_output = not self.last_parallel_output


class SupervisedInterface(ABC):

    @abstractmethod
    def prepare_for_training_step(self):
        pass

    @abstractmethod
    def finish_training_step(self):
        pass

    def prepare_for_validation_step(self):
        raise NotImplementedError('not implemented yet')

    def finish_validation_step(self):
        raise NotImplementedError('not implemented yet')

    @abstractmethod
    def get_loss_and_metrics(self, batch, forward_only):
        """Take a micro_batch_size * num microbatches input and return loss as well as metrics
        if forward_only is False, then it's expected the user calls
            loss.backward and populate the gradients

        NOTE: the metrics must be on the CPU and be replicated across all ranks
        """


class Inferrable(ABC):
    """For models that we want to infer on. On a language model
        this should run generate, on a reward model/critic this should
        give the numerical values
    """

    @abstractmethod
    def prepare_for_inference(self):
        """to prepare things for inference
        """

    @abstractmethod
    def finish_inference(self):
        """to restore things after doing inference
        """

    @abstractmethod
    def infer(self, *args, **kwargs):
        """to run inference on the RM to get the rewards out
        """


class CriticModelInterface(SupervisedInterface, Inferrable):

    def prepare_for_training(self):
        pass

    def finish_training(self):
        pass

    def infer_rm_critic(self, *args, **kwargs):
        pass


class AlignableGenerativeInterface(SupervisedInterface, Inferrable):

    @abstractmethod
    def prepare_for_training(self):
        pass

    @abstractmethod
    def finish_training(self):
        pass

    def get_ref_policy_logprobs(
        self,
        rollout_batches: List[Dict[str, List[Any]]],
    ) -> List[List[torch.Tensor]]:
        pass
