import matplotlib.pyplot as plt
import torch
from torch.optim.lr_scheduler import LambdaLR
import numpy as np

import argparse
import os
import logging
from collections import OrderedDict

import seaborn as sns
import pandas as pd
import matplotlib as mpl
from pathlib import Path

from pdPINN.util.plot_util import AnimatedImages
from pdPINN.model.model_utiities import ST
from pdPINN.model.siren_3d import MassCons3d
from pdPINN.model.siren import batch_predict
from create_data import config
from pprint import pformat
from typing import * 

from pdPINN.model.siren import dataloader_from_np

# from pdPINN.util.system_util import AtomicOpen
import fasteners

from itertools import chain
import csv

from datetime import datetime

from pdPINN.util.system_util import save_json


def main(model_settings, path_settings) -> dict:

    settings = {"format": "%(asctime)s \t %(message)s",
                "datefmt": '%m-%d %H:%M'}
    if model_settings['silent']:
        logging.basicConfig(level=logging.ERROR, **settings)
    else:
        logging.basicConfig(level=logging.INFO, **settings)
    
    
    logging.info("\n" + pformat(model_settings))
    logging.info("\n" + pformat(path_settings))
    
    

    np.random.seed(model_settings['seed'])
    torch.random.manual_seed(model_settings['seed'])

    os.makedirs(path_settings["results_csv"].parent, exist_ok=True)
    lock = fasteners.InterProcessLock("3d_lock.file")

    # load data
    df_train, df_val, df_test = load_dataframes(path_settings)
    loss_dict_combined = train_model(df_train, df_val,
                                     df_test, model_settings=model_settings, path_settings=path_settings)

    save_json(loss_dict_combined, model_settings, prefix="d3_")

    CSV_EXISTS = os.path.isfile(path_settings["results_csv"])
    with lock:
        with open(path_settings["results_csv"], 'a') as f_object:
            dict_writer = csv.DictWriter(f_object, fieldnames=loss_dict_combined.keys())
            if not CSV_EXISTS:
                dict_writer.writeheader()  # file doesn't exist yet, write a header

            dict_writer.writerow(loss_dict_combined)
        
    return loss_dict_combined


def load_dataframes(path_settings: dict) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    df_train, df_val, df_test = [(pd.read_parquet(path_settings[name])
                                  .assign(sqrt_density=lambda tmp_df: np.sqrt(tmp_df["density"])))
                                 for name in ["df_train", "df_val", "df_test"]]

    df_train = df_train.query("z<0.7").copy()  # .query("z>0.3").copy()


    idx_density_was_zero = df_train["sqrt_density"] == 0.

    df_train["sqrt_density"] = np.random.normal(df_train.sqrt_density,
                                                df_train.sqrt_density.std() / 20,
                                                df_train.sqrt_density.shape)
    df_train.loc[idx_density_was_zero, "sqrt_density"] = 0
    

    
    df_train["u"] += np.random.normal(0, df_train.u.std() / 20, df_train.u.shape)
    df_train["v"] += np.random.normal(0, df_train.v.std() / 20, df_train.v.shape)
    df_train["w"] += np.random.normal(0, df_train.w.std() / 20, df_train.w.shape)                                      
                                      
    return df_train, df_val, df_test

    


