from tqdm import tqdm

import wandb

import PIL
from matplotlib import pyplot as plt

import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, ConcatDataset, random_split
from torchvision.transforms.v2.functional import to_pil_image
from torchvision.utils import save_image
from torchmetrics import MetricCollection
from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError, MeanAbsolutePercentageError

import utils
from data.navier_stokes_dataset import NavierStokesDataset
from models.twod_unet import Unet
from models.twod_unet_cond import Unet as UnetCond

def train():
    data_type = "navier-stokes"
    device_str = "cuda"
    model_type = "unet_cond"
    # wandb_mode = "disabled"
    wandb_mode = "online"
    wandb_name = "train-unet-time-emb-ds-ns-64x64-bs-8-1"
    config = {
        "data_type": data_type,
        "history_len": 1,
        "prediction_len": 1
        }
    run = wandb.init(project="diff-sci-ml", config=config, group=f"train-{data_type}", mode=wandb_mode, name=wandb_name)
    print("Starting training")
    dev = torch.device(device_str)
    ds_path = "../data/ns_data_64_1000.h5"
    print(f"Loading dataset file {ds_path}")
    dataset = NavierStokesDataset(ds_path)
    train_dataset, rl_dataset, val_dataset, test_dataset = random_split(dataset, [0.40, 0.40, 0.10, 0.10], generator=torch.Generator().manual_seed(42))
    # create model
    if model_type == "unet_cond":
        model = UnetCond(1, 0, 1, 0, 1, 1, 64, "gelu", ch_mults=[1, 2, 2], is_attn=[False, False, False]).to(dev)
    elif model_type == "unet_attn_cond":
        model = UnetCond(1, 0, 1, 0, 1, 1, 64, "gelu", ch_mults=[1, 2, 2], is_attn=[True, True, True]).to(dev)
    elif model_type == "unet_attn":
        model = Unet(1, 0, 1, 0, 1, 1, 64, "gelu", ch_mults=[1, 2, 2], is_attn=[True, True, True]).to(dev)
    else:
        model = Unet(1, 0, 1, 0, 1, 1, 64, "gelu", ch_mults=[1, 2, 2], is_attn=[False, False, False]).to(dev)
    optimizer = Adam(model.parameters(), lr=0.0001)
    # dataset statistics for 64x64:
    smoke_mean = 0.7322910149367649
    smoke_std = 0.43668617344401106
    smoke_min = 4.5986561758581956e-07
    smoke_max = 2.121663808822632
    train_dataloader = DataLoader(train_dataset, 8, True, pin_memory=True)
    val_dataloader = DataLoader(val_dataset, 8, False, pin_memory=True)
    metrics = MetricCollection([
        MeanSquaredError(), MeanAbsoluteError(), MeanAbsolutePercentageError()
    ])
    train_metrics = metrics.clone(prefix="train_").to(dev)
    val_metrics = metrics.clone(prefix="val_").to(dev)
    total_mse = 0.0
    epochs = 50
    for epoch in tqdm(range(epochs), unit="epoch"):
        for i, batch in enumerate(tqdm(train_dataloader, unit="batch")):
            sim_ids = batch["sim_id"].to(dev)
            smoke = batch["smoke"].unsqueeze(2).to(dev)
            # standardize data
            smoke = (smoke - smoke_mean) / smoke_std
            # smoke shape [8, 64, 1, 128, 128]
            # make this range dynamic
            for t in range(8, 64):
                inputs = smoke[:, t-1, :, :, :].unsqueeze(1)
                # set ground truth to next timestep smoke field
                gt = smoke[:, t, :, :, :].unsqueeze(1)
                if model_type in {"unet_cond", "unet_attn_cond"}:
                    time = torch.full((inputs.shape[0],), t, device=dev)
                    output = model(inputs, time)
                else:
                    # input should be B x T x C x H x W
                    output = model(inputs)
                # compute loss, note that we unstandardize the data here to get back to the scale of the original data
                loss = F.mse_loss((output * smoke_std) + smoke_mean, (gt * smoke_std) + smoke_mean)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_mse += loss
                mets = train_metrics((output * smoke_std) + smoke_mean, (gt * smoke_std) + smoke_mean)
            batch_mets = train_metrics.compute()
            print(f"batch metrics: {batch_mets}")
        epoch_mets = train_metrics.compute()
        wandb.log({
            "epoch": epoch,
            "train_loss": total_mse/(len(train_dataloader)),
            "train_mse": epoch_mets['train_MeanSquaredError'],
            "train_mae": epoch_mets['train_MeanAbsoluteError'],
            "train_mape": epoch_mets['train_MeanAbsolutePercentageError'],
            })
        # reset metric states after each epoch
        train_metrics.reset()
        val_metrics.reset()
        total_mse = 0.0
    utils.save_model_checkpoint(epoch, device_str, model, optimizer, "./", f"unetmod_cond_time_emb_shuffle_800_{data_type}_64x64")
    wandb.finish()

if __name__ == "__main__":
    train()
