"""
Implementation of ordinary neural operators with autoregressive training.
"""

import json
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchdiffeq import odeint
import numpy as np
import yaml
import argparse
from datasets.loaders import get_pde_dataloader
from models.networks.fno import FNO2d
from models.networks.deeponet import DeepONet2d
from models.networks.oformer.oformer import OFormerUniform2d
from models.networks.cape import CAPE2d
from models.networks.dpot import DPOT2d
from models.networks.unet import UNet2d
from models.networks.cno import CNO2d
from models.networks.VCNeF.vcnef import VCNeF2d
from losses import *
from metrics import *
from optim import Adam as AdamFNO
from utils import *
import time
import os


def main(config):
    # device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Ablation study identifiers
    ablate_idx = ""
    if args.network == "FNO2d":
        width = config["expert"]["FNO2d"]["width"]
        ablate_idx += f"_width{width}"
    if args.network == "CAPE2d":
        width = config["expert"]["CAPE2d"]["width"]
        noise_scale = config["expert"]["CAPE2d"]["noise_scale"]
        ablate_idx += f"_width{width}_noise{noise_scale}"
    if args.network == "DPOT2d":
        embed_dim = config["expert"]["DPOT2d"]["embed_dim"]
        depth = config["expert"]["DPOT2d"]["depth"]
        noise_scale = config["expert"]["DPOT2d"]["noise_scale"]
        ablate_idx += f"_emb{embed_dim}_depth{depth}_noise{noise_scale}"
    if args.network == "UNet2d":
        hidden_dim = config["expert"]["UNet2d"]["init_hidden_dim"]
        ablate_idx += f"_hid{hidden_dim}"
    if args.network == "CNO2d":
        ablate_idx += ""
    if args.network == "OFormerUniform2d":
        latent_channels = config["expert"]["OFormerUniform2d"]["latent_channels"]
        encoder_depth = config["expert"]["OFormerUniform2d"]["encoder_depth"]
        ablate_idx += f"_latent{latent_channels}_depth{encoder_depth}"
    if args.network == "DeepONet2d":
        hidden_dim = config["expert"]["DeepONet2d"]["hidden_dim"]
        ablate_idx += f"_hid{hidden_dim}"
    if args.network == "VCNeF2d":
        d_model = config["expert"]["VCNeF2d"]["d_model"]
        n_modulation_blocks = config["expert"]["VCNeF2d"]["n_modulation_blocks"]
        condition_on_pde_param = int(config["expert"]["VCNeF2d"]["condition_on_pde_param"])
        ablate_idx += f"_latent{d_model}_depth{n_modulation_blocks}_cond{condition_on_pde_param}"

    # create model folder
    model_dir = config["output"]["model_dir"].replace("MoE", args.network)
    model_dir += ablate_idx
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    # load data
    data_cfg = config["data"]
    num_env_train = data_cfg["num_env_train"]
    train_loader, test_id_loader, test_ood_loader = get_pde_dataloader(cfg_data=data_cfg)
    ntrain = num_env_train*data_cfg["n_data_per_env_train"]
    ntest_id = len(test_id_loader)*data_cfg["batch_size_test"]
    ntest_ood = len(test_ood_loader)*data_cfg["batch_size_test"]
    print(f"ntrain={ntrain}, ntest_id={ntest_id}, ntest_ood={ntest_ood}")

    # create model
    exp_cfg = config["expert"]
    init_step = data_cfg["init_step"]
    num_var = data_cfg["num_var"]
    Nt = data_cfg["Nt"]
    comb_channels = init_step * num_var
    if args.network == "FNO2d":
        fno2d_cfg = exp_cfg["FNO2d"]
        model = FNO2d(in_channels=comb_channels, out_channels=exp_cfg["out_channels"],
                      modes1=fno2d_cfg["modes1"], modes2=fno2d_cfg["modes2"],
                      width=fno2d_cfg["width"], spatial_size=fno2d_cfg["spatial_size"],
                      n_layers=fno2d_cfg["n_layers"], act=fno2d_cfg["act_type"],
                      padding=fno2d_cfg["padding"], weight_init=fno2d_cfg["weight_init"],
                      x_span=exp_cfg["x_span"], y_span=exp_cfg["y_span"]).to(device)
    elif args.network == "CAPE2d":
        cape2d_cfg = exp_cfg["CAPE2d"]
        model = CAPE2d(widening_factor=cape2d_cfg["widening_factor"], num_params=cape2d_cfg["num_params"],
                       in_channels=comb_channels, out_channels=cape2d_cfg["out_channels"],
                       width=cape2d_cfg["width"], modes1=cape2d_cfg["modes1"], modes2=cape2d_cfg["modes2"],
                       normed_dim=cape2d_cfg["normed_dim"]).to(device)
    elif args.network == "DPOT2d":
        dpot2d_cfg = exp_cfg["DPOT2d"]
        model = DPOT2d(img_size=dpot2d_cfg["img_size"], patch_size=dpot2d_cfg["patch_size"],
                       in_channels=dpot2d_cfg["in_channels"], out_channels=dpot2d_cfg["out_channels"],
                       in_timesteps=dpot2d_cfg["in_timesteps"], out_timesteps=dpot2d_cfg["out_timesteps"],
                       embed_dim=dpot2d_cfg["embed_dim"], depth=dpot2d_cfg["depth"],
                       modes=dpot2d_cfg["modes"], normalize=dpot2d_cfg["normalize"]).to(device)
    elif args.network == "UNet2d":
        unet2d_cfg = exp_cfg["UNet2d"]
        model = UNet2d(in_channels=comb_channels, out_channels=exp_cfg["out_channels"],
                       init_features=unet2d_cfg["init_hidden_dim"]).to(device)
    elif args.network == "CNO2d":
        cno2d_cfg = exp_cfg["CNO2d"]
        model = CNO2d(in_dim=comb_channels, out_dim=cno2d_cfg["exp_out_channels"],
                      size=cno2d_cfg["size"], N_layers=cno2d_cfg["n_layers"],
                      N_res=cno2d_cfg["N_res"], N_res_neck=cno2d_cfg["N_res_neck"],
                      channel_multiplier=cno2d_cfg["channel_multiplier"]).to(device)
    elif args.network == "OFormerUniform2d":
        ofu2d_config = exp_cfg["OFormerUniform2d"]
        model = OFormerUniform2d(in_channels=comb_channels, out_channels=ofu2d_config["exp_out_channels"],
                                 latent_channels=ofu2d_config["latent_channels"], encoder_emb_dim=ofu2d_config["encoder_emb_dim"],
                                 encoder_heads=ofu2d_config["encoder_heads"], encoder_depth=ofu2d_config["encoder_depth"],
                                 x_span=exp_cfg["x_span"], y_span=exp_cfg["y_span"]).to(device)
    elif args.network == "DeepONet2d":
        dpo2d_config = exp_cfg["DeepONet2d"]
        model = DeepONet2d(in_channels=comb_channels, out_channels=dpo2d_config["exp_out_channels"],
                           hidden_dim=dpo2d_config["hidden_dim"], x_size=dpo2d_config["x_size"],
                           y_size=dpo2d_config["y_size"], act_type=dpo2d_config["act_type"],
                           x_span=exp_cfg["x_span"], y_span=exp_cfg["y_span"]).to(device)
    elif args.network == "VCNeF2d":
        vcnef2d_config = exp_cfg["VCNeF2d"]
        model = VCNeF2d(num_channels=vcnef2d_config["num_channels"], env_dim=vcnef2d_config["env_dim"],
                        d_model=vcnef2d_config["d_model"], n_modulation_blocks=vcnef2d_config["n_modulation_blocks"],
                        condition_on_pde_param=vcnef2d_config["condition_on_pde_param"],
                        x_span=exp_cfg["x_span"], y_span=exp_cfg["y_span"]).to(device)
    else:
        raise NotImplementedError
    model_size = count_params(model)
    print(f"{args.network} size: {model_size}")

    output_file_name = f"{args.dataset}_{args.network}_env{num_env_train}_N{ntrain}"
    output_file_name += ablate_idx

    # optimization
    opt_cfg = config["optim"]
    n_epochs = opt_cfg["n_epochs"]
    if "FNO" in args.network or "CAPE" in args.network:
        optimizer = AdamFNO(model.parameters(), lr=opt_cfg["init_lr"], weight_decay=opt_cfg["weight_decay"])
    else:
        optimizer = optim.AdamW(model.parameters(), lr=opt_cfg["init_lr"], weight_decay=opt_cfg["weight_decay"])
    p1 = int(0.75 * n_epochs)
    p2 = int(0.9 * n_epochs)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[p1, p2], gamma=0.1)
    for name, param in model.named_parameters():
        print(name, param.requires_grad, param.size())

    # loss function
    lp_loss = RelativeL2(reduction=False)

    # training stage
    if args.is_train:
        print(f"Training stage begins")
        train_error = []
        for epoch in range(n_epochs):
            model.train()
            train_loss = {"total": 0, "pred": 0}
            start_time = time.time()

            for i, data in enumerate(train_loader):
                x = data["state"].to(device)  # (B, Nx, Ny, C, Nt)
                c = data["context"].to(device)  # (B, env_dim)
                t = data["time"][0].to(device)  # (Nt, )

                # odeint function
                def derivative_func(t, y):  # (B, Nx, Ny, C)
                    if args.network == "CAPE2d":
                        input_t_noisy = input_t + noise_scale*(torch.sum(input_t**2, dim=(1,2,4), keepdim=True)**0.5)*torch.randn_like(input_t)
                        y_hist = rearrange(input_t_noisy, 'b h w c t -> b (c t) h w')  # (B, C*init_step, Nx, Ny)
                        pred_t = model(y_hist, c).permute(0, 2, 3, 1)  # (B, Nx, Ny, C)
                    elif args.network == "DPOT2d":
                        input_t_noisy = input_t + noise_scale*(torch.sum(input_t**2, dim=(1,2,4), keepdim=True)**0.5)*torch.randn_like(input_t)
                        y_hist = rearrange(input_t_noisy, 'b h w c t -> b h w t c')  # (B, Nx, Ny, init_step, C)
                        pred_t = model(y_hist).squeeze(-2)  # (B, Nx, Ny, C)
                    elif args.network == "VCNeF2d":
                        y_hist = y.permute(0, 3, 1, 2)  # (B, C, Nx, Ny)
                        t_inp = t.unsqueeze(0)  # (1, )
                        pred_t = model(y_hist, c, t_inp)  # (B, Nx, Ny, C)
                    else:
                        y_hist = rearrange(input_t, 'b h w c t -> b (c t) h w')  # (B, C*init_step, Nx, Ny)
                        pred_t = model(y_hist).permute(0, 2, 3, 1)  # (B, Nx, Ny, C)
                    return pred_t

                preds_out, targets_out = [], []
                input_t = x[..., :init_step]  # (B, Nx, Ny, C, init_step)
                for t_eval in range(init_step, Nt):
                    # predicting in data space
                    # x_hist = rearrange(input_t, 'b h w c t -> b (c t) h w')  # (B, C*init_step, Nx, Ny)
                    # pred_t = model(x_hist).permute(0, 2, 3, 1).unsqueeze(-1)  # (B, Nx, Ny, C, 1)

                    # predicting in derivative space
                    x0 = input_t[..., -1]  # (B, Nx, Ny, C)
                    res_t = odeint(derivative_func, y0=x0, t=t[t_eval-1:t_eval+1],
                                   method=args.int_method, options=dict())  # (2, B, Nx, Ny, C)
                    pred_t = res_t[-1].unsqueeze(-1)  # (B, Nx, Ny, C, 1)
                    target_t = x[..., t_eval].unsqueeze(-1)  # (B, Nx, Ny, C, 1)
                    input_t = torch.cat((input_t[..., 1:], pred_t), dim=-1)  # (B, Nx, Ny, C, init_step)

                    preds_out.append(pred_t)
                    targets_out.append(target_t)
                preds_out = rearrange(torch.cat(preds_out, dim=-1),
                                      'b h w c t -> (b t) h w c')  # (B*(Nt-init_step), Nx, Ny, C)
                targets_out = rearrange(torch.cat(targets_out, dim=-1),
                                        'b h w c t -> (b t) h w c')  # (B*(Nt-init_step), Nx, Ny, C)

                # calculate prediction loss
                batch_lp_loss = lp_loss(preds_out, targets_out)  # (B*(Nt-init_step), )
                pred_loss = torch.mean(batch_lp_loss)
                total_loss = pred_loss

                optimizer.zero_grad()
                total_loss.backward(retain_graph=True)
                optimizer.step()
                train_loss["total"] += total_loss.item()
                train_loss["pred"] += pred_loss.item()
            end_time = time.time()
            train_loss["total"]/=len(train_loader); train_loss["pred"]/=len(train_loader)
            lr_scheduler.step()
            # if (epoch >= n_epochs-10) or (epoch % 25 == 0):
            if epoch >= n_epochs - 10:
                torch.save(model.state_dict(), f"{model_dir}/{epoch}.pt")
                print(f"Save {model_dir}/{epoch}.pt")
            print(f"[epoch {epoch}/{n_epochs}]: total_loss={train_loss['total']}, "
                  f"pred_loss={train_loss['pred']}, time={end_time-start_time}s")
            train_error.append(train_loss["total"])

        plot_train_loss(train_error, output_file_name)

    # Testing stage
    else:
        # Load weights
        best_epoch = args.best_epoch
        ckpt_file = f"{model_dir}/{best_epoch}.pt"
        model.load_state_dict(torch.load(ckpt_file))
        print(f"Load {ckpt_file}")

        model.eval()
        if args.test_type == "ID":
            test_loader = test_id_loader
        elif args.test_type == "OOD":
            test_loader = test_ood_loader
        else:
            raise NotImplementedError
        print(f"{args.test_type} testing stage begins")

        # Evaluation metrics
        metrics = {"nMSE": {}, "fRMSE": {}}
        with torch.no_grad():
            for i, data in enumerate(test_loader):
                start_time = time.time()
                state = data["state"].to(device)
                context = data["context"].to(device)
                t = data["time"].to(device)

                # odeint function
                def derivative_func(t, y):  # (B, Nx, Ny, C)
                    if args.network == "CAPE2d":
                        y_hist = rearrange(input_t, 'b h w c t -> b (c t) h w')  # (B, C*init_step, Nx, Ny)
                        pred_t = model(y_hist, context).permute(0, 2, 3, 1)  # (B, Nx, Ny, C)
                    elif args.network == "DPOT2d":
                        y_hist = rearrange(input_t, 'b h w c t -> b h w t c')  # (B, Nx, Ny, init_step, C)
                        pred_t = model(y_hist).squeeze(-2)  # (B, Nx, Ny, C)
                    elif args.network == "VCNeF2d":
                        y_hist = y.permute(0, 3, 1, 2)  # (B, C, Nx, Ny)
                        t_inp = t.unsqueeze(0)  # (1, )
                        pred_t = model(y_hist, context, t_inp)  # (B, Nx, Ny, C)
                    else:
                        y_hist = rearrange(input_t, 'b h w c t -> b (c t) h w')  # (B, C*init_step, Nx, Ny)
                        pred_t = model(y_hist).permute(0, 2, 3, 1)  # (B, Nx, Ny, C)
                    return pred_t

                pred_x, targets = [], []
                input_t = state[..., :init_step]  # (B, Nx, Ny, C, init_step)
                for t_eval in range(init_step, Nt):
                    # x_hist = rearrange(input_t, 'b h w c t -> b (c t) h w')  # (B, C*init_step, Nx, Ny)
                    # pred_t = model(x_hist).permute(0, 2, 3, 1).unsqueeze(-1)  # (B, Nx, Ny, C, 1)

                    x0 = input_t[..., -1]  # (B, Nx, Ny, C)
                    t_step = t[0]  # (Nt, )
                    res_t = odeint(derivative_func, y0=x0, t=t_step[t_eval-1:t_eval+1],
                                   method=args.int_method, options=dict())  # (2, B, Nx, Ny, C)
                    pred_t = res_t[-1].unsqueeze(-1)  # (B, Nx, Ny, C, 1)
                    target_t = state[..., t_eval].unsqueeze(-1)  # (B, Nx, Ny, C, 1)
                    input_t = torch.cat((input_t[..., 1:], pred_t), dim=-1)  # (B, Nx, Ny, C, init_step)

                    pred_x.append(pred_t)
                    targets.append(target_t)
                pred_x = torch.cat(pred_x, dim=-1)  # (B, Nx, Ny, C, Nt-init_step)
                targets = torch.cat(targets, dim=-1)  # (B, Nx, Ny, C, Nt-init_step)

                err_nmse = cal_nMSE(pred_x, targets)  # (B, )
                err_frmse = cal_fRMSE(pred_x, targets)  # (B, )
                if np.isnan(err_nmse) or np.isnan(err_frmse):
                    continue
                metrics["nMSE"][f"case {i}"] = float(err_nmse[0])
                metrics["fRMSE"][f"case {i}"] = float(err_frmse[0])
                end_time = time.time()

                # Plot trajectories
                if args.is_plot:
                    pred_all = torch.cat((state[..., :init_step], pred_x), dim=-1)  # (B, Nx, Ny, C, Nt)
                    if "ns" in args.dataset:
                        plot_state_data(state, pred_all, t, channel=0, t_fraction=1, plt_cfg=config["plot"],
                                        ablate_idx=args.network, fig_name=f"w_{args.test_type}_{i}", is_naive=True)
                    elif "dr" in args.dataset or "bg" in args.dataset or "gs" in args.dataset:
                        plot_state_data(state, pred_all, t, channel=0, t_fraction=1, plt_cfg=config["plot"],
                                        ablate_idx=args.network, fig_name=f"u_{args.test_type}_{i}", is_naive=True)
                        plot_state_data(state, pred_all, t, channel=1, t_fraction=1, plt_cfg=config["plot"],
                                        ablate_idx=args.network, fig_name=f"v_{args.test_type}_{i}", is_naive=True)
                    elif "sw" in args.dataset:
                        plot_state_data(state, pred_all, t, channel=0, t_fraction=1, plt_cfg=config["plot"],
                                        ablate_idx=args.network, fig_name=f"h_{args.test_type}_{i}", is_naive=True)
                    elif "hc" in args.dataset:
                        plot_state_data(state, pred_all, t, channel=0, t_fraction=1, plt_cfg=config["plot"],
                                        ablate_idx=args.network, fig_name=f"u_{args.test_type}_{i}", is_naive=True)
                    elif "sst" in args.dataset:
                        plot_state_data(state, pred_all, t, channel=0, t_fraction=1, plt_cfg=config["plot"],
                                        ablate_idx=args.network, fig_name=f"t_{args.test_type}_{i}", is_naive=True)
                    else:
                        raise NotImplementedError

                print(f"{args.test_type} testing case {i}: env={context.detach().cpu().numpy()}, nMSE={err_nmse}, "
                      f"fRMSE={err_frmse}, time={end_time-start_time}s")
        avr_nMSE = np.mean(list(metrics["nMSE"].values()))
        avr_fRMSE = np.mean(list(metrics["fRMSE"].values()))
        std_nMSE = np.std(list(metrics["nMSE"].values()))
        std_fRMSE = np.std(list(metrics["fRMSE"].values()))
        print(f"{args.test_type} testing results: avr_nMSE={avr_nMSE}, avr_fRMSE={avr_fRMSE}")
        print(f"{args.test_type} testing results: std_nMSE={std_nMSE}, std_fRMSE={std_fRMSE}")
        with open(f"output/{output_file_name}_{args.test_type}.json", "w") as f:
            json.dump(metrics, f, indent=4)