def train_model(df_train: pd.DataFrame, df_val: pd.DataFrame, df_test: pd.DataFrame, model_settings: Dict, path_settings: Dict):
    loss_weights = dict(
        nlogprob_density=1,
        velocity=10 * model_settings["weight_velocity_loss"],
        pde_weight=model_settings["pde_weight"],
        # hessian=1e-5,
    )
    if model_settings['mass_only']:
        del loss_weights["velocity"]
    if model_settings["sampling_method"] == ST.none:
        del loss_weights["pde_weight"]

    if model_settings["device"] == "cuda" and not torch.cuda.is_available():
        model_settings["device"] = "cpu"

    np.random.seed(model_settings['seed'])
    torch.random.manual_seed(model_settings['seed'])

    stds_train, stds_val, stds_test = [{key: val for key, val in df.std(ddof=0).items()}
                                       for df in [df_train, df_val, df_test]]
    target_vars = ["sqrt_density", "u", "v", "w"]
    feature_vars = ["t", "x", "y", "z"]

    X_train, X_val, X_test = [df.loc[:, feature_vars].values.astype(np.float32)
                              for df in [df_train, df_val, df_test]]
    Y_train, Y_val, Y_test = [df.loc[:, target_vars].values.astype(np.float32)
                              for df in [df_train, df_val, df_test]]

    mb_size = X_train.shape[0] // 5 + 1  # //4  # 1050
    
    train_dataloader = dataloader_from_np(
        [X_train, Y_train], mb_size, shuffle=True)
    train_dataloader_noshuffle = dataloader_from_np(
        [X_train, Y_train], mb_size, shuffle=False)
    val_dataloader = dataloader_from_np(
        [X_val, Y_val], X_val.shape[0], shuffle=False)
    test_dataloader = dataloader_from_np(
        [X_test, Y_test], mb_size, shuffle=False)


    m = MassCons3d(X=X_train, Y=Y_train, **model_settings)
    m.to(model_settings["device"])
    m._compile(X_train)

    # Generate the optimizers.
    lr = model_settings["lr"]
    optim = torch.optim.Adam(m.parameters(), lr=lr)

    scheduler = LambdaLR(optim, lambda ep: 0.99 ** ep)
    total_step = 0

    if model_settings["sampling_method"] != ST.none:
        m.do_sample(model_settings["n_samples_constraints"],
                    sampling_method=ST.uniform if model_settings["sampling_method"] == ST.uniform else ST.gaussian,
                    first=True)
    try:
        for epoch in range(model_settings["num_its"]):

            for step, (x_, target) in enumerate(train_dataloader):
                x_, target = x_.to(m.device), target.to(m.device)

                optim.zero_grad(set_to_none=True)

                loss_mb, loss_dict_mb = m.training_loss(x_,
                                                        target=target,
                                                        weight_dict=loss_weights,
                                                        n_samples_constraints=model_settings["n_samples_constraints"]
                                                        )
                loss_mb.backward()
                optim.step()

            if epoch % 30 == 0 and epoch > 0:
                m.do_sample(model_settings["n_samples_constraints"], model_settings["sampling_method"])

                if model_settings["sampling_method"] not in [ST.uniform, ST.none,
                                                             ST.importance_sampling,
                                                             ST.it_pdpinn] and m.fraction_mcmc < .8:
                    m.fraction_mcmc = min(.8, m.fraction_mcmc + 0.2)
                    logging.info(f"{m.fraction_mcmc:.2f}")
                    logging.info(f"acceptance rate: {m.sampler.acceptance_rate * 100:.1f}%")

            if (epoch + 1) % 10 == 0 or epoch == 0:
                log_errors(epoch, loss_dict_mb, stds_train)

                with torch.no_grad():
                    loss_dict_val = get_loss_dict(df_val, m, val_dataloader)
                    log_test_errors(loss_dict_val)

            scheduler.step()
            total_step += 1
            optim.zero_grad(set_to_none=True)
        RETURN_STATE = "SUCCESS"
    except KeyboardInterrupt as exc:
        RETURN_STATE = repr(exc)
        print(RETURN_STATE)
        print(f"Manually interrupted training at epoch {epoch}")

    with torch.no_grad():
        loss_dict_train = get_loss_dict(df_train, m, train_dataloader_noshuffle, postfix="_train", epochs=epoch)
        loss_dict_val = get_loss_dict(df_val, m, val_dataloader, postfix="_val")
        loss_dict_test = get_loss_dict(df_test, m, test_dataloader, postfix="_test")
        loss_dict_combined = OrderedDict(chain(loss_dict_train.items(),
                                               loss_dict_val.items(),
                                               loss_dict_test.items(),
                                               [(key, str(item)) for key,item in model_settings.items()])
                                         )
        log_test_errors(loss_dict_train, postfix="_train")
        log_test_errors(loss_dict_val, postfix="_val")
        log_test_errors(loss_dict_test, postfix="_test")
    loss_dict_combined["return_state"] = RETURN_STATE
    loss_dict_combined["datetime"] = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")


    if not model_settings["no_plots"]:
        logging.info("Plotting..")
        sns.set(style="whitegrid", font_scale=1.5, palette="rocket_r")
        os.makedirs(path_settings["plots"], exist_ok=True)
        
        df_train = predict_to_df(df_train, m, train_dataloader_noshuffle)
        # df_val = predict_to_df(df_val, m, val_dataloader)
        df_test = predict_to_df(df_test, m, test_dataloader)

        # rmse_density_train = ((df_train["density"] ** 0.5 - df_train["density_hat"] ** 0.5) ** 2).mean() ** 0.5

        rar_text = "_rar" if model_settings["rar"] else ""
        otrar_text = "_otrar" if model_settings["ot_rar"] else ""
        rar_text = rar_text + otrar_text

        plot_results(df_test, df_train, path_settings["plots"], rar_text)

    return loss_dict_combined  # , (loss_dict["u"] +loss_dict["v"]+ loss_dict["w"])/3


