import torch
import numpy as np 
import numpy.linalg as LA
import torch.nn as nn
from torch.nn import Linear, Conv2d, SELU
from torch import sigmoid
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F

class DictNet(torch.nn.Module):
    def __init__(self, A, G):
        super(DictNet, self).__init__()
        self.A = A
        self.G = G

    def forward(self, z):
        #zs: P by k
        A = self.A
        G = self.G
        P = z.shape[0]
        m = A.shape[0]
        n = A.shape[1]
        Gzs = G(z)
        ys = torch.matmul(A, Gzs.view(P, n, 1))
        return ys

class DictNet_withinput(torch.nn.Module):
    def __init__(self, A, G, P):
        super(DictNet_withinput, self).__init__()
        self.A = A
        self.G = G
        self.z = nn.Parameter(torch.zeros((P, k), device=device), requires_grad=True)
        self.P = P

    def forward(self):
        #zs: P by k
        A = self.A
        G = self.G
        z = self.z
        P = self.P
        m = A.shape[0]
        n = A.shape[1]
        Gzs = G(z)
        ys = torch.matmul(A, Gzs.view(P, n, 1))
        return ys

class StackedNet(torch.nn.Module):
    # for autodiff-ing through both at once
    def __init__(self, G, m, n):
        super(StackedNet, self).__init__()
        self.A = Linear(n, m, bias=False)
        self.G = G
        self.m = m
        self.n = n

    def forward(self, z):
        #zs: P by k
        P = z.shape[0]
        G = self.G
        Gzs = G(z)
        ys = self.A(Gzs.view(P, n)).view(P, -1, 1)
        return ys

class GenericStackedNet(torch.nn.Module):
    # for autodiff-ing through both at once
    def __init__(self, A, G):
        super(GenericStackedNet, self).__init__()
        self.A = A
        self.G = G

    def forward(self, z):
        #zs: P by k
        return self.A(self.G(z))

class A_matrix(torch.nn.Module):  
    def __init__(self, m, n, nonlinearity=None):
        super(A_matrix, self).__init__()
        self.A = Linear(n, m, bias=False)
        self.m = m
        self.n = n
        self.nonlinearity = nonlinearity

    def forward(self, zs):
        if self.nonlinearity is None:
            return self.A(zs)
        elif self.nonlinearity is 'SELU':
            return SELU()(self.A(zs))
        elif self.nonlinearity is 'sigmoid':
            return sigmoid()(self.A(zs))

class A_lowrank_matrix(torch.nn.Module):  
    def __init__(self, m, n, r, nonlinearity=None):
        super(A_lowrank_matrix, self).__init__()
        self.A1 = Linear(n, r, bias=False)
        self.A2 = Linear(r, m, bias=False)
        self.m = m
        self.n = n
        self.nonlinearity = nonlinearity
    def forward(self, zs):
        if self.nonlinearity is None:
            return self.A2(self.A1(zs))
        elif self.nonlinearity is 'SELU':
            return SELU()(self.A2(self.A1(zs)))
        elif self.nonlinearity is 'sigmoid':
            return sigmoid()(self.A2(self.A1(zs)))

class A_2dconv(torch.nn.Module):
# note: this is a linear layeR!
    def __init__(self, n, nonlinearity=None):
        super(A_2dconv, self).__init__()
        self.n = n
        sidelen = int(np.sqrt(n))
        if sidelen % 2 == 0:
            kernel_size = sidelen-1
        else:
            kernel_size = sidelen
        # needs periodic boundary conditions?
        self.convlayer = Conv2d(1, 1, stride=1, kernel_size=kernel_size, padding=int(np.ceil(sidelen/2))-1, bias=False)
        self.nonlinearity = nonlinearity

    def forward(self, Gzs):
        # Gzs is P by n
        # n has to be a square number
        P = Gzs.shape[0]
        n = self.n
        sidelen = int(np.sqrt(n))
        out = self.convlayer(Gzs.view(P, 1, sidelen, sidelen))
        # will be P by 1 by sidelen by sidelen
        if self.nonlinearity is None:
            return out.view(P, n)
        elif self.nonlinearity is 'SELU':
            return SELU()(out.view(P, n))
        elif self.nonlinearity is 'sigmoid':
            return sigmoid()(out.view(P, n))

class A_MLP(torch.nn.Module):
    def __init__(self, m, n):
        super(A_MLP, self).__init__()
        self.m = m
        self.n = n
        numhidden = n*2
        self.l1 = Linear(n, numhidden)
        self.l2 = Linear(numhidden, m)

    def forward(self, Gzs):
        # Gzs is P by n
        return torch.nn.SELU(self.l2(torch.nn.SELU(self.l1(Gzs))))

