import torch
import torch.nn as nn

from models.ledn_model import LEDN
from utils.losses import mse_loss


class LEDN_Mock(nn.Module):
    def __init__(self, config):
        super(LEDN_Mock, self).__init__()
        self.net = LEDN(img_channels=4, depth_channel=1, n_classes=3, bilinear=True,
                        use_dba=config.model.use_dba, use_idf=config.model.use_idf, use_moe=config.model.use_moe)
        self.loss_fn1 = nn.L1Loss(reduction='mean')
        self.loss_fn2 = nn.MSELoss(reduction='mean')

    def forward(self, inputs, targets, depths):
        preds, layered_preds_list = self.net(inputs, depths)
        loss = self.loss_fn2(preds, targets)
        if layered_preds_list is not None:
            # weights = [0.1, 0.3, 0.5]
            weights = [1, 1, 1]
            for i, layered_preds in enumerate(layered_preds_list):
                loss += self.loss_fn1(layered_preds, targets) * weights[i]
                loss += self.loss_fn2(layered_preds, targets) * weights[i]
            preds = layered_preds_list[-1]
        # loss += self.loss_fn1(preds, targets)
        return preds, loss