def get_loss_dict(df, model, dataloader, postfix="", **kwargs):
    velocity_hat_val, sqrt_density_hat_val, _, _ = batch_predict(model, dataloader)
    loss_dict = OrderedDict(
        density=((np.squeeze(sqrt_density_hat_val) - df["sqrt_density"].values) ** 2).mean() / df["sqrt_density"].std(
            ddof=0) ** 2,
        u=((velocity_hat_val[:, 0] - df["u_atradar"].values) ** 2).mean() / df["u_atradar"].std(ddof=0) ** 2,
        v=((velocity_hat_val[:, 1] - df["v_atradar"].values) ** 2).mean() / df["v_atradar"].std(ddof=0) ** 2,
        w=((velocity_hat_val[:, 2] - df["w_atradar"].values) ** 2).mean() / df["w_atradar"].std(ddof=0) ** 2,
        **kwargs
    )

    loss_dict = OrderedDict((key + postfix, val) for key, val in loss_dict.items())

    return loss_dict


def log_test_errors(loss_dict, postfix=""):
    logging.info(f"{postfix}:\t u:\t{loss_dict['u' + postfix]:.3f}, "
                 f"v:\t{loss_dict['v' + postfix]:.3f}, "
                 f"w:\t{loss_dict['w' + postfix]:.3f}, "
                 f"\tDensity:\t{loss_dict['density' + postfix]:.3f}"
                 )


def log_errors(epoch, loss_dict, stds_train):
    logging.info(f"\nEpoch: {epoch}")
    logging.info(f"Train:\t "
                 f"u:\t{loss_dict['u'] / stds_train['u'] ** 2:.3f}, "
                 f"v:\t{loss_dict['v'] / stds_train['v'] ** 2:.3f}, "
                 f"w:\t{loss_dict['w'] / stds_train['w'] ** 2:.3f}, "
                 f"\tDensity:\t{loss_dict['density'] / stds_train['sqrt_density'] ** 2:.3f}, "
                 f"\t\tVelocity :\t{loss_dict['velocity']:.2f}, "
                 f"\tMass-Conservation: {loss_dict.get('sinks_and_sources_deviation', 0.):.3f},"
                 )