class G_sparse(torch.nn.Module):
    # ISSUE: this is not continuous w.r.t z! backprop does not make so much sense...
    def __init__(self, a, n, device=torch.device('cpu')):
        # random k positions
        # = to +/- a with equal probability
        super(G_sparse, self).__init__()
        self.a = a
        #self.k = k
        self.n = n
        self.device = device

    def forward(self, z):
        # z is a P by k array, each of which is an index between 0 and n-1 inclusive
        # G will now assign +/- a
        # this might not be Lipschitz?
        n = self.n
        P = z.shape[0]
        k = z.shape[1]
        out = torch.zeros((P, n), requires_grad=False, device=self.device).float()
        for i in range(P):
            for j in range(k):
                if torch.rand(1) < 0.5:
                    fac = -1
                else:
                    fac = 1
                out[i, z[i, j]] = float(fac*self.a)
        print('out dtype', out.dtype)
        return out

class G_bipartite(torch.nn.Module): # this is going to have no learnable parameters: issue? 
    def __init__(self, adj_mat, nonlinear_func='square_and_sum', device=torch.device('cpu')):
        super(G_bipartite, self).__init__()
        self.adj_mat = adj_mat # n by k
        self.nonlinear_func = nonlinear_func
        self.k = adj_mat.shape[1]

    def forward(self, z):
        if self.nonlinear_func == 'square_and_sum':
            return torch.matmul(self.adj_mat, z.unsqueeze(-1) ** 2).squeeze(-1)

def make_G_bipartite(adj_mat, nonlinear_func):
    if nonlinear_func == 'square_and_sum':
        return (lambda z: torch.matmul(adj_mat, z.unsqueeze(-1) ** 2).squeeze(-1))
    elif nonlinear_func == 'square_and_sum_and_square':
        return (lambda z: (torch.matmul(adj_mat, z.unsqueeze(-1) ** 2)**2).squeeze(-1))
    elif nonlinear_func == 'square_and_sum_and_relu':
        return (lambda z: torch.nn.ReLU()(torch.matmul(adj_mat, z.unsqueeze(-1) ** 2)).squeeze(-1))
    elif nonlinear_func == 'square_and_sum_and_sigmoid':
        return (lambda z: torch.sigmoid(torch.matmul(adj_mat, z.unsqueeze(-1) ** 2)).squeeze(-1))

def pointwisesquare(z):
    return z**2

def copy(z, func):
    return torch.cat((z, func(z), func(z), func(z)), dim=1)

def copylist(z, funclist):
    temp = torch.cat(tuple(funclist[i](z) for i in range(len(funclist))), dim=1)
    return torch.cat((z, temp), dim=1)

def make_G_redundant(func):
    def Gfunc(z):
        return copy(z, func)
    return Gfunc

def make_G_redundant_list(funclist, inpfx=None):
    if inpfx is None:
        def Gfunc(z):
            return copylist(z, funclist)
    else:
        def Gfunc(z):
            return copylist(inpfx(z), funclist)
    return Gfunc

def make_G_sparse(n):
    # assume z iid unif random on [0,1]
    def func(z):
        P = z.shape[0]
        k = z.shape[1]
        zeromat = torch.zeros(P, n, k).to(z.device);
        temp = torch.matmul(zeromat, z.unsqueeze(-1)).squeeze(-1)
        z = z - torch.min(z)
        z = z / (torch.max(z) + 1e-6)
        inds = torch.floor(z*n).long()
        rowinds = torch.tensor(range(P)).view(P,1).repeat(1,k)
        temp[rowinds, inds] = 1
        return temp
    return func


def make_G_sparse_bad(n):
    # assume z iid unif random on [0,1]
    def func(z):
        P = z.shape[0]
        temp = torch.zeros(P, n); 
        inds = torch.floor(z*n).long()
        rowinds = torch.tensor(range(P)).view(P,1).repeat(1,k)
        temp[rowinds, inds] = 1
        return temp
    return func

class G_bipartite2(torch.nn.Module): # this is going to have no learnable parameters: issue? 
    def __init__(self, adj_mat, nonlinear_func='square_and_sum', device=torch.device('cpu')):
        super(G_bipartite, self).__init__()
        self.adj_mat = adj_mat # n by k
        self.nonlinear_func = nonlinear_func

    def forward(self, z):
        if self.nonlinear_func == 'square_and_sum':
            return torch.matmul(self.adj_mat, z.unsqueeze(-1) ** 2).squeeze(-1)

# From https://github.com/lyeoni/pytorch-mnist-VAE/blob/master/pytorch-mnist-VAE.ipynb
class aVAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim, do_sigmoid):
        super(aVAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)

        self.do_sigmoid = do_sigmoid
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        h = self.fc6(h)
        if self.do_sigmoid:
            h = F.sigmoid(h)
        return h
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var
