import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_dct as dct
import torchvision
import os
from  network_unet import UNet

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# L-layer MLP

class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims):
        super(MLP, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dims = hidden_dims
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_dim, hidden_dims[0]))
        for i in range(len(hidden_dims)-1):
            self.layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
        self.layers.append(nn.Linear(hidden_dims[-1], output_dim))

    def forward(self, x):
        for i in range(len(self.layers)-1):
            x = F.elu(self.layers[i](x))
        return self.layers[-1](x).reshape(-1, 3, 80, 80)
    
# Class Enc1 is just an input layer plus Elu

class Enc1(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Enc1, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.linear = nn.Linear(input_dim, hidden_dim)

    def forward(self, x):
        return F.elu(self.linear(x))

# Class enc2 is just an output layer

class Enc2(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super(Enc2, self).__init__()
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.linear = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

class TruncatedDCT(nn.Module):
    def __init__(self, num_coeffs):
        super(TruncatedDCT, self).__init__()
        self.num_coeffs = num_coeffs

    def forward(self, x):
        x_dct = dct.dct_2d(x)
        x_dct_trunc = x_dct[:,:,:self.num_coeffs,:self.num_coeffs]
        return x_dct_trunc

class TruncatedIDCT(nn.Module):
    def __init__(self, num_coeffs, original_size):
        super(TruncatedIDCT, self).__init__()
        self.num_coeffs = num_coeffs
        self.original_size = original_size

    def forward(self, X_dct_trunc):
        X_dct_trunc = torch.cat([X_dct_trunc, torch.zeros(X_dct_trunc.shape[0], X_dct_trunc.shape[1], self.original_size - self.num_coeffs, self.num_coeffs, device=device)], dim=2)
        X_dct_trunc = torch.cat([X_dct_trunc, torch.zeros(X_dct_trunc.shape[0], X_dct_trunc.shape[1], self.original_size, self.original_size - self.num_coeffs, device=device)], dim=3)
        x_idct = dct.idct_2d(X_dct_trunc)
        return x_idct

def main():
    # Load the model
    D = 3*80*80
    input_dim = D
    latent_dim = 700
    hidden_dim = 10000

    num_coeffs = 80

    trunc_dct = TruncatedDCT(num_coeffs=num_coeffs)
    trunc_idct = TruncatedIDCT(num_coeffs=num_coeffs, original_size=256)

    enc1 = Enc1(input_dim, hidden_dim).to(device)
    enc2 = Enc2(hidden_dim, latent_dim).to(device)
    output_unet = UNet(in_nc=3, out_nc=3).to(device)
    dec = nn.Sequential(MLP(latent_dim,input_dim,[hidden_dim]), trunc_idct, output_unet).to(device)

    print(enc1)
    print(enc2)
    print(dec)

    # Load checkpoint
    checkpt_fname = ''
    checkpt = torch.load(checkpt_fname, map_location=device)

    enc1.load_state_dict(checkpt['enc1_state_dict'])
    enc2.load_state_dict(checkpt['enc2_state_dict'])
    dec.load_state_dict(checkpt['dec_state_dict'])

    del checkpt

    # Load CFDM samples and nearest neighbors
    samples_path = ''
    samples_tensor = torch.load(samples_path).to(device)
    nn_path = ''
    nn_tensor = torch.load(nn_path).to(device)

    # Iterate over samples and decode one by one
    decoded_samples = []
    decoded_samples_nn = []
    for i in range(samples_tensor.shape[0]):
        sample = samples_tensor[i].unsqueeze(0)
        sample_nn = nn_tensor[i].unsqueeze(0)
        decoded_sample = dec(sample).detach()
        decoded_sample_nn = dec(sample_nn).detach()
        decoded_samples.append(decoded_sample)
        decoded_samples_nn.append(decoded_sample_nn)
    decoded_samples = torch.cat(decoded_samples)
    decoded_samples_nn = torch.cat(decoded_samples_nn)

    # Save decoded samples
    save_dir = ''
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    torch.save(decoded_samples, os.path.join(save_dir, "decoded_samples.pt"))
    torch.save(decoded_samples_nn, os.path.join(save_dir, "decoded_samples_nn.pt"))

    # Also save samples and NNs as images
    save_dir_images = ''
    if not os.path.exists(save_dir_images):
        os.makedirs(save_dir_images)
    for i in range(decoded_samples.shape[0]):
        torchvision.utils.save_image(decoded_samples[i], os.path.join(save_dir_images, f"decoded_sample_{i}.png"))
        torchvision.utils.save_image(decoded_samples_nn[i], os.path.join(save_dir_images, f"decoded_sample_nn_{i}.png"))

if __name__ == '__main__':
    main()