
from pathlib import Path
import os
from time import time

from tqdm import tqdm

import numpy as np
import pandas as pd
import xarray as xr

import matplotlib.colors as colors
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset, random_split
from torch.distributions import Categorical, Bernoulli
from torchmetrics import MetricCollection
from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError, MeanAbsolutePercentageError

import h5py
import wandb

from phi.torch.flow import (  # SoftGeometryMask,; Sphere,; batch,; tensor,
    Box,
    CenteredGrid,
    Noise,
    StaggeredGrid,
    advect,
    diffuse,
    extrapolation,
    fluid,
    jit_compile,
    spatial,
    channel,
)
from phi.math import tensor as phi_tensor

import utils
from data.navier_stokes_dataset import NavierStokesDataset
from models.twod_unet import Unet
from models.twod_unet_cond import Unet as UnetCond
from neuralop.models import FNO
from models.resnet import resnet18, resnet18_cond, resnet34, resnet34_cond

def test(model_type, data_type, device, epoch):
    print("Starting test")
    dev = torch.device(device)
    lam = 0.30
    learning_rate = 0.00001
    ST, FT = 8, 28 # 20 timesteps
    T = FT - ST
    # calculate dt for each timestep
    dt = 96.0 / 64.0
    buoyancy_factor = 0.5
    diffusion_coef = 0.01
    noise_level = 0.0
    K = 10 # number of times to run random baseline
    S = 4 # number of timesteps to select actions for
    wandb_group = f"test-rl-{data_type}-64x64"
    wandb_name = f"test-rl-multistep-{S}-reinforce-ns-64-lam-{lam}-lr-{learning_rate}-epoch-best-val-bs-1-ts-{T}-simple-baseline-final-advantage-penalty-1x-sim-penalty-1x-resnet34-best-testval-set-1"
    # wandb_mode = "online"
    wandb_mode = "disabled"
    config = {
        "data_type": data_type,
        "history_len": 1,
        "prediction_len": 1
        }
    run = wandb.init(project="diff-sci-ml", config=config, group=wandb_group, mode=wandb_mode, name=wandb_name)
    result_path = Path(f"../rl-test-results-ns/{wandb_name}")
    result_path.mkdir(parents=True, exist_ok=False)
    # load dataset
    ds_path = "../data/ns_data_64_1000.h5"
    print(f"Loading dataset file {ds_path}")
    dataset = NavierStokesDataset(ds_path, provide_velocity=True)
    train_dataset, rl_dataset, val_dataset, test_dataset = random_split(dataset, [0.40, 0.40, 0.10, 0.10], generator=torch.Generator().manual_seed(42))
    # combine test and val datasets
    test_val_dataset = ConcatDataset([test_dataset, val_dataset])
    # dataset statistics for 64x64:
    smoke_mean = 0.7322910149367649
    smoke_std = 0.43668617344401106
    smoke_min = 4.5986561758581956e-07
    smoke_max = 2.121663808822632
    train_dataloader = DataLoader(train_dataset, 1, False, pin_memory=True)
    rl_dataloader = DataLoader(rl_dataset, 1, False, pin_memory=True)
    # NOTE our batch size needs to be 1 here because the code below is written to test only one trajectory at a time
    test_dataloader = DataLoader(test_dataset, 1, False, pin_memory=True)
    test_val_dataloader = DataLoader(test_val_dataset, 1, False, pin_memory=True)
    # load unet surrogate model
    model = UnetCond(1, 0, 1, 0, 1, 1, 64, "gelu", ch_mults=[1, 2, 2], is_attn=[False, False, False]).to(dev)
    model = utils.load_model(f"./unetmod_cond_time_emb_shuffle_navier-stokes_64x64_model_checkpoint_49.pt", model)
    model.to(dev)
    model.eval()
    # load fno surrogate model
    fno_model = FNO(n_modes=(27, 27), hidden_channels=64, in_channels=2, out_channels=1)
    fno_model = utils.load_model("./fno_time_static_shuffle_navier-stokes_64x64_model_checkpoint_49.pt", fno_model)
    fno_model = fno_model.to(dev)
    fno_model.eval()
    # load actor
    actor_model = resnet34_cond(num_classes=S, param_conditioning=False).to(dev)
    model_path = "../train-rl-ns-results-64x64/train-rl-lam-0.3-multistep-4-reinforce-simple-baseline-ns-64-lr-1e-05-ts-20-final-advantage-penalty-1x-sim-penalty-1x-test-1/actor-0.3-reinforce-lr-1e-05_model_checkpoint_best.pt"
    print(f"Loading model path {model_path}")
    actor_model = utils.load_model(model_path, actor_model)
    actor_model.to(dev)
    actor_model = actor_model.eval()
    print("Models loaded")
    id_action_trajectories = []
    id_mses = []
    id_orig_mses = []
    id_surr_mses, id_surr_orig_mses = [], []
    id_fno_mses, id_fno_orig_mses = [], []
    id_sim_calls = []
    id_entropies = []
    id_log_odds = []
    id_final_mses, id_cumulative_mses = [], []
    id_gts, id_unet_preds, id_fno_preds, id_baseline_preds, id_rl_preds = [], [], [], [], []
    unet_times, fno_times, rl_times = [], [], []
    sim_call_times = []
    all_baseline_mses = np.zeros((len(test_val_dataset), K, T))
    all_orig_baseline_mses = np.zeros((len(test_val_dataset), K, T))
    baseline_final_mses, baseline_cumulative_mses = [], []
    baseline_actions = np.zeros((len(test_val_dataset), K, T))
    # indices to make plots for
    plot_idxs = [0, 1, 2, 3, 4, 5]
    with torch.no_grad():
        # right now batch size should remain 1 for testing
        for i, batch in enumerate(test_val_dataloader):
            print(f"Starting batch/trajectory {i}")
            unet_time, fno_time, rl_time = 0., 0., 0.
            # load data to devices
            sim_ids = batch["sim_id"].to(dev)
            smoke = batch["smoke"].unsqueeze(2).to(dev)
            velocity_x = batch["velocity_x"].unsqueeze(2).to(dev)
            velocity_y = batch["velocity_y"].unsqueeze(2).to(dev)
            local_batch_size = smoke.shape[0]
            # standardize data
            smoke_norm = (smoke - smoke_mean) / smoke_std
            smoke_norm_min = smoke_norm.min()
            smoke_norm_max = smoke_norm.max()
            # init smoke prediction B x C x H x W
            last_pred = smoke_norm[:, ST-1, :, :, :].unsqueeze(1)
            init_state = smoke_norm[:, ST-1, :, :, :].unsqueeze(1)
            # make output directory for plotting
            if i in plot_idxs:
                output_path = Path(f"{result_path}/traj-{i}")
                output_path.mkdir(parents=True, exist_ok=True)
            num_sim_calls, num_model_calls = 0.0, 0.0
            num_sim_calls_ten = torch.tensor([0.0], device=dev)
            action_trajectory, mses, orig_mses, surr_mses, surr_orig_mses, sim_calls, entropies, log_odds, fno_mses, fno_orig_mses = [], [], [], [], [], [], [], [], [], []
            # init smoke prediction
            last_pred = smoke_norm[:, ST-1, :, :, :].unsqueeze(1)
            init_state = smoke_norm[:, ST-1, :, :, :].unsqueeze(1)
            last_pred_surr = smoke_norm[:, ST-1, :, :, :].unsqueeze(1)
            last_pred_fno = smoke_norm[:, ST-1, :, :, :]
            start_time = ST
            # create tensors for plotting
            gts = torch.zeros((T, smoke.shape[3], smoke.shape[4]))
            surr_preds = torch.zeros((T, smoke.shape[3], smoke.shape[4]))
            rl_preds = torch.zeros((T, smoke.shape[3], smoke.shape[4]))
            baseline_preds = torch.zeros((T, smoke.shape[3], smoke.shape[4]))
            for s in range(int(T/S)):
                while last_pred.dim() < 5:
                    last_pred = last_pred.unsqueeze(0)
                rl_input = torch.cat((last_pred.squeeze(1), init_state.squeeze(1)), dim=1)
                # make time embedding
                time_tensor = torch.full((rl_input.shape[0],), start_time-1, device=dev)
                # call decision model with time embedding
                rl_logits = actor_model(rl_input, time_tensor)
                rl_probs = F.sigmoid(rl_logits)
                rl_dist = Bernoulli(probs=rl_probs.squeeze())
                # sample rl action
                rl_action = rl_dist.sample()
                action_trajectory.append(rl_action.tolist())
                print(f"Running rl model on traj {i}")
                for t in range(start_time, start_time+S):
                    # always call surrogate to get metrics for surrogate only model
                    model_inputs = last_pred_surr
                    while model_inputs.dim() < 5:
                        model_inputs = model_inputs.unsqueeze(0)
                    model_gts = smoke_norm[:, t, :, :, :].unsqueeze(1)
                    call_time_start = time()
                    surr_output = call_model(t, model_inputs, model, dev, noise_level)
                    call_time_end = time()
                    unet_time += (call_time_end - call_time_start)
                    surr_mse = F.mse_loss(surr_output, model_gts)
                    surr_orig_mse = F.mse_loss((surr_output * smoke_std) + smoke_mean, (model_gts * smoke_std) + smoke_mean)
                    surr_mses.append(surr_mse.item())
                    surr_orig_mses.append(surr_orig_mse.item())
                    # set new last pred surr
                    last_pred_surr = surr_output
                    # call fno model
                    fno_model_inputs = last_pred_fno
                    # for now scale static time to be between 0 and 1 for fno
                    fno_time_map = torch.full((1, fno_model_inputs.shape[1], fno_model_inputs.shape[2], fno_model_inputs.shape[3]), t/FT, dtype=torch.float, device=dev)
                    fno_model_inputs = torch.cat((fno_model_inputs, fno_time_map), dim=1)
                    # noise free fno call
                    call_time_start = time()
                    fno_output = fno_model(fno_model_inputs)
                    call_time_end = time()
                    fno_time += (call_time_end - call_time_start)
                    fno_mse = F.mse_loss(fno_output, model_gts.squeeze(0))
                    fno_orig_mse = F.mse_loss((fno_output * smoke_std) + smoke_mean, (model_gts.squeeze(0) * smoke_std) + smoke_mean)
                    fno_mses.append(fno_mse.item())
                    fno_orig_mses.append(fno_orig_mse.item())
                    last_pred_fno = fno_output
                    # call surrogate model if required
                    if rl_action[t-start_time] == 0:
                        num_model_calls += 1
                        model_inputs = last_pred
                        while model_inputs.dim() < 5:
                            model_inputs = model_inputs.unsqueeze(0)
                        model_gts = smoke_norm[:, t, :, :, :].unsqueeze(1)
                        call_time_start = time()
                        output = call_model(t, model_inputs, model, dev, noise_level)
                        call_time_end = time()
                        rl_time += (call_time_end - call_time_start)
                        mse = F.mse_loss(output, model_gts)
                        orig_mse = F.mse_loss((output * smoke_std) + smoke_mean, (model_gts * smoke_std) + smoke_mean)
                        mses.append(mse.item())
                        orig_mses.append(orig_mse.item())
                        # set last pred to current pred
                        last_pred = output.detach()
                    elif rl_action[t-start_time] == 1:
                        num_sim_calls += 1
                        num_sim_calls_ten = num_sim_calls_ten + 1.0
                        # convert tensors to phiflow grids
                        last_pred_unnorm = (last_pred.detach() * smoke_std + smoke_mean).clone()
                        # create sim inputs
                        smoke_grid = torch_to_phi_centered(last_pred_unnorm.squeeze())
                        velocity_grid = torch_to_phi_staggered(velocity_x[:, t-1].squeeze(), velocity_y[:, t-1].squeeze())
                        call_time_start = time()
                        smoke_grid, velocity_grid = call_sim(smoke_grid, velocity_grid, dt, buoyancy_factor, diffusion_coef)
                        call_time_end = time()
                        rl_time += (call_time_end - call_time_start)
                        sim_call_times.append((call_time_end - call_time_start))
                        # convert phiflow grids back to tensors
                        sim_output_smoke = phi_centered_to_torch(smoke_grid).to(dev)
                        # normalize smoke sim output
                        sim_output_smoke = (sim_output_smoke - smoke_mean) / smoke_std
                        mse = F.mse_loss(sim_output_smoke, smoke_norm[:, t, :, ...].squeeze())
                        orig_mse = F.mse_loss(sim_output_smoke * smoke_std + smoke_mean, smoke[:, t, :, ...].squeeze())
                        mses.append(mse.item())
                        orig_mses.append(orig_mse.item())
                        # set last pred to current sim pred
                        last_pred = sim_output_smoke.detach()
                    sim_calls.append(num_sim_calls)
                    # save data for plotting
                    if i in plot_idxs:
                        # original scale preds
                        gts[t-ST] = (smoke_norm[:, t, ...] * smoke_std + smoke_mean).squeeze()
                        surr_preds[t-ST] = (last_pred_surr * smoke_std + smoke_mean).squeeze()
                        rl_preds[t-ST] = (last_pred * smoke_std + smoke_mean).squeeze()
                start_time = start_time + S
            print(f"Sim called {num_sim_calls} times")
            # save rl model metrics
            id_action_trajectories.append(action_trajectory)
            id_mses.append(mses)
            id_orig_mses.append(orig_mses)
            id_surr_mses.append(surr_mses)
            id_surr_orig_mses.append(surr_orig_mses)
            id_fno_mses.append(fno_mses)
            id_fno_orig_mses.append(fno_orig_mses)
            id_sim_calls.append(sim_calls)
            id_final_mses.append([mses[-1]])
            id_cumulative_mses.append([np.sum(mses, axis=0)])
            id_entropies.append(entropies)
            id_log_odds.append(log_odds)
            unet_times.append(unet_time)
            fno_times.append(fno_time)
            rl_times.append(rl_time)
            # run the same traj with a random policy k times with same number of sim calls
            print(f"Running random baseline on traj {i}")
            final_mses, cum_mses = [], []
            for k in range(K):
                # print(f"k = {k}")
                baseline_mses = []
                last_pred = smoke_norm[:, ST-1, :, :, :].unsqueeze(0)
                # last_pred = smoke_norm[:, ST-K:ST, :, :, :]
                sim_call_choices = np.random.default_rng().choice(range(ST, FT), int(num_sim_calls), replace=False)
                # print(f"Sim call choices: {sim_call_choices}")
                for t in range(ST, FT):
                    if t not in sim_call_choices:
                        # call model
                        model_inputs = last_pred
                        model_gts = smoke_norm[:, t, :, :, :].unsqueeze(0)
                        output = call_model(t, model_inputs, model, dev, noise_level)
                        mse = F.mse_loss(output, model_gts)
                        orig_mse = F.mse_loss((output * smoke_std) + smoke_mean, (model_gts * smoke_std) + smoke_mean)
                        baseline_mses.append(mse.item())
                        all_baseline_mses[i, k, t-ST] = mse.item()
                        all_orig_baseline_mses[i, k, t-ST] = orig_mse.item()
                        # set last pred to current pred
                        last_pred = output
                    elif t in sim_call_choices:
                        # call sim
                        # convert tensors to phiflow grids
                        last_pred_unnorm = last_pred.detach() * smoke_std + smoke_mean
                        # create sim inputs
                        smoke_grid = torch_to_phi_centered(last_pred_unnorm.squeeze())
                        velocity_grid = torch_to_phi_staggered(velocity_x[:, t-1].squeeze(), velocity_y[:, t-1].squeeze())
                        smoke_grid, velocity_grid = call_sim(smoke_grid, velocity_grid, dt, buoyancy_factor, diffusion_coef)
                        # convert phiflow grids back to tensors
                        sim_output_smoke = phi_centered_to_torch(smoke_grid).to(dev)
                        # normalize smoke sim output
                        sim_output_smoke = (sim_output_smoke - smoke_mean) / smoke_std
                        mse = F.mse_loss(sim_output_smoke, smoke_norm[:, t, :, ...].squeeze())
                        orig_mse = F.mse_loss(sim_output_smoke * smoke_std + smoke_mean, smoke[:, t, :, ...].squeeze())
                        baseline_mses.append(mse.item())
                        all_baseline_mses[i, k, t-ST] = mse.item()
                        all_orig_baseline_mses[i, k, t-ST] = orig_mse.item()
                        baseline_actions[i, k, t-ST] += 1.0
                        # set last pred to current sim pred
                        last_pred = sim_output_smoke[None, None, None, ...]
                    # save data for plotting
                    if i in plot_idxs and k == 0:
                        # original scale preds
                        baseline_preds[t-ST] = (last_pred * smoke_std + smoke_mean).squeeze()
                final_mses.append(baseline_mses[-1])
                cum_mses.append(np.sum(baseline_mses, axis=0))
            # save baseline mses
            baseline_final_mses.append([np.mean(final_mses)])
            baseline_cumulative_mses.append([np.mean(cum_mses)])
            # save figures for videos
            print("Saving figures")
            if i in plot_idxs:
                # plot for original scale data
                utils.save_traj_ns_custom(gts, surr_preds, rl_preds, baseline_preds, smoke_min, smoke_max, output_path)
                # print("Making video")
                make_video(output_path, output_path, i)
                id_gts.append(gts)
                id_unet_preds.append(surr_preds)
                id_baseline_preds.append(baseline_preds)
                id_rl_preds.append(rl_preds)

        print("Converting and saving metrics to hdf5")
        with h5py.File(result_path / "metrics.h5", "w") as hfile:
            rl_grp = hfile.create_group("rl")
            rl_grp.create_dataset("actions", data=id_action_trajectories)
            rl_grp.create_dataset("mses", data=id_mses)
            rl_grp.create_dataset("orig_mses", data=id_orig_mses)
            rl_grp.create_dataset("surr_mses", data=id_surr_mses)
            rl_grp.create_dataset("surr_orig_mses", data=id_surr_orig_mses)
            rl_grp.create_dataset("fno_mses", data=id_fno_mses)
            rl_grp.create_dataset("fno_orig_mses", data=id_fno_orig_mses)
            rl_grp.create_dataset("sim_calls", data=id_sim_calls)
            rl_grp.create_dataset("final_mses", data=id_final_mses)
            rl_grp.create_dataset("cumulative_mses", data=id_cumulative_mses)
            rl_grp.create_dataset("entropies", data=id_entropies)
            rl_grp.create_dataset("log_odds", data=id_log_odds)
            rl_grp.create_dataset("gts", data=id_gts)
            rl_grp.create_dataset("unet_preds", data=id_unet_preds)
            rl_grp.create_dataset("baseline_preds", data=id_baseline_preds)
            rl_grp.create_dataset("rl_preds", data=id_rl_preds)
            rl_grp.create_dataset("unet_times", data=unet_times)
            rl_grp.create_dataset("fno_times", data=fno_times)
            rl_grp.create_dataset("rl_times", data=rl_times)
            rl_grp.create_dataset("sim_call_times", data=sim_call_times)
            baseline_grp = hfile.create_group("baseline")
            baseline_grp.create_dataset("final_mses", data=baseline_final_mses)
            baseline_grp.create_dataset("cumulative_mses", data=baseline_cumulative_mses)
            baseline_grp.create_dataset("baseline_mses", data=all_baseline_mses)
            baseline_grp.create_dataset("baseline_orig_mses", data=all_orig_baseline_mses)
            baseline_grp.create_dataset("baseline_actions", data=baseline_actions)

