# Code modified from https://github.com/pytorchbearer/torchbearer

import torch
import torch.nn.functional as F
from torch.distributions.beta import Beta

import torchbearer
from torchbearer import Metric
from torchbearer.callbacks import Callback
from torchbearer import metrics as m
from torchbearer.metrics import default as d

@m.default_for_key('l1_loss')
@m.running_mean
@m.mean
class L1_loss(Metric):
    """Mean squared error metric. Computes the pixelwise squared error which is then averaged with decorators.
    Decorated with a mean and running_mean. Default for key: 'mse'.

    Args:
        pred_key (StateKey): The key in state which holds the predicted values
        target_key (StateKey): The key in state which holds the target values
    """

    def __init__(self, pred_key=torchbearer.Y_PRED, target_key=torchbearer.Y_TRUE):
        super().__init__('l1_loss')
        self.pred_key = pred_key
        self.target_key = target_key

    def process(self, *args):
        state = args[0]
        y_pred = state[self.pred_key]
        y_true = state[self.target_key]
        return F.l1_loss(y_pred, y_true).data


@m.running_mean
@m.mean
class MixupAcc_lam(m.AdvancedMetric):
    def __init__(self):
        m.super(MixupAcc_lam, self).__init__('mixup_acc')
        self.cat_acc = m.CategoricalAccuracy().root
        self.mse = L1_loss().root

    def process_train(self, *args):
        m.super(MixupAcc_lam, self).process_train(*args)
        state = args[0]

        pred1, pred2 = state[torchbearer.PREDICTION]
        target1 = state[torchbearer.Y_TRUE]
        target2 = target1[state[torchbearer.MIXUP_PERMUTATION]]
        _state = args[0].copy()
        _state[torchbearer.Y_TRUE] = target1
        _state[torchbearer.PREDICTION] = pred1
        acc1 = self.cat_acc.process(_state)

        _state = args[0].copy()
        _state[torchbearer.Y_TRUE] = target2
        _state[torchbearer.PREDICTION] = pred1
        acc2 = self.cat_acc.process(_state)

        _state = args[0].copy()
        _state[torchbearer.PREDICTION] = pred2
        _state[torchbearer.Y_TRUE] = state[torchbearer.MIXUP_LAMBDA] * torch.ones_like(pred2).to(pred2.device)
        mse = self.mse.process(_state)

        return (acc1 * state[torchbearer.MIXUP_LAMBDA] + acc2 * (1 - state[torchbearer.MIXUP_LAMBDA])) * 0.8 + 0.2 * mse

    def process_validate(self, *args):
        m.super(MixupAcc_lam, self).process_validate(*args)

        state = args[0]
        _state = args[0].copy()
        pred1, pred2 = state[torchbearer.PREDICTION]
        _state[torchbearer.PREDICTION] = pred1
        return self.cat_acc.process(_state)

    def reset(self, state):
        self.cat_acc.reset(state)
        self.mse.reset(state)

@m.running_mean
@m.mean
class MixupAcc(m.AdvancedMetric):
    def __init__(self):
        m.super(MixupAcc, self).__init__('mixup_acc')
        self.cat_acc = m.CategoricalAccuracy().root

    def process_train(self, *args):
        m.super(MixupAcc, self).process_train(*args)
        state = args[0]

        target1 = state[torchbearer.Y_TRUE]
        target2 = target1[state[torchbearer.MIXUP_PERMUTATION]]
        _state = args[0].copy()
        _state[torchbearer.Y_TRUE] = target1
        acc1 = self.cat_acc.process(_state)

        _state = args[0].copy()
        _state[torchbearer.Y_TRUE] = target2
        acc2 = self.cat_acc.process(_state)

        return acc1 * state[torchbearer.MIXUP_LAMBDA] + acc2 * (1 - state[torchbearer.MIXUP_LAMBDA])

    def process_validate(self, *args):
        m.super(MixupAcc, self).process_validate(*args)

        return self.cat_acc.process(*args)

    def reset(self, state):
        self.cat_acc.reset(state)

