import numpy as np
import time
import torch
import torch.nn as nn
from torch.autograd import Function

from ..tools.recorders import SplittingMethodStats


"""
------------------
Implicit Function
------------------
Provides forward/backward methods for the implicit graph fixed point problem

    X = \phi( W X A + B )

Where B is the output of a bias function dependent on the model inputs U.

------------------
Reference
------------------
https://github.com/SwiftieH/IGNN

"""
class ImplicitFunction(nn.Module):
    def __init__(self,record=True,tol=1e-6,**kwargs):
        super().__init__()
        self.record = record
        self.stats = SplittingMethodStats()
        self.tol = tol
        pass

    def forward(self, W, X_0, A, B, phi, fd_mitr=300, bw_mitr=300):
        zn = self.ImplicitFunc.apply(self, W, X_0, A, B, phi, fd_mitr, bw_mitr)
        return zn

    def inn_pred(self, W, X, A, B, phi, mitr=300, trasposed_A=False, compute_dphi=False):
        start = time.time()
        At = A if trasposed_A else torch.transpose(A, 0, 1)

        err = 1
        errs = []
        for i in range(mitr): #TODO: more standard iteration code
            X_ = W @ X
            support = torch.spmm(At, X_.T).T
            X_new = phi(support + B)
            err = torch.norm(X_new - X, np.inf)
            if err < self.tol:
                break
            X = X_new

        dphi = None 
        if compute_dphi:
            dphi = phi.derivative(X) # use exact derivative rather than autograd
            # with torch.enable_grad():
            #     support = torch.spmm(At, (W @ X).T).T
            #     Z = support + B
            #     Z.requires_grad_(True)
            #     X_new = phi(Z)
            #     dphi = torch.autograd.grad(torch.sum(X_new), Z, only_inputs=True)[0]

        if self.record: # added recorder and convergence print statement (you can disable gradient here but maybe consider for eig penalty?)
            if trasposed_A:
                Weigs = torch.linalg.eigvals(W)
                Weigs = Weigs.abs()
                self.stats.bwd_lWmax += [torch.max(Weigs)]
                self.stats.bwd_lWmin += [torch.min(Weigs)]
                self.stats.bkwd_time.update(time.time() - start)
                self.stats.bkwd_iters.update(i)
                self.stats.BERR.append(errs)
                print(f'Backward: lam_max: {self.stats.bwd_lWmax[-1].item()}\t lam_min: {self.stats.bwd_lWmin[-1].item()}')
                print("Backward:" , i, err.item(), 'Converged' if i+1<mitr else 'Not Converged')
            else:
                Weigs = torch.linalg.eigvals(W)
                Weigs = Weigs.abs()
                sorted_eigs = Weigs.sort()[0]
                sorted_eigs = [float(s.item()) for s in sorted_eigs]
                self.stats.fwd_lWmax += [torch.max(Weigs)]
                self.stats.fwd_lWmin += [torch.min(Weigs)]
                self.stats.fwd_time.update(time.time() - start)
                self.stats.fwd_iters.update(i)
                self.stats.ERR.append(errs)
                print('Top 5 Eigs:',*sorted_eigs,sep=', ')
                print(f'Forward: lam_max: {self.stats.fwd_lWmax[-1].item()}\t lam_min: {self.stats.fwd_lWmin[-1].item()}')
                print("Forward:", i, err.item(), 'Converged' if i+1<mitr else 'Not Converged')

        return X_new, err, None, dphi

    """
    ------------
    ImplicitFunc
    ------------
    Subclass to treat backward as an autograd function.
    """
    class ImplicitFunc(Function):
        @staticmethod
        def forward(ctx, sp, W, X_0, A, B, phi, fd_mitr=300, bw_mitr=300):
            ctx.splitter = sp
            X_0 = B if X_0 is None else X_0
            X, err, status, D = sp.inn_pred(W, X_0, A, B, phi, mitr=fd_mitr, compute_dphi=True)
            ctx.save_for_backward(W, X, A, B, D, X_0, torch.tensor(bw_mitr))
            return X

        @staticmethod
        def backward(ctx, *grad_outputs):
            sp = ctx.splitter
            W, X, A, B, D, X_0, bw_mitr = ctx.saved_tensors
            bw_mitr = bw_mitr.cpu().numpy()
            grad_x = grad_outputs[0]
            dphi = lambda X: torch.mul(X, D)

            grad_z, err, status, _ = sp.inn_pred(W.T, X_0, A, grad_x, dphi, mitr=bw_mitr, trasposed_A=True)
            grad_W = grad_z @ torch.spmm(A, X.T)
            grad_B = grad_z

            # Added recording
            sp.stats.dL += [grad_x.norm()]
            sp.stats.dG += [grad_z.norm()]
            sp.stats.dW += [grad_W.norm()]

            return None, grad_W, None, torch.zeros_like(A), grad_B, None, None, None