def torch_to_phi_centered(data):
    phi_ten = phi_tensor(data.transpose(1, 0), spatial('x,y')) # have to transpose the data to get x and y in the right dims
    phi_grid = CenteredGrid(phi_ten, extrapolation.BOUNDARY, Box['x,y', 0 : 32.0, 0 : 32.0])
    return phi_grid

def phi_centered_to_torch(data):
    data_np = data.values.numpy('x,y')
    return torch.from_numpy(data_np.transpose(1,0)) # transpose to get into y, x order

def torch_to_phi_staggered(data_x, data_y):
    # expand torch tensors to be 1 larger in each dimension
    data_x = torch.cat((data_x, data_x[:, -1].reshape(data_x.shape[1], -1)), 1)
    data_x = torch.cat((data_x, data_x[-1, :].reshape(-1, data_x.shape[1])), 0)
    data_y = torch.cat((data_y, data_y[:, -1].reshape(data_y.shape[1], -1)), 1)
    data_y = torch.cat((data_y, data_y[-1, :].reshape(-1, data_y.shape[1])), 0)
    # stack tensors for velocity vector field
    stacked = torch.stack((data_x.transpose(1,0), data_y.transpose(1,0)), dim=2)
    stacked_ten = phi_tensor(stacked, spatial('x,y'), channel('vector'))
    # create staggered grid
    phi_grid = StaggeredGrid(stacked_ten, extrapolation.ZERO, Box['x,y', 0 : 32.0, 0 : 32.0])
    return phi_grid

