import math
from collections import deque
from time import time
from pathlib import Path
from tqdm import tqdm
import numpy as np

import wandb

from PIL import Image
from matplotlib import pyplot as plt
import h5py

from scipy.stats import wasserstein_distance

import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset, RandomSampler, Subset
from torch.distributions import Categorical, Bernoulli
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts
from torchvision.transforms.v2.functional import to_pil_image
from torchmetrics import MetricCollection
from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError, MeanAbsolutePercentageError
from torchmetrics.functional import symmetric_mean_absolute_percentage_error

from phi.torch.flow import (  # SoftGeometryMask,; Sphere,; batch,; tensor,
    Box,
    CenteredGrid,
    Noise,
    StaggeredGrid,
    advect,
    diffuse,
    extrapolation,
    fluid,
    jit_compile,
    spatial,
    dual,
    channel,
)
from phi.math import reshaped_native
from phi.math import tensor as phi_tensor
from phi.math import stack as phi_stack

import utils
from data.navier_stokes_dataset import NavierStokesDataset
from models.twod_unet import Unet
from models.twod_unet_cond import Unet as UnetCond
from models.resnet import resnet34, resnet34_cond

def train():
    data_type = "ns-64"
    device_str = "cuda"
    model_type = "resnet18"
    epochs = 30
    batch_size = 1
    lam = 0.30
    gamma = 1.0
    learning_rate = 0.00001
    eps = torch.finfo(torch.float32).eps
    noise_level = 0.0
    ST, FT = 8, 28 # 20 timesteps
    T = FT - ST
    K = 4 # number of timesteps to select actions for
    # calculate dt for each timestep
    dt = 96.0 / 64.0
    sim_cost = 1.0 / T
    buoyancy_factor = 0.5
    diffusion_coef = 0.01
    sim_usage_penalty = 1.0
    final_step_advantage_penalty = 1.0
    wandb_mode = "online"
    # wandb_mode = "disabled"
    wandb_name = f"train-rl-lam-{lam}-multistep-{K}-reinforce-simple-baseline-ns-64-lr-{learning_rate}-ts-20-final-advantage-penalty-1x-sim-penalty-1x-noise-{noise_level}-20-24-test-1"
    wandb_name = f"train-rl-lam-{lam}-multistep-{K}-reinforce-simple-baseline-ns-64-lr-{learning_rate}-ts-20-final-advantage-penalty-1x-sim-penalty-1x-sim-calls-test-1"
    config = {
        "data_type": data_type,
        "history_len": 1,
        "prediction_len": 1
        }
    result_path = Path(f"../train-rl-ns-results-64x64/{wandb_name}/")
    result_path.mkdir(parents=True, exist_ok=True)
    run = wandb.init(project="diff-sci-ml", config=config, group=f"train-rl-{data_type}-64x64", mode=wandb_mode, name=wandb_name)
    dev = torch.device(device_str)
    # load dataset
    ds_path = "../data/ns_data_64_1000.h5"
    print(f"Loading dataset file {ds_path}")
    dataset = NavierStokesDataset(ds_path, provide_velocity=True)
    train_dataset, rl_dataset, val_dataset, test_dataset = random_split(dataset, [0.40, 0.40, 0.10, 0.10], generator=torch.Generator().manual_seed(42))
    actor_model = resnet34_cond(num_classes=K, param_conditioning=False).to(dev)
    actor_optimizer = Adam(actor_model.parameters(), lr=learning_rate)
    # load ml/surrogate model
    surr_model = UnetCond(1, 0, 1, 0, 1, 1, 64, "gelu", ch_mults=[1, 2, 2], is_attn=[False, False, False]).to(dev)
    surr_model = utils.load_model(f"./unetmod_cond_time_emb_shuffle_navier-stokes_64x64_model_checkpoint_49.pt", surr_model)
    surr_model.to(dev)
    surr_model.eval()
    # dataset statistics for 64x64:
    smoke_mean = 0.7322910149367649
    smoke_std = 0.43668617344401106
    smoke_min = 4.5986561758581956e-07
    smoke_max = 2.121663808822632
    rl_dataloader = DataLoader(rl_dataset, batch_size, False, pin_memory=True)
    val_dataloader = DataLoader(val_dataset, 1, False, pin_memory=True)
    best_val_advantage = -math.inf
    for epoch in range(epochs):
        epoch_reward, epoch_policy_loss, epoch_advantage, epoch_advantage_final_penalty, final_step_advantage = 0.0, 0.0, 0.0, 0.0, 0.0
        percent_sim, percent_sim_epoch = 0.0, 0.0
        rl_probs_epoch = torch.zeros(T, device=dev)
        print(f"Starting epoch {epoch}")
        for i, batch in enumerate(tqdm(rl_dataloader, desc="RL Training", unit="batch")):
            print(f"Training batch {i}")
            sim_ids = batch["sim_id"].to(dev)
            smoke = batch["smoke"].unsqueeze(2).to(dev)
            velocity_x = batch["velocity_x"].unsqueeze(2).to(dev)
            velocity_y = batch["velocity_y"].unsqueeze(2).to(dev)
            local_batch_size = smoke.shape[0]
            # standardize data
            smoke_norm = (smoke - smoke_mean) / smoke_std
            # keep running min and max mses for normalization and reset every traj
            mse_min, mse_max = math.inf, -math.inf
            num_sim_calls, num_model_calls = 0.0, 0.0
            num_sim_calls_ten = torch.tensor([0.0], device=dev)
            total_mse, total_orig_mse = 0.0, 0.0
            # init smoke prediction
            last_pred = smoke_norm[:, ST-1, :, :, :].unsqueeze(1)
            init_state = smoke_norm[:, ST-1, :, :, :].unsqueeze(1)
            # lists to save trajectory probs, rewards
            log_probs, rewards, entropies, mse_terms, est_reward_curr_terms, est_reward_next_terms = [], [], [], [], [], []
            # rl_probs_traj = torch.zeros(T, device=dev)
            cum_mse, final_mse = 0.0, 0.0
            start_time = ST
            total_sim_time = 0.0
            total_sim_calls = 0
            for k in range(int(T/K)):
                rl_input = torch.cat((last_pred.squeeze(1), init_state.squeeze(1)), dim=1)
                # call decision model with time embedding
                time_tensor = torch.full((rl_input.shape[0],), start_time-1, device=dev)
                rl_logits = actor_model(rl_input, time_tensor)
                rl_probs = F.sigmoid(rl_logits)
                rl_probs_epoch[k*K:(k+1)*K] += rl_probs.squeeze()
                rl_dist = Bernoulli(probs=rl_probs.squeeze())
                # sample rl action
                rl_action = rl_dist.sample()
                log_probs.append(rl_dist.log_prob(rl_action))
                entropies.append(rl_dist.entropy())
                for t in range(start_time, start_time+K):
                    if rl_action[t-start_time] == 0:
                        # call model
                        num_model_calls += 1
                        model_inputs = last_pred
                        model_gts = smoke_norm[:, t, :, :, :].unsqueeze(1)
                        # print(f"model input mean: {model_inputs.mean()}")
                        while model_inputs.dim() < 5:
                            model_inputs = model_inputs.unsqueeze(0)
                        output = call_model(t, model_inputs, surr_model, dev, noise_level)
                        model_mse = F.mse_loss(output, model_gts)
                        orig_mse = F.mse_loss((output * smoke_std) + smoke_mean, (model_gts * smoke_std) + smoke_mean).item()
                        last_pred = output.detach()
                        mse = model_mse.item()
                        reward = -mse
                    elif rl_action[t-start_time] == 1:
                        # call sim
                        num_sim_calls += 1
                        num_sim_calls_ten = num_sim_calls_ten + 1.0
                        sim_output_smoke = torch.zeros_like(smoke[:, 0, ...], device='cpu')
                        sim_output_velocity_x = torch.zeros_like(velocity_x[:, 0, ...], device='cpu')
                        sim_output_velocity_y = torch.zeros_like(velocity_y[:, 0, ...], device='cpu')
                        # convert tensors to phiflow grids
                        last_pred_unnorm = (last_pred.detach() * smoke_std + smoke_mean).clone()
                        # last_pred_unnorm = last_pred * smoke_std + smoke_mean
                        smoke_grid = torch_to_phi_centered(last_pred_unnorm.squeeze())
                        velocity_grid = torch_to_phi_staggered(velocity_x[:, t-1].squeeze(), velocity_y[:, t-1].squeeze())
                        # call sim
                        start_sim_time = time()
                        smoke_grid, velocity_grid = call_sim(smoke_grid, velocity_grid, dt, buoyancy_factor, diffusion_coef)
                        end_sim_time = time()
                        total_sim_time += end_sim_time - start_sim_time
                        total_sim_calls += 1
                        # convert phiflow grids back to tensors
                        sim_output_smoke[:, ...] = phi_centered_to_torch(smoke_grid)
                        sim_output_velocity_x[:, :, :], sim_output_velocity_y[:, :, :] = phi_staggered_to_torch(velocity_grid)
                        # move tensors back to correct device
                        sim_output_smoke = sim_output_smoke.to(dev)
                        # normalize tensor
                        sim_output_smoke = (sim_output_smoke - smoke_mean) / smoke_std
                        sim_mse = F.mse_loss(sim_output_smoke[:, ...], smoke_norm[:, t, ...]).item()
                        orig_mse = F.mse_loss(sim_output_smoke[:, ...] * smoke_std + smoke_mean, smoke[:, t, ...]).item()
                        last_pred = sim_output_smoke.unsqueeze(1).detach()
                        mse = sim_mse
                        reward = -mse
                    cum_mse += mse # keep track of cumulative mse
                    if t == FT-1:
                        final_mse = mse
                    mse_terms.append(mse)
                    rewards.append(reward)
                    total_mse += mse
                    total_orig_mse += orig_mse
                start_time = start_time + K
            percent_sim = num_sim_calls / T / local_batch_size
            percent_sim_epoch += percent_sim
            total_reward = sum(rewards)
            baseline_rewards = []
            print("Starting baseline reward calculation")
            with torch.no_grad():
                # init smoke prediction
                last_pred = smoke_norm[:, ST-1, :, :, :].unsqueeze(1)
                sim_call_choices = np.random.default_rng().choice(range(ST, FT), int(num_sim_calls), replace=False)
                for t in range(ST, FT):
                    if t not in sim_call_choices:
                        # call surr model
                        model_inputs = last_pred
                        model_gts = smoke_norm[:, t, :, :, :].unsqueeze(1)
                        output = call_model(t, model_inputs, surr_model, dev, noise_level)
                        model_mse = F.mse_loss(output, model_gts)
                        orig_mse = F.mse_loss((output * smoke_std) + smoke_mean, (model_gts * smoke_std) + smoke_mean).item()
                        last_pred = output.detach()
                        mse = model_mse.item()
                        reward = -mse
                    elif t in sim_call_choices:
                        # call sim
                        sim_output_smoke = torch.zeros_like(smoke[:, 0, ...], device='cpu')
                        sim_output_velocity_x = torch.zeros_like(velocity_x[:, 0, ...], device='cpu')
                        sim_output_velocity_y = torch.zeros_like(velocity_y[:, 0, ...], device='cpu')
                        # convert tensors to phiflow grids
                        last_pred_unnorm = (last_pred.detach() * smoke_std + smoke_mean).clone()
                        smoke_grid = torch_to_phi_centered(last_pred_unnorm.squeeze())
                        velocity_grid = torch_to_phi_staggered(velocity_x[:, t-1].squeeze(), velocity_y[:, t-1].squeeze())
                        # call sim
                        smoke_grid, velocity_grid = call_sim(smoke_grid, velocity_grid, dt, buoyancy_factor, diffusion_coef)
                        # convert phiflow grids back to tensors
                        sim_output_smoke[:, ...] = phi_centered_to_torch(smoke_grid)
                        sim_output_velocity_x[:, :, :], sim_output_velocity_y[:, :, :] = phi_staggered_to_torch(velocity_grid)
                        # move tensors back to correct device
                        sim_output_smoke = sim_output_smoke.to(dev)
                        # normalize tensor
                        sim_output_smoke = (sim_output_smoke - smoke_mean) / smoke_std
                        sim_mse = F.mse_loss(sim_output_smoke[:, ...], smoke_norm[:, t, ...]).item()
                        orig_mse = F.mse_loss(sim_output_smoke[:, ...] * smoke_std + smoke_mean, smoke[:, t, ...]).item()
                        last_pred = sim_output_smoke.unsqueeze(1).detach()
                        mse = sim_mse
                        reward = -mse
                    baseline_rewards.append(reward)
            # subtract from baseline to get advantage
            final_rewards = [r - br for r, br in zip(rewards, baseline_rewards)]
            epoch_advantage += sum(final_rewards)
            final_step_advantage += final_rewards[-1]
            # penalize final step advantage more than others
            final_rewards[-1] *= final_step_advantage_penalty
            epoch_advantage_final_penalty += sum(final_rewards)
            # penalize rewards for being far away from lambda percent sim usage
            final_rewards = [r - (np.abs(percent_sim - lam) * sim_usage_penalty) for r in final_rewards]
            returns = deque(maxlen=T)
            for t in range(T)[::-1]:
                return_t = (returns[0] if len(returns) > 0 else 0)
                returns.appendleft(return_t + final_rewards[t])
            returns = torch.tensor(returns)
            policy_loss = []
            # concatenate log probs together if predicting for separate time windows
            log_probs = torch.cat(log_probs, 0)
            assert(len(log_probs) == len(returns))
            for log_prob, disc_return in zip(log_probs, returns):
                policy_loss.append(-log_prob * disc_return)
            policy_loss = sum(policy_loss)
            # update model
            actor_optimizer.zero_grad()
            policy_loss.backward()
            actor_optimizer.step()
            epoch_policy_loss += policy_loss.item()
            epoch_reward += sum(final_rewards)
        rl_probs_epoch /= len(rl_dataloader)
        epoch_reward /= len(rl_dataloader)
        percent_sim_epoch /= len(rl_dataloader)
        epoch_policy_loss /= len(rl_dataloader)
        epoch_advantage /= len(rl_dataloader)
        epoch_advantage_final_penalty /= len(rl_dataloader)
        final_step_advantage /= len(rl_dataloader)
        wandb.log({
            "epoch": epoch,
            "reward": epoch_reward,
            "percent_sim": percent_sim_epoch,
            "policy_loss": epoch_policy_loss,
            "advantage": epoch_advantage,
            "final_step_advantage": final_step_advantage,
            "advantage_final_penalty": epoch_advantage_final_penalty,
        }, step=epoch)
        print(f"epoch advantage: {epoch_advantage}")
        # log bar chart of sim calls
        plt.bar(range(T), rl_probs_epoch.detach().cpu().numpy())
        wandb.log({"policy_probs": plt}, step=epoch)
        # Validation loop
        with torch.no_grad():
            val_epoch_reward, val_epoch_policy_loss, val_epoch_advantage, val_epoch_advantage_final_penalty, val_final_step_advantage = 0.0, 0.0, 0.0, 0.0, 0.0
            val_percent_sim, val_percent_sim_epoch = 0.0, 0.0
            val_rl_probs_epoch = torch.zeros(T, device=dev)
            for i, batch in enumerate(tqdm(val_dataloader, desc="Validation", unit="batch")):
                print(f"Training batch {i}")
                sim_ids = batch["sim_id"].to(dev)
                smoke = batch["smoke"].unsqueeze(2).to(dev)
                velocity_x = batch["velocity_x"].unsqueeze(2).to(dev)
                velocity_y = batch["velocity_y"].unsqueeze(2).to(dev)
                local_batch_size = smoke.shape[0]
                # standardize data
                smoke_norm = (smoke - smoke_mean) / smoke_std
                # keep running min and max mses for normalization and reset every traj
                mse_min, mse_max = math.inf, -math.inf
                val_num_sim_calls, val_num_model_calls = 0.0, 0.0
                val_num_sim_calls_ten = torch.tensor([0.0], device=dev)
                total_mse, total_orig_mse = 0.0, 0.0
                # init smoke prediction
                last_pred = smoke_norm[:, ST-1, :, :, :].unsqueeze(1)
                init_state = smoke_norm[:, ST-1, :, :, :].unsqueeze(1)
                # lists to save trajectory probs, rewards
                log_probs, rewards, entropies, mse_terms, est_reward_curr_terms, est_reward_next_terms = [], [], [], [], [], []
                cum_mse, final_mse = 0.0, 0.0
                start_time = ST
                for k in range(int(T/K)):
                    rl_input = torch.cat((last_pred.squeeze(1), init_state.squeeze(1)), dim=1)
                    # call decision model with time embedding
                    time_tensor = torch.full((rl_input.shape[0],), start_time-1, device=dev)
                    rl_logits = actor_model(rl_input, time_tensor)
                    rl_probs = F.sigmoid(rl_logits)
                    val_rl_probs_epoch[k*K:(k+1)*K] += rl_probs.squeeze()
                    rl_dist = Bernoulli(probs=rl_probs.squeeze())
                    # sample rl action
                    rl_action = rl_dist.sample()
                    log_probs.append(rl_dist.log_prob(rl_action))
                    entropies.append(rl_dist.entropy())
                    for t in range(start_time, start_time+K):
                        if rl_action[t-start_time] == 0:
                            # call model
                            val_num_model_calls += 1
                            model_inputs = last_pred
                            model_gts = smoke_norm[:, t, :, :, :].unsqueeze(1)
                            output = call_model(t, model_inputs, surr_model, dev, noise_level)
                            model_mse = F.mse_loss(output, model_gts)
                            orig_mse = F.mse_loss((output * smoke_std) + smoke_mean, (model_gts * smoke_std) + smoke_mean).item()
                            last_pred = output.detach()
                            mse = model_mse.item()
                            reward = -mse
                        elif rl_action[t-start_time] == 1:
                            # call sim
                            val_num_sim_calls += 1
                            val_num_sim_calls_ten = val_num_sim_calls_ten + 1.0
                            sim_output_smoke = torch.zeros_like(smoke[:, 0, ...], device='cpu')
                            sim_output_velocity_x = torch.zeros_like(velocity_x[:, 0, ...], device='cpu')
                            sim_output_velocity_y = torch.zeros_like(velocity_y[:, 0, ...], device='cpu')
                            # convert tensors to phiflow grids
                            last_pred_unnorm = (last_pred.detach() * smoke_std + smoke_mean).clone()
                            # last_pred_unnorm = last_pred * smoke_std + smoke_mean
                            smoke_grid = torch_to_phi_centered(last_pred_unnorm.squeeze())
                            velocity_grid = torch_to_phi_staggered(velocity_x[:, t-1].squeeze(), velocity_y[:, t-1].squeeze())
                            # call sim
                            smoke_grid, velocity_grid = call_sim(smoke_grid, velocity_grid, dt, buoyancy_factor, diffusion_coef)
                            # convert phiflow grids back to tensors
                            sim_output_smoke[:, ...] = phi_centered_to_torch(smoke_grid)
                            sim_output_velocity_x[:, :, :], sim_output_velocity_y[:, :, :] = phi_staggered_to_torch(velocity_grid)
                            # move tensors back to correct device
                            sim_output_smoke = sim_output_smoke.to(dev)
                            # normalize tensor
                            sim_output_smoke = (sim_output_smoke - smoke_mean) / smoke_std
                            sim_mse = F.mse_loss(sim_output_smoke[:, ...], smoke_norm[:, t, ...]).item()
                            orig_mse = F.mse_loss(sim_output_smoke[:, ...] * smoke_std + smoke_mean, smoke[:, t, ...]).item()
                            last_pred = sim_output_smoke.unsqueeze(1).detach()
                            mse = sim_mse
                            reward = -mse
                        # min_max_norm_mse = mse # don't min max norm mse
                        cum_mse += mse # keep track of cumulative mse
                        if t == FT-1:
                            final_mse = mse
                        mse_terms.append(mse)
                        rewards.append(reward)
                        total_mse += mse
                        total_orig_mse += orig_mse
                    start_time = start_time + K
                val_percent_sim = val_num_sim_calls / T / local_batch_size
                val_percent_sim_epoch += val_percent_sim
                total_reward = sum(rewards)
                baseline_rewards = []
                print("Starting baseline reward calculation")
                # init smoke prediction
                last_pred = smoke_norm[:, ST-1, :, :, :].unsqueeze(1)
                sim_call_choices = np.random.default_rng().choice(range(ST, FT), int(val_num_sim_calls), replace=False)
                for t in range(ST, FT):
                    if t not in sim_call_choices:
                        # call surr model
                        model_inputs = last_pred
                        model_gts = smoke_norm[:, t, :, :, :].unsqueeze(1)
                        output = call_model(t, model_inputs, surr_model, dev, noise_level)
                        model_mse = F.mse_loss(output, model_gts)
                        orig_mse = F.mse_loss((output * smoke_std) + smoke_mean, (model_gts * smoke_std) + smoke_mean).item()
                        last_pred = output.detach()
                        mse = model_mse.item()
                        reward = -mse
                    elif t in sim_call_choices:
                        # call sim
                        sim_output_smoke = torch.zeros_like(smoke[:, 0, ...], device='cpu')
                        sim_output_velocity_x = torch.zeros_like(velocity_x[:, 0, ...], device='cpu')
                        sim_output_velocity_y = torch.zeros_like(velocity_y[:, 0, ...], device='cpu')
                        # convert tensors to phiflow grids
                        last_pred_unnorm = (last_pred.detach() * smoke_std + smoke_mean).clone()
                        smoke_grid = torch_to_phi_centered(last_pred_unnorm.squeeze())
                        velocity_grid = torch_to_phi_staggered(velocity_x[:, t-1].squeeze(), velocity_y[:, t-1].squeeze())
                        # call sim
                        smoke_grid, velocity_grid = call_sim(smoke_grid, velocity_grid, dt, buoyancy_factor, diffusion_coef)
                        # convert phiflow grids back to tensors
                        sim_output_smoke[:, ...] = phi_centered_to_torch(smoke_grid)
                        sim_output_velocity_x[:, :, :], sim_output_velocity_y[:, :, :] = phi_staggered_to_torch(velocity_grid)
                        # move tensors back to correct device
                        sim_output_smoke = sim_output_smoke.to(dev)
                        # normalize tensor
                        sim_output_smoke = (sim_output_smoke - smoke_mean) / smoke_std
                        sim_mse = F.mse_loss(sim_output_smoke[:, ...], smoke_norm[:, t, ...]).item()
                        orig_mse = F.mse_loss(sim_output_smoke[:, ...] * smoke_std + smoke_mean, smoke[:, t, ...]).item()
                        last_pred = sim_output_smoke.unsqueeze(1).detach()
                        mse = sim_mse
                        # print(f"sim mse: {sim_mse}")
                        reward = -mse
                    baseline_rewards.append(reward)
                # subtract from baseline to get advantage
                final_rewards = [r - br for r, br in zip(rewards, baseline_rewards)]
                val_epoch_advantage += sum(final_rewards)
                val_final_step_advantage += final_rewards[-1]
                # penalize final step advatage more than others
                final_rewards[-1] *= final_step_advantage_penalty
                val_epoch_advantage_final_penalty += sum(final_rewards)
                # penalize rewards for being far away from lambda percent sim usage
                final_rewards = [r - (np.abs(percent_sim - lam) * sim_usage_penalty) for r in final_rewards]
                returns = deque(maxlen=T)
                for t in range(T)[::-1]:
                    return_t = (returns[0] if len(returns) > 0 else 0)
                    returns.appendleft(return_t + final_rewards[t])
                returns = torch.tensor(returns)
                val_policy_loss = []
                # concatenate log probs together if predicting for time windows
                log_probs = torch.cat(log_probs, 0)
                assert(len(log_probs) == len(returns))
                for log_prob, disc_return in zip(log_probs, returns):
                    val_policy_loss.append(-log_prob * disc_return)
                val_policy_loss = sum(val_policy_loss)
                val_epoch_policy_loss += val_policy_loss.item()
                val_epoch_reward += sum(final_rewards)
            val_epoch_reward /= len(val_dataloader)
            val_percent_sim_epoch /= len(val_dataloader)
            val_epoch_policy_loss /= len(val_dataloader)
            val_epoch_advantage /= len(val_dataloader)
            val_final_step_advantage /= len(val_dataloader)
            val_epoch_advantage_final_penalty /= len(val_dataloader)
            wandb.log({
                "val_reward": val_epoch_reward,
                "val_percent_sim": val_percent_sim_epoch,
                "val_policy_loss": val_epoch_policy_loss,
                "val_advantage": val_epoch_advantage,
                "val_final_step_advantage": val_final_step_advantage,
                "val_advantage_final_penalty": val_epoch_advantage_final_penalty,
            }, step=epoch)
        if val_epoch_advantage > best_val_advantage:
            best_val_advantage = val_epoch_advantage
            print(f"Saving best models with val advantage {val_epoch_advantage}")
            utils.save_model_checkpoint(epoch, device_str, actor_model, actor_optimizer, result_path, f"actor-{lam}-reinforce-lr-{learning_rate}", suffix="best", overwrite=True)
        # save model every 5 epochs and on final epoch
        if epoch % 5 == 0 or (epoch == epochs-1):
            utils.save_model_checkpoint(epoch, device_str, actor_model, actor_optimizer, result_path, f"actor-{lam}-reinforce-lr-{learning_rate}")
    wandb.finish()