def plot_results(df_test, df_train, plot_path, rar_text):
    """
    - 2D animation
    - 1d projection over time
    - Density over Time.
    :param df_test:
    :param df_train:
    :param PLOT_PATH:
    :return:
    """
    df_test_copy = df_test.copy()
    df_test_copy.loc[df_test_copy["density"] == 0, ["u", "v"]] = np.nan
    df_visualization = df_test_copy.groupby(["x", "y", "t"]).aggregate({"density_hat": "sum",
                                                                        "u_hat": np.nanmean,
                                                                        "v_hat": np.nanmean,
                                                                        "w_hat": np.nanmean,
                                                                        }).reset_index()
    df_visualization.loc[:, "mass"] = np.sqrt(df_visualization["density_hat"])
    df_visualization.loc[:, "u"] = df_visualization["u_hat"]
    df_visualization.loc[:, "v"] = df_visualization["v_hat"]
    df_visualization.loc[:, "w"] = df_visualization["w_hat"]

    threshold = 0.25
    
    df_visualization_ground = df_test_copy.query(f"z<{threshold}").groupby(["x", "y", "t"]).aggregate(
        {"density_hat": "sum",
         "u_hat": np.nanmean,
         "v_hat": np.nanmean,
         "w_hat": np.nanmean,
         }).reset_index()
    df_visualization_ground.loc[:, "mass"] = np.sqrt(df_visualization_ground["density_hat"])
    df_visualization_ground.loc[:, "u"] = df_visualization_ground["u_hat"]
    df_visualization_ground.loc[:, "v"] = df_visualization_ground["v_hat"]
    df_visualization_ground.loc[:, "w"] = df_visualization_ground["w_hat"]
    # df_test_melted = pd.melt(df_test.assign(id=df_test.index),
    #                          id_vars=['id', 't', "x", "y", "z"],
    #                          value_vars=['density', 'density_hat'])
    df_test_grouped = df_test.groupby(["t", "z"]).mean().reset_index()
    norm_1d = mpl.colors.Normalize(vmin=df_test_grouped.density.min(), vmax=df_test_grouped.density.max())
    with sns.axes_style("whitegrid"):
        fig, axs = plt.subplots(2, 1, figsize=(15, 10))
        # df_test_grouped.plot.scatter(x="t", y="z", c="density", colormap="Blues", ax=axs[0], norm=norm_1d, marker="s",
        #                              s=25)
        # df_test_grouped.plot.scatter(x="t", y="z", c="density_hat", colormap="Blues", ax=axs[1], norm=norm_1d,
        #                              marker="s", s=25)

        grid_density = df_test_grouped.pivot(index="z", columns="t", values="density")
        grid_density_hat = df_test_grouped.pivot(index="z", columns="t", values="density_hat")
        imshow_settings_1d = dict(
            origin='lower', cmap="Blues",
            vmin=df_test_grouped.density.min(),
            vmax=df_test_grouped.density.max(),
            extent=[df_test_grouped.t.min(), df_test_grouped.t.max(),
                    df_test_grouped.z.min(), df_test_grouped.z.max()])
        im = axs[0].imshow(grid_density, **imshow_settings_1d)
        im2 = axs[1].imshow(grid_density_hat, **imshow_settings_1d)
        plt.colorbar(im, ax=axs[0])
        plt.colorbar(im2, ax=axs[1])
        axs[0].set_ylabel("z")
        axs[1].set_ylabel("z")

        axs[0].set_xlabel("t")
        axs[1].set_xlabel("t")

        # plt.axhline(threshold, color="red")
        axs[0].set_title("Groundtruth")
        if model_settings["sampling_method"] == ST.mh_pdpinn:
            axs[1].set_title("pdPINN")

        if model_settings["sampling_method"] == ST.uniform and model_settings["ot_rar"] == True:
            axs[1].set_title("OT-RAR")
        plt.tight_layout()
        plt.savefig(plot_path / f"density_1d_projection_{model_settings['sampling_method']}{rar_text}.png")
        plt.close()
    # img_s = config["img_dim"] - 1
    # with sns.axes_style("whitegrid"):
    #     img = (df_visualization_ground.dropna()
    #            .groupby(["x", "y"])
    #            .mean()
    #            .reset_index()
    #            .mass.values
    #            .reshape((img_s, img_s)))

    ani = AnimatedImages(df_train=df_train,
                         config=config,
                         image_extent=[df_visualization.x.min(), df_visualization.x.max(),
                                       df_visualization.y.min(), df_visualization.y.max()],
                         df_test=df_visualization,
                         scale=15,  # 6e4
                         fps=6
                         )
    ani.save(plot_path / f"predicted_mass_{model_settings['sampling_method']}{rar_text}.mp4")
    logging.info("finished mp4.")
    # ani2 = AnimatedImages(df_train=df_train,
    #                       config=config,
    #                       image_extent=[df_visualization.x.min(), df_visualization.x.max(),
    #                                     df_visualization.y.min(), df_visualization.y.max()],
    #                       df_test=df_visualization_ground,
    #                       scale=15,  # 6e4
    #                       fps=6
    #                       )
    # ani2.save(PLOT_PATH / "predicted_mass_ground_{cargs['sampling_method']}.mp4")
    # fig, ax = plt.subplots(1, 1, figsize=(15, 5))
    # sns.lineplot(data=df_test_melted, x="t", y="value", hue="variable", ax=ax, ci="sd")
    # ax.set_title(f"Total mass over time. Experiment: {exp}")
    # plt.savefig(PLOT_PATH / f"density_over_time_{cargs['sampling_method']}.png")
    # plt.close()