def phi_staggered_to_torch(data):
    field = data.staggered_tensor().numpy('x,y,vector')
    field_x = torch.from_numpy(field[:-1, :-1, 0].transpose(1,0))
    field_y = torch.from_numpy(field[:-1, :-1, 1].transpose(1,0))
    return field_x, field_y

# last_pred should be shaped B x T x C x H x W
def call_model(timestep: int, last_pred, model, dev, noise_level):
    # with torch.no_grad():
    timesteps = torch.full((last_pred.shape[0],), timestep, device=dev)
    output = model(last_pred, timesteps)
    return output

@jit_compile
def call_sim(smoke_grid, velocity_grid, dt, buoyancy_factor, diffusion_coef):
    # convert torch tensors to phi flow tensors
    smoke_grid = advect.semi_lagrangian(smoke_grid, velocity_grid, dt) # default dt is 1.5
    buoyancy_force = (smoke_grid * (0, buoyancy_factor)).at(velocity_grid)  # resamples smoke to velocity sample points
    velocity_grid = advect.semi_lagrangian(velocity_grid, velocity_grid, dt) + dt * buoyancy_force
    velocity_grid = diffuse.explicit(velocity_grid, diffusion_coef, dt)
    velocity_grid, _ = fluid.make_incompressible(velocity_grid)
    return smoke_grid, velocity_grid

def make_video(input_path, output_path, trajectory_num):
    os.system(f"ffmpeg -framerate 2 -i {input_path}/grid-%d.png -vcodec h264 -vf scale=1240:-2,format=yuv420p {output_path}/traj-{trajectory_num}.mp4")

if __name__ == "__main__":
    test("resnet34", "ns-64", "cuda", 30)