def convert_data_call_sim(sim_idx, t, smoke, velocity_x, velocity_y, smoke_mean, smoke_std, dt, buoyancy_factor, diffusion_coef):
    # convert tensors to phiflow grids
    smoke_grid = torch_to_phi_centered(smoke[sim_idx, t-1].squeeze())
    velocity_grid = torch_to_phi_staggered(velocity_x[sim_idx, t-1].squeeze(), velocity_y[sim_idx, t-1].squeeze())
    smoke_grid, velocity_grid = call_sim(smoke_grid, velocity_grid, dt, buoyancy_factor, diffusion_coef)
    # convert phiflow grids back to tensors
    sim_output_smoke = phi_centered_to_torch(smoke_grid)
    # normalize smoke tensor, have to use sim_idx.item() here because sim_idx is tensor from a different device
    sim_output_smoke = (sim_output_smoke - smoke_mean) / smoke_std
    sim_output_velocity_x, sim_output_velocity_y = phi_staggered_to_torch(velocity_grid)
    return sim_output_smoke, sim_output_velocity_x, sim_output_velocity_y

def torch_to_phi_centered(data):
    phi_ten = phi_tensor(data.transpose(1, 0), spatial('x,y')) # have to transpose the data to get x and y in the right dims
    phi_grid = CenteredGrid(phi_ten, extrapolation.BOUNDARY, Box['x,y', 0 : 32.0, 0 : 32.0])
    return phi_grid

