from abc import ABC, abstractmethod

import torch


class Processor(ABC):
    """
    Process incoming object.
    Useful for input output pre- and post- processing of models

    Parameters
    ----------

    inp : any python object
        input to be transformed

    Returns
    -------
        transformed input
    """

    def __init__(self):
        super().__init__()

    @abstractmethod
    def process(self, inp):
        pass


class SequentialProcessor(Processor):
    """
    Apply multiple processors after each other.

    processors : list[Processor]
        list of Processor objects
    """
    def __init__(self, processors):
        super().__init__()
        self.processors = processors

    def process(self, inp):
        for p in self.processors:
            inp = p.process(inp)
        return inp


class LambdaProcessor(Processor):
    """
    Apply arbitrary function to input

     - For example useful when non-linearity should be applied

    Examples
    --------
    non_lin_processor = LambdaProcessor(fun=torch.sigmoid)
    non_lin_processor = LambdaProcessor(fun=lambda x: torch.softmax(x, axis=1)
    """
    def __init__(self, fun):
        super().__init__()
        self.fun = fun

    def process(self, inp):
        return self.fun(inp)


class IdentityProcessor(Processor):
    """
    Keep untouched
    """
    def process(self, inp):
        return inp


class ToNumpyProcessor(Processor):
    """
    Convert to numpy
    """
    def process(self, inp):
        return inp.cpu().numpy()


class ArgmaxProcessor(Processor):
    """
    Apply argmax along certain axis
    """

    def __init__(self, axis=1):
        super().__init__()
        self.axis = axis

    def process(self, inp):
        return inp.argmax(axis=self.axis)


class AddLRFlipProcessor(Processor):
    """
    Add left right flip to input data
    """
    def __init__(self):
        super().__init__()

    def process(self, inp):
        inp_flipped = torch.flip(inp["X"], [3])
        inp["X"] = torch.cat((inp["X"], inp_flipped), 0)
        return inp


class AverageLRFlipPostProcessor(Processor):
    """
    Average output of left right flip image
    """
    def __init__(self):
        super().__init__()

    def process(self, inp):
        # get half batch size
        n_half = inp.shape[0] // 2

        # get l/r data and original data
        inp_lr = torch.flip(inp[n_half::], [3])
        inp_orig = inp[0:n_half]

        # average predictions
        return (inp_lr + inp_orig) / 2
