from __future__ import annotations

import pickle
from pathlib import Path
import os

import numpy as np
import torch
from fno.fno import FNO1d, FNO2d, FNO3d
from fno.utils_1d_ks_baseline import FNODatasetMult
from metrics_aux import metrics
from torch import nn
from tqdm import tqdm
import wandb
import random
import pdb
import gc
import h5py
import matplotlib
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def run_training(
    if_training,
    continue_training,
    rollout_test,
    num_workers,
    modes,
    width,
    initial_step,
    t_train,
    num_channels,
    batch_size,
    epochs,
    train_subsample,
    learning_rate,
    scheduler_step,
    scheduler_gamma,
    model_update,
    FNO_model_flmn,
    plot,
    channel_plot,
    x_min,
    x_max,
    y_min,
    y_max,
    t_min,
    t_max,
    base_path="../data/",
    training_type="single",
    scheduler='cosine',
):
    # print(
    #    f"Epochs = {epochs}, learning rate = {learning_rate}, scheduler step = {scheduler_step}, scheduler gamma = {scheduler_gamma}"
    # )

    ################################################################
    # load data
    ################################################################

    # filename
    model_name = FNO_model_flmn + "_FNO"

    # print("FNODatasetMult")
    train_data = FNODatasetMult(
        saved_folder=base_path,
        train_subsample=train_subsample,
        rollout_test=rollout_test,
    )
    val_data = FNODatasetMult(
        if_test=True,
        saved_folder=base_path,
        rollout_test=rollout_test,
    )

    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=batch_size, num_workers=num_workers, shuffle=True
    )
    val_loader = torch.utils.data.DataLoader(
        val_data, batch_size=batch_size, num_workers=num_workers, shuffle=False
    )

    print("length of training loader:",len(train_loader), "length of test loader:", len(val_loader))
    print("Device:", device)

    ################################################################
    # training and evaluation
    ################################################################

    _, _data, _ = next(iter(val_loader))
    dimensions = len(_data.shape)
    # print("Spatial Dimension", dimensions - 3)

    if dimensions == 5:
        model = FNO2d(
            num_channels=num_channels,
            width=width,
            modes1=modes,
            modes2=modes,
            initial_step=initial_step,
        ).to(device)
    elif dimensions == 6:
        model = FNO3d(
            num_channels=num_channels,
            width=width,
            modes1=modes,
            modes2=modes,
            modes3=modes,
            initial_step=initial_step,
        ).to(device)
    elif dimensions == 4:
        model = FNO1d(
            num_channels=num_channels,
            width=width,
            modes=modes,
            initial_step=initial_step,
        ).to(device)

    # Set maximum time step of the data to train
    t_train = min(t_train, _data.shape[-2])
    model_path = model_name + ".pt"

    output_dir = "../data_gen/result_plot"

    if not if_training:
        checkpoint = torch.load(model_path, map_location=device, weights_only=True)
        model.load_state_dict(checkpoint["model_state_dict"])
        model.to(device)
        model.eval()
        with torch.no_grad():

            for i, (xx, yy, grid) in enumerate(tqdm(val_loader, desc="Inferring")):
                # send to device
                xx, yy = xx.to(device), yy.to(device)
                grid = grid.to(device)
                # warm-start
                pred = yy[..., :initial_step, :]

                # rollout timesteps
                for t in range(t_train-initial_step):
                    im_primary = model(xx, grid)
                    pred = torch.cat((pred, im_primary), dim=-2)
                    # shift xx / yy window
                    xx = torch.cat((xx[...,1:,:], yy[...,t:t+1,:]), dim=-2)
                # pdb.set_trace()


                pred_np = pred.squeeze(0).permute(1,0,2).squeeze(-1).cpu().numpy()
                Ldom = 2*np.pi*32

                def spacetime(pred_np, Ldom, title):
                    plt.figure(figsize=(9,4.2))
                    plt.imshow(pred_np, extent=[0, Ldom, 0, t_train-1], aspect='auto', origin='lower')
                    plt.xlabel("x"); plt.ylabel("time"); plt.title(title)
                    plt.tight_layout(); plt.show()
                    plt.savefig("spacetime_bl.png", dpi=200, bbox_inches="tight")

                spacetime(pred_np, Ldom, "baseline")

                # write out one file per sample
                # out_path = os.path.join(output_dir, f"2D_DR_pred_trj_sample{i:03d}.h5")
                # with h5py.File(out_path, "w") as f:
                #     f.create_dataset(f"{seed_str}/data",
                #     data=pred_np,
                #     dtype="float32",
                #     compression="lzf",)


        print("Inference complete!")


   