import os
from time import time

from tqdm import trange
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.utils as vutils

from VAE.model import VAE
from env_list import env_list

ctx = (torch.amp.autocast(device_type='cuda', dtype=torch.float16))

torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn

def transform(imgs, device):
    assert isinstance(imgs, np.ndarray)
    imgs = torch.tensor(imgs, device=device).permute(0, 3, 1, 2).contiguous()
    imgs = (imgs / 255. - 0.5) / 0.5
    imgs = imgs.float()
    return imgs

def kl_loss_f(mu, logvar):
    return -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean()

def build_model(save_dir, device):
    ckpt_path_vae = os.path.join(save_dir, 'ckpt.pt')
    model = VAE(channel_in=3, ch=64, blocks=(1, 2, 4, 8), latent_channels=32, deep_model=True).to(device)
    loss_log = list()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    checkpoint = None
    scaler = torch.cuda.amp.GradScaler()
    return model, optimizer, scaler, None, loss_log

def save_model(ckpt_name):
    torch.save(
        {
            "losses": loss_log,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        },
        os.path.join(save_dir, ckpt_name),
    )
    return

if __name__ == '__main__':
    device = 'cuda'
    save_dir = "VAE/out"
    os.makedirs(save_dir, exist_ok=True)

    data_dir = "../data/{}"

    train_data = {env_name: np.memmap(os.path.join(data_dir.format(env_name), 'observations_train.npy'), dtype=np.uint8, mode='r', shape=(90000, 128, 128, 3)) for env_name in env_list}
    val_data = {env_name: np.memmap(os.path.join(data_dir.format(env_name), 'observations_val.npy'), dtype=np.uint8, mode='r', shape=(10000, 128, 128, 3)) for env_name in env_list}

    model, optimizer, scaler, _, loss_log = build_model(save_dir, device)

    seed = int(f'{time():.10f}'[-9:][::-1])
    generator = np.random.RandomState(seed)

    early_stop = torch.tensor([0], dtype=torch.int64, device=device)
    best_val_loss = 1e9
    iter_total = 10000000
    val_loss = 1e9
    for iter in trange(iter_total, desc='VAE training', leave=False):

        images = np.stack([train_data[env_name][generator.randint(90000)] for env_name in env_list]).astype(np.int64)
        assert images.shape[2] == 128
        images = transform(images, device) # B, env_obs_size, env_obs_size, 3

        # We will train with mixed precision!
        with ctx:
            recon_img, mu, log_var = model(images)
            kl_loss = kl_loss_f(mu, log_var)
            mse_loss = F.mse_loss(recon_img, images)
            loss = kl_loss + mse_loss

            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
            scaler.step(optimizer)
            scaler.update()

        if (iter + 1) % 1000 == 0:
            model.eval()
            val_losses = []
            for eval_iter in range(100):
                val_images = np.stack([val_data[env_name][generator.randint(10000)] for env_name in env_list]).astype(np.int64)
                val_images = transform(val_images, device)
                with torch.no_grad():
                    with ctx:
                        recon_img, mu, log_var = model(val_images)
                        val_loss = F.mse_loss(recon_img, val_images).item()
                        val_losses.append(val_loss)
            val_loss = np.mean(val_losses)

            loss_log.append(val_loss)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                save_model('ckpt.pt')
                img_cat = torch.cat((recon_img.cpu(), val_images.cpu()), 2).float()
                vutils.save_image(img_cat, "{}/recon_images_val.png".format(save_dir), normalize=True)
                early_stop = torch.tensor([0], dtype=torch.int64, device=device)
            else:
                early_stop += 1
            model.train()

        if early_stop >= 100:
            break