from tqdm import tqdm

import wandb

import PIL
from matplotlib import pyplot as plt

import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, ConcatDataset, random_split
from torchvision.transforms.v2.functional import to_pil_image
from torchvision.utils import save_image
from torchmetrics import MetricCollection
from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError, MeanAbsolutePercentageError

import utils
from data.navier_stokes_dataset import NavierStokesDataset
from models.fno import FNO2d
from neuralop.models import FNO

def train():
    data_type = "navier-stokes"
    device_str = "cuda"
    model_type = "fno"
    # wandb_mode = "disabled"
    wandb_mode = "online"
    wandb_name = "train-fno-orig-time-emb-ds-ns-64x64-bs-8-1"
    config = {
        "data_type": data_type,
        "history_len": 1,
        "prediction_len": 1
        }
    run = wandb.init(project="diff-sci-ml", config=config, group=f"train-{data_type}", mode=wandb_mode, name=wandb_name)
    ST, FT = 8, 64
    T = FT-ST
    print("Starting training")
    dev = torch.device(device_str)
    ds_path = "../data/ns_data_64_1000.h5"
    print(f"Loading dataset file {ds_path}")
    dataset = NavierStokesDataset(ds_path)
    train_dataset, rl_dataset, val_dataset, test_dataset = random_split(dataset, [0.40, 0.40, 0.10, 0.10], generator=torch.Generator().manual_seed(42))
    large_dataset = ConcatDataset((train_dataset, rl_dataset))
    # create model with similar number of parameters as unet
    model = FNO(n_modes=(27, 27), hidden_channels=64, in_channels=2, out_channels=1)
    model = model.to(dev)
    model.train()
    optimizer = Adam(model.parameters(), lr=0.00001)
    # dataset statistics for 64x64:
    smoke_mean = 0.7322910149367649
    smoke_std = 0.43668617344401106
    smoke_min = 4.5986561758581956e-07
    smoke_max = 2.121663808822632
    batch_size = 8
    train_dataloader = DataLoader(train_dataset, batch_size, True, pin_memory=True)
    val_dataloader = DataLoader(val_dataset, batch_size, False, pin_memory=True)
    metrics = MetricCollection([
        MeanSquaredError(), MeanAbsoluteError(), MeanAbsolutePercentageError()
    ])
    train_metrics = metrics.clone(prefix="train_").to(dev)
    val_metrics = metrics.clone(prefix="val_").to(dev)
    total_mse = 0.0
    epochs = 50
    for epoch in tqdm(range(epochs), unit="epoch"):
        for i, batch in enumerate(tqdm(train_dataloader, unit="batch")):
            sim_ids = batch["sim_id"].to(dev)
            smoke = batch["smoke"].unsqueeze(2).to(dev)
            # standardize data
            smoke = (smoke - smoke_mean) / smoke_std
            # smoke shape [8, 64, 1, 128, 128]
            # make this range dynamic
            for t in range(8, 64):
                inputs = smoke[:, t-1, :, :, :]
                # for now scale static time to be between 0 and 1
                time_map = torch.full((batch_size, inputs.shape[1], inputs.shape[2], inputs.shape[3]), t/FT, dtype=torch.float, device=dev)
                inputs = torch.cat((inputs, time_map), dim=1)
                # set ground truth to next timestep smoke field
                gt = smoke[:, t, :, :, :]
                output = model(inputs)
                # compute loss, note that we unstandardize the data here to get back to the scale of the original data
                loss = F.mse_loss((output * smoke_std) + smoke_mean, (gt * smoke_std) + smoke_mean)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_mse += loss
                mets = train_metrics((output * smoke_std) + smoke_mean, (gt * smoke_std) + smoke_mean)
            batch_mets = train_metrics.compute()
            print(f"batch metrics: {batch_mets}")
        epoch_mets = train_metrics.compute()
        wandb.log({
            "epoch": epoch,
            "train_loss": total_mse/(len(train_dataloader)),
            "train_mse": epoch_mets['train_MeanSquaredError'],
            "train_mae": epoch_mets['train_MeanAbsoluteError'],
            "train_mape": epoch_mets['train_MeanAbsolutePercentageError'],
            })
        # reset metric states after each epoch
        train_metrics.reset()
        val_metrics.reset()
        total_mse = 0.0
    utils.save_model_checkpoint(epoch, device_str, model, optimizer, "./", f"fno_time_static_shuffle_{data_type}_64x64")
    wandb.finish()

if __name__ == "__main__":
    train()
