#!/bin/python3
import sys

sys.path.append("..")
import os
from pathlib import Path

import torch
from hyperdiffusion import FrozenUnet, HyperDiffusion
from torch.utils.data import DataLoader, Subset
from torchvision.transforms import transforms
from tqdm import tqdm

from era5 import Era5

# Path to the directory containing the ERA5 dataset *.npy files generated by create_era5_data.py
PATH = ""

if __name__ == "__main__":
    print("SEED:", torch.random.initial_seed())
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    write_dir = Path(f"weights")
    write_dir.mkdir(exist_ok=True, parents=True)

    # Experiment settings
    N = 1000
    B = 32
    T = 100
    R = 128  # resolution
    num_epochs = 400
    n_params = 2440241

    tfm = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((R, R), antialias=True),
        ]
    )
    dset = Era5(PATH, tfm)
    train_set = Subset(dset, range(N))
    loader = DataLoader(train_set, batch_size=B, shuffle=True)

    backbone = FrozenUnet(
        dim=16, dim_mults=(1, 2, 4, 8), channels=1, self_condition=True, device=device
    )
    model = HyperDiffusion(backbone, image_size=R, timesteps=T, n_params=n_params).to(
        device
    )
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    pbar = tqdm(range(num_epochs))
    for epoch in pbar:
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            t = torch.randint(0, T, (len(x),)).to(device)
            in_vec = torch.randn(model.in_dim, device=device)
            loss = model.p_losses(x, t, y, in_vec)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_description(f"Loss: {loss.item():.3f}")
    torch.save(
        model.state_dict(),
        os.path.join(
            write_dir,
            f"hyperddpm.pt",
        ),
    )