if __name__ == "__main__":
    # input args
    parser = argparse.ArgumentParser(description="NDE-MoE")
    parser.add_argument("--is_train", type=bool, default=True)
    # parser.add_argument("--is_train", type=bool, default=False)
    parser.add_argument("--best_epoch", type=int, default=499)

    parser.add_argument("--test_type", type=str, default="ID", help="options: ['ID', 'OOD']")
    parser.add_argument("--is_plot", type=bool, default=False)

    parser.add_argument("--config_file", type=str, default="dr2d_moe.yaml")
    # parser.add_argument("--config_file", type=str, default="ns2d_moe.yaml")
    # parser.add_argument("--config_file", type=str, default="sw2d_moe.yaml")
    # parser.add_argument("--config_file", type=str, default="hc2d_moe.yaml")
    # parser.add_argument("--config_file", type=str, default="bg2d_moe.yaml")
    # parser.add_argument("--config_file", type=str, default="gs2d_moe.yaml")
    # parser.add_argument("--config_file", type=str, default="sst2d_moe.yaml")
    parser.add_argument("--seed", type=int, default=2025)

    parser.add_argument("--dataset", type=str, default="dr2d")
    # parser.add_argument("--dataset", type=str, default="ns2d")
    # parser.add_argument("--dataset", type=str, default="sw2d")
    # parser.add_argument("--dataset", type=str, default="hc2d")
    # parser.add_argument("--dataset", type=str, default="bg2d")
    # parser.add_argument("--dataset", type=str, default="gs2d")
    # parser.add_argument("--dataset", type=str, default="sst2d")

    # parser.add_argument("--network", type=str, default="FNO2d")
    # parser.add_argument("--network", type=str, default="DeepONet2d")
    parser.add_argument("--network", type=str, default="OFormerUniform2d")
    # parser.add_argument("--network", type=str, default="UNet2d")
    # parser.add_argument("--network", type=str, default="CAPE2d")
    # parser.add_argument("--network", type=str, default="DPOT2d")
    # parser.add_argument("--network", type=str, default="CNO2d")
    # parser.add_argument("--network", type=str, default="VCNeF2d")

    # rk4 will induce vastly more computation time than "euler"
    parser.add_argument("--int_method", type=str, default="euler", help="options: ['euler', 'midpoint', 'rk4']")
    parser.add_argument("--int_step_scale", type=float, default=1.0)

    args = parser.parse_args()
    print(args)

    # configuration
    cfg_path = "configs/" + args.config_file
    with open(cfg_path, "r") as f:
        config = yaml.safe_load(f)

    # fix seed
    fix_seed(args.seed)

    # run exp
    # main(config)

    # modify hyper-parameters and re-run exp
    config["expert"]["UNet2d"]["init_hidden_dim"] = 48
    config["expert"]["OFormerUniform2d"]["latent_channels"] = 32
    config["expert"]["OFormerUniform2d"]["encoder_emb_dim"] = 16
    main(config)

    config["data"]["path_train"] = "datasets/data/dr2d/dr2d_train_env16_N4096_Nx64_Ny64_T21.pkl"
    config["data"]["num_env_train"] = 16
    config["data"]["n_data_per_env_train"] = 256
    config["output"]["model_dir"] = "weights/MoE/dr2d/env16_N4096_Nx64_Ny64_T21"
    # main(config)

    config["data"]["path_train"] = "datasets/data/dr2d/dr2d_train_env16_N2048_Nx64_Ny64_T21.pkl"
    config["data"]["num_env_train"] = 16
    config["data"]["n_data_per_env_train"] = 128
    config["output"]["model_dir"] = "weights/MoE/dr2d/env16_N2048_Nx64_Ny64_T21"
    # main(config)

    config["data"]["path_train"] = "datasets/data/dr2d/dr2d_train_env16_N1024_Nx64_Ny64_T21.pkl"
    config["data"]["num_env_train"] = 16
    config["data"]["n_data_per_env_train"] = 64
    config["output"]["model_dir"] = "weights/MoE/dr2d/env16_N1024_Nx64_Ny64_T21"
    # main(config)

    config["data"]["path_train"] = "datasets/data/dr2d/dr2d_train_env16_N512_Nx64_Ny64_T21.pkl"
    config["data"]["num_env_train"] = 16
    config["data"]["n_data_per_env_train"] = 32
    config["output"]["model_dir"] = "weights/MoE/dr2d/env16_N512_Nx64_Ny64_T21"
    # main(config)

    config["data"]["path_train"] = "datasets/data/dr2d/dr2d_train_env16_N256_Nx64_Ny64_T21.pkl"
    config["data"]["num_env_train"] = 16
    config["data"]["n_data_per_env_train"] = 16
    config["output"]["model_dir"] = "weights/MoE/dr2d/env16_N256_Nx64_Ny64_T21"
    # main(config)