import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.linalg as linalg
import math

from globals import *
from agfunctions import *
from modules import *
from utils import *

class FCMNIST (nn.Module):

    def __init__(self, grad_type='backprop'):
        super().__init__()

        self.layers = nn.ModuleList()

        if (grad_type == 'backprop'):
            self.linearity = bplinear
            self.nonlinearity = nn.LeakyReLU(0.01)
        elif (grad_type == 'pseudo'):
            self.linearity = pseudolinear
            self.nonlinearity = nn.LeakyReLU(0.01) 
        elif (grad_type == 'random'):
            self.linearity = randomlinear
            self.nonlinearity = nn.LeakyReLU(0.01)

        self.layers.append(self.linearity(400, 28*28))
        self.layers.append(self.linearity(200, 400))
        self.layers.append(self.linearity(100, 200))
        self.layers.append(self.linearity(50, 100))
        self.layers.append(self.linearity(10, 50))
    
    def forward(self, x):
        for l in self.layers:
            x = l(x)
            h = self.nonlinearity(x)
        return h

    def update_backwards(self):
        #for l in self.layers:
        #    l.update_backwards()
        return

class FCCIFAR (nn.Module):

    def __init__(self, grad_type='backprop'):
        super().__init__()

        self.layers = nn.ModuleList()

        if (grad_type == 'backprop'):
            self.linearity = bplinear
            self.nonlinearity = nn.LeakyReLU(0.01)
        elif (grad_type == 'pseudo'):
            self.linearity = pseudolinear
            self.nonlinearity = nn.LeakyReLU(0.01) #InvertibleLeakyReLU(negative_slope=0.01)
        elif (grad_type == 'random'):
            self.linearity = randomlinear
            self.nonlinearity = nn.LeakyReLU(0.01)

        self.layers.append(self.linearity(1000, 32*32*3))
        self.layers.append(self.linearity(500, 1000))
        self.layers.append(self.linearity(100, 500))
        self.layers.append(self.linearity(10, 100))
    
    def forward(self, x):
        for l in self.layers:
            x = l(x)
            h = self.nonlinearity(x)
        return h

    def update_backwards(self):
        #for l in self.layers:
        #    l.update_backwards()
        return