def predict_to_df(df, m, dataloader):
    df = df.copy()
    velocity_hat_val, sqrt_density_hat_val, _, _ = batch_predict(model=m, dataloader=dataloader)
    df.loc[:, "u_hat"] = velocity_hat_val[..., 0]
    df.loc[:, "v_hat"] = velocity_hat_val[..., 1]
    df.loc[:, "w_hat"] = velocity_hat_val[..., 2]
    df.loc[:, "log_density_hat"] = sqrt_density_hat_val
    df.loc[:, "density_hat"] = sqrt_density_hat_val ** 2
    return df


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # parser.add_argument("nonlinearity", help="Nonlinearity to use", nargs='?',
    #                     choices=('sine', 'relu', 'sigmoid', 'tanh', 'selu', 'softplus', 'elu'),
    #                     default="sine")
    parser.add_argument("--experiment", default="", type=str)
    parser.add_argument("--device", help="Device to use", nargs='?',
                        choices=('cpu', 'cuda'),
                        default="cuda")
    parser.add_argument("--silent", action="store_true")
    parser.add_argument("--normalize-loss", action="store_true")
    parser.add_argument("--seed", default=1234, type=int)
    parser.add_argument("--sine-frequency", default=5, type=float)

    parser.add_argument("--no-plots", action="store_true")
    parser.add_argument("--mass-only", action="store_true")
    parser.add_argument("--regularize-hessian", action="store_true")

    parser.add_argument("--sampling-method",
                        choices=([str(e) for e in ST]),
                        default=str(ST.none))

    parser.add_argument("--n-samples", default=int(2 ** 12), dest="n_samples_constraints", type=int)
    parser.add_argument("--pde-weight", default=2e2, type=float)
    parser.add_argument("--rar", action="store_true")
    parser.add_argument("--ot-rar", action="store_true")

    cargs = vars(parser.parse_args())

    exp = "fluid_3d"

    PROJECT_PATH = Path(__file__).parent
    DATA_DIR = PROJECT_PATH / "data"
    DATA_3D = DATA_DIR / "data_3d_fixed_val"
    

    path_settings = dict(
        file=__file__,
        
        df_train=DATA_3D / "train_3d.snappy",
        df_val=  DATA_3D / "eval_3d.snappy",
        df_test= DATA_3D / "test_3d.snappy",

        results_csv=PROJECT_PATH / "notebooks/experiments_evaluation/results_3d.csv",
        plots=PROJECT_PATH / f"images/{exp}"
    )

    model_settings = OrderedDict(
        experiment=cargs["experiment"],
        sampling_method=ST[cargs["sampling_method"]],
        seed=cargs["seed"],
        device=cargs["device"],
        silent = cargs["silent"],
        mass_only = cargs["mass_only"],

        hidden_features=256,
        vf_hidden_features=256,
        density_hidden_layers=6,
        vf_hidden_layers=3,

        num_its=300,
        nonlinearity="sine",

        sine_frequency=cargs["sine_frequency"],
        n_samples_constraints=cargs["n_samples_constraints"],
        lr=1e-4,
        pde_weight=cargs["pde_weight"],
        # pde_weight=2e3,
        normalize_loss=cargs["normalize_loss"],
        include_boundary_coditions=False,
        weight_velocity_loss=4,

        tmin=-.5, tmax=3.5,
        xmin=-3, xmax=3,
        ymin=-3, ymax=3,
        zmin=-.125, zmax=1.0,

        dataset=str(path_settings["df_train"].parent),
        rar=cargs["rar"],
        ot_rar=cargs["ot_rar"],
        no_plots=cargs["no_plots"],
        run_num=0
    )
    
    
    main(model_settings, path_settings)
