from functools import reduce
import torch
import numpy as np
from torch.autograd.functional import jacobian
import torch.nn as nn

from .tools.regularization import jac_loss_estimate, jac_loss_exact

class MonotoneImplicitGraph(nn.Module):
    def __init__(self, lin_module, nonlin_module, solver, **kwargs):
        super(MonotoneImplicitGraph, self).__init__()
        self.lin_module = lin_module
        self.nonlin_module = nonlin_module
        self.solver = solver
        pass
    def forward(self, x, *args, compute_jac_loss=False, **kwargs):
        z = self.solver(x,**kwargs)
        self.frd_itr = self.solver.frd_itr
        if compute_jac_loss:
            z.requires_grad_(True)
            f = self.nonlin_module(self.lin_module.multiply(z)+self.lin_module.bias(x))
            jac_loss = jac_loss_estimate(f,z)
            return z, jac_loss
            # jac_loss = self.solver.stats.fwd_lWmax[-1]
            # return z, jac_loss 
        else:
            return z, None