#!/bin/python3
import sys

sys.path.append("..")

import os
from pathlib import Path

import torch
from torch.utils.data import DataLoader, Subset
from torchvision.transforms import transforms
from tqdm import tqdm

from denoising_diffusion_pytorch.denoising_diffusion_pytorch import (
    DropoutDiffusion,
    DropoutUnet,
)
from luna import Luna

# Path to the directory containing the LUNA dataset *.npy files generated by create_luna_dataset.py
PATH = ""

if __name__ == "__main__":
    print("SEED:", torch.random.initial_seed())
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    write_dir = Path(f"weights")
    write_dir.mkdir(exist_ok=True, parents=True)

    # Experiment settings
    N = 1000
    B = 32
    T = 100
    R = 128
    num_epochs = 400

    tfm = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((R, R), antialias=True),
        ]
    )
    dset = Luna(PATH, tfm)
    train_set = Subset(dset, range(N))
    loader = DataLoader(train_set, batch_size=B, shuffle=True)

    backbone = DropoutUnet(
        dim=16,
        dim_mults=(1, 2, 4, 8),
        channels=1,
        self_condition=True,
        dropout_rate=0.3,
    )
    model = DropoutDiffusion(backbone, image_size=R, timesteps=T).to(device)
    model.init_rng()
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    pbar = tqdm(range(num_epochs))
    for epoch in pbar:
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            t = torch.randint(0, T, (len(x),)).to(device)
            loss = model.p_losses(x, t, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_description(f"Loss: {loss.item():.3f}")
    torch.save(
        model.state_dict(),
        os.path.join(write_dir, f"mcdropout.pt"),
    )
