import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_dct as dct
import torchvision
import samplers
import math
import numpy as np
import importlib
import os
from tqdm import tqdm
from  network_unet import UNet
importlib.reload(samplers)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# L-layer MLP

class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims):
        super(MLP, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dims = hidden_dims
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_dim, hidden_dims[0]))
        for i in range(len(hidden_dims)-1):
            self.layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
        self.layers.append(nn.Linear(hidden_dims[-1], output_dim))

    def forward(self, x):
        for i in range(len(self.layers)-1):
            x = F.elu(self.layers[i](x))
        return self.layers[-1](x).reshape(-1, 3, 80, 80)
    
# Class Enc1 is just an input layer plus Elu

class Enc1(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Enc1, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.linear = nn.Linear(input_dim, hidden_dim)

    def forward(self, x):
        return F.elu(self.linear(x))

# Class enc2 is just an output layer

class Enc2(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super(Enc2, self).__init__()
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.linear = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

class TruncatedDCT(nn.Module):
    def __init__(self, num_coeffs):
        super(TruncatedDCT, self).__init__()
        self.num_coeffs = num_coeffs

    def forward(self, x):
        x_dct = dct.dct_2d(x)
        x_dct_trunc = x_dct[:,:,:self.num_coeffs,:self.num_coeffs]
        return x_dct_trunc

class TruncatedIDCT(nn.Module):
    def __init__(self, num_coeffs, original_size):
        super(TruncatedIDCT, self).__init__()
        self.num_coeffs = num_coeffs
        self.original_size = original_size

    def forward(self, X_dct_trunc):
        X_dct_trunc = torch.cat([X_dct_trunc, torch.zeros(X_dct_trunc.shape[0], X_dct_trunc.shape[1], self.original_size - self.num_coeffs, self.num_coeffs, device=device)], dim=2)
        X_dct_trunc = torch.cat([X_dct_trunc, torch.zeros(X_dct_trunc.shape[0], X_dct_trunc.shape[1], self.original_size, self.original_size - self.num_coeffs, device=device)], dim=3)
        x_idct = dct.idct_2d(X_dct_trunc)
        return x_idct

def reparameterize(mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

# https://datascience.stackexchange.com/questions/96271/logcoshloss-on-pytorch
def log_cosh_loss(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    def _log_cosh(x: torch.Tensor) -> torch.Tensor:
        return x + torch.nn.functional.softplus(-2. * x) - math.log(2.0)
    return torch.sum(_log_cosh(y_pred - y_true))

# Define VAE loss with log cosh reconstruction loss

def log_cosh_vae_loss(rec, X, mu, logvar, reg_strength=0.00025):
    B = X.shape[0]
    # Log-cosh reconstruction loss
    alpha = 100
    recon_loss = log_cosh_loss(alpha*rec, alpha*X)
    recon_loss = (1/alpha) * (1/B) * 0.5 * recon_loss
    kl_reg = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + reg_strength * kl_reg

def main():
    # Load CelebA
    to_tensor = torchvision.transforms.ToTensor()
    downsize = torchvision.transforms.Resize((256, 256))
    composed_transform = torchvision.transforms.Compose([downsize, to_tensor])
    celeba_dir = ''
    trainset = torchvision.datasets.CelebA(root=celeba_dir, split='train', download=True, transform=composed_transform)
    batch_size = 16
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=2)

    # Train encoder-decoder model using sparsity-inducing loss

    D = 3*80*80
    input_dim = D
    latent_dim = 700
    hidden_dim = 10000

    num_coeffs = 80

    trunc_dct = TruncatedDCT(num_coeffs=num_coeffs)
    trunc_idct = TruncatedIDCT(num_coeffs=num_coeffs, original_size=256)

    enc1 = Enc1(input_dim, hidden_dim).to(device)
    enc2 = Enc2(hidden_dim, latent_dim).to(device)
    output_unet = UNet(in_nc=3, out_nc=3).to(device)
    enc = nn.Sequential(enc1, enc2).to(device)
    dec = nn.Sequential(MLP(latent_dim//2,input_dim,[hidden_dim]), trunc_idct, output_unet).to(device)

    print(enc)
    print(dec)

    reg_strength = 0.0010
    lr = 1e-4
    start_epoch = 0 
    num_epochs = 100

    optimizer = torch.optim.AdamW(list(enc1.parameters()) + list(enc2.parameters()) + list(dec.parameters()), lr=lr, weight_decay=0)

    losses = []

    # Directory to save checkpoints
    checkpoint_dir = ''
    # Make the directory if it doesn't exist
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    for epoch in tqdm(range(start_epoch, num_epochs)):
        print('Epoch:', epoch)
        for i, (X, _) in enumerate(trainloader):
            X = X.to(device)
            X_dct = trunc_dct(X)
            X_dct_flat = X_dct.reshape(X.shape[0], D)
            latent = enc(X_dct_flat)
            # use first 350 dims of z as mean and next 350 dims as logvar
            mu, logvar = latent[:, :latent_dim//2], latent[:, latent_dim//2:]
            # clamp logvar to prevent numerical blowup when we take exp
            logvar = torch.clamp(logvar, -10, 10)
            z = reparameterize(mu, logvar)
            rec = dec(z)
            loss = log_cosh_vae_loss(rec, X, mu, logvar, reg_strength)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
            if i % 100 == 0:
                print('Epoch: %d, Iter: %d, Loss: %.4f' % (epoch, i, loss.item()))
        
        # Save checkpoint
        if epoch % 1 == 0:
            print('Saving checkpoint')
            torch.save({
                'epoch': epoch,
                'enc1_state_dict': enc1.state_dict(),
                'enc2_state_dict': enc2.state_dict(),
                'dec_state_dict': dec.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'losses': losses,
                }, checkpoint_dir + '/checkpoint_last_epoch.pt')
        
if __name__ == '__main__':
    main()