def phi_centered_to_torch(data):
    data_np = data.values.numpy('x,y')
    return torch.from_numpy(data_np.transpose(1,0)) # transpose to get into y, x order

def torch_to_phi_staggered(data_x, data_y):
    # expand torch tensors to be 1 larger in each dimension
    data_x = torch.cat((data_x, data_x[:, -1].reshape(data_x.shape[1], -1)), 1)
    data_x = torch.cat((data_x, data_x[-1, :].reshape(-1, data_x.shape[1])), 0)
    data_y = torch.cat((data_y, data_y[:, -1].reshape(data_y.shape[1], -1)), 1)
    data_y = torch.cat((data_y, data_y[-1, :].reshape(-1, data_y.shape[1])), 0)
    # stack tensors for velocity vector field
    stacked = torch.stack((data_x.transpose(1,0), data_y.transpose(1,0)), dim=2)
    stacked_ten = phi_tensor(stacked, spatial('x,y'), channel('vector'))
    # create staggered grid
    phi_grid = StaggeredGrid(stacked_ten, extrapolation.ZERO, Box['x,y', 0 : 32.0, 0 : 32.0])
    return phi_grid

def phi_staggered_to_torch(data):
    field = data.staggered_tensor().numpy('x,y,vector')
    field_x = torch.from_numpy(field[:-1, :-1, 0].transpose(1,0))
    field_y = torch.from_numpy(field[:-1, :-1, 1].transpose(1,0))
    return field_x, field_y

# last_pred should be shaped B x T x C x H x W
def call_model(timestep: int, last_pred, model, dev, noise_level):
    # print(f"Calling model for timestep {timestep}")
    timesteps = torch.full((last_pred.shape[0],), timestep, device=dev)
    # regular model call
    output = model(last_pred, timesteps)
    return output

@jit_compile
def call_sim(smoke_grid, velocity_grid, dt, buoyancy_factor, diffusion_coef):
    # convert torch tensors to phi flow tensors
    smoke_grid = advect.semi_lagrangian(smoke_grid, velocity_grid, dt) # default dt is 1.5
    buoyancy_force = (smoke_grid * (0, buoyancy_factor)).at(velocity_grid)  # resamples smoke to velocity sample points
    velocity_grid = advect.semi_lagrangian(velocity_grid, velocity_grid, dt) + dt * buoyancy_force
    velocity_grid = diffuse.explicit(velocity_grid, diffusion_coef, dt)
    velocity_grid, _ = fluid.make_incompressible(velocity_grid)
    return smoke_grid, velocity_grid

if __name__ == "__main__":
    train()



