import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.function import once_differentiable
import torch.linalg as linalg

from globals import *
from rankone_update import *
from utils import *
    
    
class pseudograd(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, W, B, b):
        ctx.save_for_backward(x, W, B)
        # Implementing batch support w transpose
        WT = W.transpose(0,1)
        a = x @ WT + b
        if (a != a).any():
            raise Exception("a : {}".format(a))
        return a
    
    @staticmethod
    def backward(ctx, d_out):
        x, W, B = ctx.saved_tensors
        dx = dW = dB = db = None
        B = torch.linalg.pinv(W)
        BT = torch.transpose(B,0,1)
        dx = d_out @ BT
        dW = d_out.transpose(0,1) @ x 
        return dx, dW, dB, db
    

class bpgrad(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, W, b):
        # Implementing batch support w transpose
        ctx.save_for_backward(x, W)
        WT = torch.transpose(W,0,1)
        a = x @ WT + b
        if (a != a).any():
            raise Exception("a : {}".format(a))
        return a
    
    @staticmethod
    def backward(ctx, d_out):
        x, W = ctx.saved_tensors
        dx = dW = db = None
        #dx = torch.matmul(W_t, d_out)
        dx = d_out @ W
        #dW = torch.outer(d_out, x)
        dW = d_out.transpose(0,1) @ x
        if (d_out != d_out).any():
            raise Exception("d_out : {}".format(d_out))
        if (x != x).any():
            raise Exception("x : {}".format(x))
        if (dW != dW).any():
            raise Exception("dW : {}".format(dW))
        #db = d_out
        return dx, dW, db


class randomgrad(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, W, B, b):
        ctx.save_for_backward(x, B)
        # Implementing batch support w transpose
        WT = torch.transpose(W,0,1)
        a = x @ WT + b
        return a
    
    @staticmethod
    def backward(ctx, d_out):
        x, B = ctx.saved_tensors
        dx = dW = dB = db = None
        BT = torch.transpose(B,0,1)
        dx = d_out @ BT
        dW = (d_out[...,None] @ torch.transpose(x[...,None],1,2)).sum(dim=0)
        db = d_out
        return dx, dW, dB, db
