import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.linalg as linalg
import math

from globals import *
from agfunctions import *

class pseudolinear(nn.Module):

    def __init__(self, out_dim, in_dim):
        super().__init__()
        W = torch.empty((out_dim, in_dim), dtype=DTYPE, device=DEVICE).normal_(mean=0.0,std=math.sqrt(2/in_dim))
        W_copy = W.detach().clone()
        b = torch.zeros((out_dim), dtype=DTYPE, device=DEVICE) 
        self.W = nn.Parameter(W)
        self.B = torch.linalg.pinv(W_copy)
        self.b = nn.Parameter(b)
        self.grad = pseudograd.apply
        self.nonlinearity = nn.LeakyReLU(0.1)

    def forward(self, x):
        h = self.grad(x, self.W, self.B, self.b)
        a = self.nonlinearity(h)
        return a

    def update_backwards(self):
        #new_W = self.W.clone()
        #self.W_inv = torch.linalg.pinv(new_W)
        #print(torch.norm(self.W_inv - W_inv, p='fro'))
        return
    

class bplinear(nn.Module):

    def __init__(self, out_dim, in_dim):
        super().__init__()
        W = torch.empty((out_dim, in_dim), dtype=DTYPE, device=DEVICE).normal_(mean=0.0,std=math.sqrt(2/in_dim))
        b = torch.zeros((out_dim), dtype=DTYPE, device=DEVICE) 
        self.W = nn.Parameter(W)
        self.b = nn.Parameter(b)
        self.W_t = W.transpose(0,1)
        self.grad = bpgrad.apply
        self.nonlinearity = nn.LeakyReLU(0.1)

    def forward(self, x):
        h = self.grad(x, self.W, self.b)
        a = self.nonlinearity(h)
        return a
    
    def update_backwards(self):
        #new_W = self.W.clone()
        #self.W_t = get_transpose(new_W)
        return


class randomlinear(nn.Module):

    def __init__(self, out_dim, in_dim):
        super().__init__()
        W = torch.empty((out_dim, in_dim), dtype=DTYPE, device=DEVICE).normal_(mean=0.0,std=math.sqrt(2/in_dim))
        B = torch.empty((in_dim, out_dim), dtype=DTYPE, device=DEVICE).normal_(mean=0.0,std=math.sqrt(2/in_dim))
        b = torch.zeros((out_dim), dtype=DTYPE, device=DEVICE) 
        self.W = nn.Parameter(W)
        self.b = nn.Parameter(b)
        self.B = nn.Parameter(B)
        self.grad = randomgrad.apply
        self.nonlinearity = nn.LeakyReLU(0.1)

    def forward(self, x):
        h = self.grad(x, self.W, self.B, self.b)
        a = self.nonlinearity(h)
        return a
    
    def update_backwards(self):
        # Never update random backwards matrix
        return
    