import torch
from torch import nn
from torch.nn import functional as F

class VAE(nn.Module):
    def __init__(self, D, r, latent_dim, hidden=768):
        super(VAE, self).__init__()
        self.D = D
        self.r = r
        self.latent_dim = latent_dim

        self.encoder = nn.Sequential(
            nn.Linear(D * r, hidden),
            nn.ReLU(),
            nn.Linear(hidden, latent_dim * D * 2),
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim * D, hidden),
            nn.ReLU(),
            nn.Linear(hidden, D * r),
        )

    def encode(self, x):
        x = x.reshape(self.D * self.r)
        x = self.encoder(x)
        mu = x[:self.latent_dim * self.D]
        logvar = x[self.latent_dim * self.D:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        z = self.decoder(z)
        return z.view(self.D, self.r)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        z_reshaped = z.view(self.D, self.latent_dim)
        return self.decode(z), z_reshaped, mu, logvar

def loss_function(recon_x, x, mu, logvar):
    MSE = F.mse_loss(recon_x.reshape(-1), x.reshape(-1), reduction='mean')
    KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return MSE + KLD


class AE(nn.Module):
    def __init__(self, D, r, latent_dim, hidden=768 * 2):
        super(AE, self).__init__()
        self.D = D
        self.r = r
        self.latent_dim = latent_dim

        self.encoder = nn.Sequential(
            nn.Linear(D * r, hidden),
            nn.ReLU(),
            nn.Linear(hidden, latent_dim * D),
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim * D, hidden),
            nn.ReLU(),
            nn.Linear(hidden, D * r),
        )

    def encode(self, x):
        x = x.reshape(self.D * self.r)
        x = self.encoder(x)
        return x

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        z = self.decoder(z)
        return z.view(self.D, self.r)

    def forward(self, x):
        z = self.encode(x)
        z_reshaped = z.view(self.D, self.latent_dim)
        return self.decode(z), z_reshaped, z, z

def loss_function_1(recon_x, x):
    MSE = F.mse_loss(recon_x.reshape(-1), x.reshape(-1), reduction='sum')
    return MSE


def loss_function_2(recon_x, x):
    loss_fn = nn.L1Loss()
    MAE = loss_fn(recon_x.reshape(-1), x.reshape(-1))
    return MAE


class AEW(nn.Module):
    def __init__(self, D, r, hidden=8, bias=True, dropout=0.5):
        super(AEW, self).__init__()
        self.D = D
        self.r = r

        self.en_w1 = nn.Parameter(torch.FloatTensor(r, hidden))
        self.de_w1 = nn.Parameter(torch.FloatTensor(r, hidden))

        self.dropout = nn.Dropout(p=dropout)

        self.init_params()

    def init_params(self):
        for param in self.parameters():
            if len(param.size()) == 2:
                nn.init.xavier_uniform_(param)
            else:
                nn.init.constant_(param, 0.0)

    def encode(self, x):
        x = torch.einsum('dr,rh->dh', x, self.en_w1)
        return x

    def decode(self, z):
        z = torch.einsum('dh,rh->dr', z, self.de_w1)
        return z

    def forward(self, x):
        z = self.encode(x)
        return self.decode(z), z, None, None


class AEWW(nn.Module):
    def __init__(self, D, r, hidden=8, bias=True, dropout=0.5, c=8):
        super(AEWW, self).__init__()
        self.D = D
        self.r = r


        self.en_w1 = nn.Parameter(torch.FloatTensor(D, c))
        self.en_w2 = nn.Parameter(torch.FloatTensor(r, hidden))
        self.de_w1 = nn.Parameter(torch.FloatTensor(r, hidden))
        self.de_w2 = nn.Parameter(torch.FloatTensor(D, c))
        self.dropout = nn.Dropout(p=dropout)

        self.init_params()

    def init_params(self):
        for param in self.parameters():
            if len(param.size()) == 2:
                nn.init.xavier_uniform_(param)
            else:
                nn.init.constant_(param, 0.0)

    def encode(self, x):

        x = torch.einsum('dr,dc->cr', x, self.en_w1)

        x = torch.einsum('cr,rh->ch', x, self.en_w2)

        return x

    def decode(self, z):
        z = torch.einsum('ch,rh->cr', z, self.de_w1)
        z = torch.einsum('cr,dc->dr', z, self.de_w2)
        return z

    def forward(self, x):
        z = self.encode(x)
        return self.decode(z), z, None, None
