import torch
import torch.nn as nn

from torch.autograd import Variable

from neuralfaults.impute_models.rits import RITS


class BRITS(nn.Module):
    def __init__(self, input_dim, rnn_hid_size=64):
        super(BRITS, self).__init__()

        self.rnn_hid_size = rnn_hid_size

        self.rits_f = RITS(input_dim, self.rnn_hid_size)
        self.rits_b = RITS(input_dim, self.rnn_hid_size)

    def forward(self, inp, mask, delta):
        loss_f, imputations_f = self.rits_f(inp, mask, delta)
        loss_b, imputations_b = self.rits_b(inp.flip(2), mask.flip(2), delta.flip(2))
        imputations_b = self.reverse(imputations_b)

        ret = self.merge_ret(loss_f, imputations_f, loss_b, imputations_b)

        return ret

    def merge_ret(self, loss_f, imputations_f, loss_b, imputations_b):
        loss_c = self.get_consistency_loss(imputations_f, imputations_b)

        loss = loss_f + loss_b + loss_c

        imputations = (imputations_f + imputations_b) / 2

        return loss, imputations

    def get_consistency_loss(self, pred_f, pred_b):
        loss = torch.abs(pred_f - pred_b).mean() * 1e-1
        return loss

    def reverse(self, imputations):
        indices = range(imputations.size()[1])[::-1]
        indices = Variable(torch.LongTensor(indices), requires_grad = False)

        if torch.cuda.is_available():
            indices = indices.cuda()

        return imputations.index_select(1, indices)