class RMixup(Callback):
    """Perform mixup on the model inputs. Requires use of :meth:`MixupInputs.loss`, otherwise lambdas can be found in
    state under :attr:`.MIXUP_LAMBDA`. Model targets will be a tuple containing the original target and permuted target.

    .. note::

        The accuracy metric for mixup is different on training to deal with the different targets,
        but for validation it is exactly the categorical accuracy, despite being called "val_mixup_acc"

    Example: ::

        >>> from torchbearer import Trial
        >>> from torchbearer.callbacks import Mixup

        # Example Trial which does Mixup regularisation
        >>> mixup = Mixup(0.9)
        >>> trial = Trial(None, criterion=Mixup.mixup_loss, callbacks=[mixup], metrics=['acc'])

    Args:
        lam (float): Mixup inputs by fraction lam. If RANDOM, choose lambda from Beta(alpha, alpha). Else, lambda=lam
        alpha (float): The alpha value to use in the beta distribution.
    """
    RANDOM = -10.0

    def __init__(self, other_loader=None, alpha=1.0, lam=RANDOM, reformulate=False, mixout=False, lam_train=False):
        super(RMixup, self).__init__()
        self.alpha = alpha
        self.lam = lam
        self.reformulate = reformulate
        self.lam_train = lam_train
        self.distrib = Beta(self.alpha, self.alpha) if not reformulate else Beta(self.alpha + 1, self.alpha)
        self.other_loader = other_loader
        self.iterator = None
        self.mixout=mixout

    @staticmethod
    def mixup_loss(state):
        """The standard cross entropy loss formulated for mixup (weighted combination of `F.cross_entropy`).

        Args:
            state: The current :class:`Trial` state.
        """
        input, target = state[torchbearer.Y_PRED], state[torchbearer.Y_TRUE]

        if state[torchbearer.DATA] is torchbearer.TRAIN_DATA:
            y1, y2 = target
            return F.cross_entropy(input, y1) * state[torchbearer.MIXUP_LAMBDA] + F.cross_entropy(input, y2) * (1-state[torchbearer.MIXUP_LAMBDA])
        else:
            return F.cross_entropy(input, target)

    def on_start_epoch(self, state):
        self.iterator = iter(self.other_loader)

    def on_sample(self, state):
        if self.lam is RMixup.RANDOM:
            if self.alpha > 0:
                lam = self.distrib.sample()
            else:
                lam = 1.0
        else:
            lam = self.lam

        state[torchbearer.MIXUP_LAMBDA] = lam
        state[torchbearer.MIXUP_PERMUTATION] = torch.randperm(state[torchbearer.X].size(0))

        if self.reformulate:
            x, _ = next(self.iterator)
            x = x[:state[torchbearer.X].shape[0]]
            if x.shape != state[torchbearer.X].shape:
                x = F.pad(x,(2,2,2,2))
            if self.mixout:
                x = torch.zeros_like(x)
            state[torchbearer.X] = state[torchbearer.X] * state[torchbearer.MIXUP_LAMBDA] + \
                                   x.to(state[torchbearer.X].device) \
                                   * (1 - state[torchbearer.MIXUP_LAMBDA])
            state[torchbearer.MIXUP_LAMBDA] = 1
        else:
            state[torchbearer.X] = state[torchbearer.X] * state[torchbearer.MIXUP_LAMBDA] + \
                                   state[torchbearer.X][state[torchbearer.MIXUP_PERMUTATION],:] \
                                   * (1 - state[torchbearer.MIXUP_LAMBDA])

        # if self.lam_train:
        #     d.__loss_map__[RMixup().mixup_loss().__name__] = MixupAcc_lam
        # else:
        #     d.__loss_map__[RMixup().mixup_loss().__name__] = MixupAcc
        if self.lam_train:
            d.__loss_map__[RMixup.mixup_loss.__name__] = MixupAcc_lam
        else:
            d.__loss_map__[RMixup.mixup_loss.__name__] = MixupAcc
