from pickle import NONE
from typing import Tuple
import matplotlib.pyplot as plt
import torch
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
import numpy as np

from datetime import datetime

import argparse
from pathlib import Path
import os
import logging
from itertools import chain
import csv

import seaborn as sns
import pandas as pd
from collections import OrderedDict

from pdPINN.model.siren import batch_predict, dataloader_from_np
from pdPINN.util.plot_util import AnimatedImages
from pdPINN.model.siren_2d import MassCons2d
from pdPINN.util.system_util import save_json
from pdPINN.model.model_utiities import ST
from pprint import pprint, pformat

from create_data import config
import fasteners


def load_dataframes(path_settings: dict) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    df_train, df_val, df_test = [(pd.read_parquet(path_settings[name])
                                  .assign(sqrt_density=lambda tmp_df: np.sqrt(tmp_df["density"])))
                                 for name in ["df_train", "df_val", "df_test"]]

    for df in [df_train, df_val, df_test]:
        df.loc[df.u.isnull(), "u"] = 0
        df.loc[df.v.isnull(), "v"] = 0

    df_train["sqrt_density"] = np.clip(np.random.normal(df_train.sqrt_density,
                                                        df_train.sqrt_density.std() / 20,
                                                        df_train.sqrt_density.shape),
                                       a_min=0, a_max=None)

    df_train["u"] += np.random.normal(0,
                                      df_train.u.std() / 20, df_train.u.shape)
    df_train["v"] += np.random.normal(0,
                                      df_train.v.std() / 20, df_train.v.shape)

    return df_train, df_val, df_test


def main(model_settings: dict, path_settings: dict):
    # setup
    logging.info("\n" + pformat(model_settings))
    logging.info("\n" + pformat(path_settings))

    np.random.seed(model_settings["seed"])
    torch.random.manual_seed(model_settings["seed"])

    loss_weights = dict(
        nlogprob_density=0.0005,
        velocity=200.,
        pde=model_settings["pde_weight"]
    )

    # load data
    df_train, df_val, df_test = load_dataframes(path_settings)

    # set target vals and create dataloaders
    target_vars = ["sqrt_density", "u", "v"]
    feature_vars = ["t", "x", "y"]

    X_train, X_val, X_test = [df.loc[:, feature_vars].values.astype(np.float32)
                              for df in [df_train, df_val, df_test]]
    Y_train, Y_val, Y_test = [df.loc[:, target_vars].values.astype(np.float32)
                              for df in [df_train, df_val, df_test]]
    mb_size = X_train.shape[0]  # //4  # 1050

    train_dataloader = dataloader_from_np(
        [X_train, Y_train], mb_size, shuffle=True)
    train_dataloader_noshuffle = dataloader_from_np(
        [X_train, Y_train], mb_size, shuffle=False)
    val_dataloader = dataloader_from_np(
        [X_val, Y_val], X_val.shape[0] // 5, shuffle=False)
    test_dataloader = dataloader_from_np(
        [X_test, Y_test], X_test.shape[0] // 5, shuffle=False)

    # remove unneeded terms in loss
    if model_settings["mass_only"]:
        del loss_weights["velocity"]
    if model_settings["sampling_method"] == ST.none:
        del loss_weights["pde"]

    if model_settings["device"] == "cuda" and not torch.cuda.is_available():
        model_settings["device"] = "cpu"


    # setup model
    m = MassCons2d(X=X_train, Y=Y_train,
                   **model_settings)
    m.to(model_settings["device"])
    m._compile(X_train)

    # setup optimizer and schedulers
    optim = torch.optim.Adam(lr=8e-4, params=m.parameters())
    # scheduler = ReduceLROnPlateau(optim, 'min')
    scheduler = LambdaLR(optim, lambda ep: 0.99 ** ep)
    m.do_sample(model_settings["n_samples_constraints"],
                sampling_method=ST.uniform if model_settings[
                                                  "sampling_method"] == ST.uniform else ST.dirichlet,
                first=True
                )
    m.fraction_mcmc = 0.01
    total_step = 0

    try:
        for epoch in range(model_settings["num_its"]):
            for step, (x_, vf_true) in enumerate(train_dataloader):
                x_, vf_true = x_.to(m.device), vf_true.to(m.device)
                optim.zero_grad()

                if epoch % 10 == 0 and epoch > 0:
                    m.do_sample(
                        model_settings["n_samples_constraints"], model_settings["sampling_method"])

                    if model_settings["sampling_method"] not in [ST.uniform, ST.none, ST.importance_sampling,
                                                                 ST.it_pdpinn] and m.fraction_mcmc < .8:
                        m.fraction_mcmc = min(.8, m.fraction_mcmc + 0.05)
                        logging.info(f"{m.fraction_mcmc:.2f}")
                        if model_settings["sampling_method"] not in ["it-mcmc"]:
                            logging.info(
                                f"acceptance rate: {m.sampler.acceptance_rate * 100:.1f}%")

                total_loss, loss_dict_val = m.training_loss(x_,
                                                            target=vf_true,
                                                            weight_dict=loss_weights,
                                                            )
                total_loss.backward()
                optim.step()

            if (epoch + 1) % 50 == 0 or epoch == 0:
                logging.info(f"\nEpoch: {epoch}")
                logging.info(f"Train.: "
                             f"u:\t{loss_dict_val['u'] / df_train['u'].var():.3f}, "
                             f"v:\t{loss_dict_val['v'] / df_train['v'].var():.3f}, "
                             f"\tSqrtDensity: {loss_dict_val['sqrt_density'] / df_train['sqrt_density'].var():.3f}, "
                             f"\tMass-Conservation: {loss_dict_val['pde']:.3f},"
                             )

                with torch.no_grad():
                    m.eval()
                    loss_dict_val = batch_loss(m, val_dataloader)
                    logging.info(f"Val.: VF\t{loss_dict_val['u'] / df_val['u'].var():.3f}, "
                                 f"VF\t{loss_dict_val['v'] / df_val['v'].var():.3f}, "
                                 f"\tSqrtDensity: {loss_dict_val['sqrt_density'] / df_val['sqrt_density'].var():.3f}"
                                 )
                    m.train()

            scheduler.step()
            total_step += 1
            optim.zero_grad()
        RETURN_STATE = "SUCCESS"
    except KeyboardInterrupt as exc:
        RETURN_STATE = repr(exc)
        print(f"Manually interrupted training at epoch {epoch}")
    m.eval()

    # save results to file
    with torch.no_grad():
        loss_dict_train = get_loss_dict(
            df_train, m, train_dataloader_noshuffle, postfix="_train", epochs=epoch)
        loss_dict_val = get_loss_dict(
            df_val, m, val_dataloader, postfix="_val")
        loss_dict_test = get_loss_dict(
            df_test, m, test_dataloader, postfix="_test")
        loss_dict_combined = OrderedDict(chain(loss_dict_train.items(),
                                               loss_dict_val.items(),
                                               loss_dict_test.items(),
                                               [(key, str(item)) for key, item in model_settings.items()]
                                               #                                       model_settings.items()
                                               )
                                         )

    loss_dict_combined["return_state"] = RETURN_STATE
    loss_dict_combined["datetime"] = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")

    save_results(loss_dict_combined, path_settings["results_csv"])
    save_json(loss_dict_combined, model_settings, prefix="d2_")

    if path_settings.get("plots", None) is not None and not model_settings["no_plots"]:
        pprint(loss_dict_combined)
        sns.set(style="whitegrid", font_scale=1.2, palette="rocket_r")
        plot_stuff(m, test_dataloader, df_train, df_test, path_settings["plots"])
    return loss_dict_combined


def plot_stuff(m: MassCons2d, test_dataloader, df_train: pd.DataFrame, df_test: pd.DataFrame, plot_path: str):
    # plotting
    os.makedirs(plot_path, exist_ok=True)

    velocity_hat, sqrt_density_hat, _, _ = batch_predict(m, test_dataloader)

    df_test = df_test.assign(
        u_hat=lambda _: velocity_hat[..., 0],
        v_hat=lambda _: velocity_hat[..., 1],
        sqrt_density_hat=lambda _: sqrt_density_hat,
        density_hat=lambda _: sqrt_density_hat ** 2,
    )

    rar_text = "_rar" if model_settings["rar"] else ""
    otrar_text = "_otrar" if model_settings["ot_rar"] else ""
    rar_text = rar_text + otrar_text

    ani = AnimatedImages(df_train=df_train,
                         config=config,
                         image_extent=[df_test.x.min(), df_test.x.max(),
                                       df_test.y.min(), df_test.y.max()],
                         df_test=df_test,
                         scale=15,  # 6e4
                         fps=6,
                         density_varname="sqrt_density_hat",
                         u_varname="u_hat",
                         v_varname="v_hat"
                         )
    ani.save(
        plot_path / f"predicted_mass_{model_settings['sampling_method']}{rar_text}.mp4")

    tmp = pd.melt(df_test.assign(id=df_test.index), id_vars=[
        'id', 't'], value_vars=['density', 'density_hat'])
    sns.lineplot(data=tmp, x="t", y="value", hue="variable")
    plt.title(f"Total mass over time. Experiment: {exp}")
    plt.savefig(
        plot_path / f"density_over_time_{model_settings['sampling_method']}{rar_text}.png")
    plt.close()


def save_results(loss_dict_combined, RESULT_CSV):
    os.makedirs(Path(RESULT_CSV).parent, exist_ok=True)
    CSV_EXISTS = os.path.isfile(RESULT_CSV)
    lock = fasteners.InterProcessLock("2d_lock.file")
    with lock:
        # with AtomicOpen(RESULT_CSV, 'a') as f_object:
        with open(RESULT_CSV, 'a') as f_object:
            dict_writer = csv.DictWriter(
                f_object, fieldnames=loss_dict_combined.keys())
            if not CSV_EXISTS:
                dict_writer.writeheader()  # file doesn't exist yet, write a header
            dict_writer.writerow(loss_dict_combined)


def get_loss_dict(df, model, dataloader, postfix="", **kwargs):
    velocity_hat_val, sqrt_density_hat_val, _, _ = batch_predict(
        model, dataloader)
    loss_dict = OrderedDict(
        density=((np.squeeze(sqrt_density_hat_val) - df["sqrt_density"].values) ** 2).mean() / df["sqrt_density"].std(
            ddof=0) ** 2,
        u=((velocity_hat_val[:, 0] - df["u_atradar"].values)
           ** 2).mean() / df["u_atradar"].std(ddof=0) ** 2,
        v=((velocity_hat_val[:, 1] - df["v_atradar"].values)
           ** 2).mean() / df["v_atradar"].std(ddof=0) ** 2,
        **kwargs
    )

    loss_dict = OrderedDict((key + postfix, val)
                            for key, val in loss_dict.items())

    return loss_dict


def batch_loss(model, dataloader):
    dict_list = []
    for step, (x_, vf_true) in enumerate(dataloader):
        x_, vf_true = x_.to(model.device), vf_true.to(model.device)
        loss = model.reconstruction_loss(x_, target=vf_true, return_numpy=True)
        if len(loss) < 1:
            continue
        dict_list.append(loss)
    loss = pd.DataFrame(dict_list).mean()
    return loss


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--nonlinearity", help="Nonlinearity to use", nargs='?',
                        choices=('sine', 'relu', 'sigmoid', 'tanh',
                                 'selu', 'softplus', 'elu'),
                        default="sine")

    parser.add_argument("--device", help="Device to use", nargs='?',
                        choices=('cpu', 'cuda'),
                        default="cuda")
    parser.add_argument("-silent", action="store_true")
    parser.add_argument("--rar", action="store_true")
    parser.add_argument("--ot-rar", action="store_true")

    parser.add_argument("--seed", default=1, type=int)
    parser.add_argument("--sine-frequency", default=12, type=float)

    parser.add_argument("-no-plots", action="store_true")
    parser.add_argument("-mass-only", action="store_true")
    # parser.add_argument("-no-background-loss", action="store_true")
    parser.add_argument("-regularize-hessian", action="store_true")

    # options for experiments
    parser.add_argument("--experiment", default="", type=str)

    parser.add_argument("--sampling-method",
                        choices=([str(e) for e in ST]),
                        default=str(ST.none))
    parser.add_argument("--n-samples", default=int(2 ** 6),
                        dest="n_samples_constraints", type=int)
    parser.add_argument("--mcmc-noiselevel", default=1e-2,
                        dest="mcmc_noiselevel", type=float)
    parser.add_argument("--pde-weight", default=1.1e0, type=float)
    args = parser.parse_args()

    settings = {"format": "%(asctime)s \t %(message)s",
                "datefmt": '%m-%d %H:%M:%S'}
    if args.silent:
        logging.basicConfig(level=logging.ERROR, **settings)

    else:
        logging.basicConfig(level=logging.INFO, **settings)

    exp = "fluid_2d"

    PROJECT_PATH = Path(__file__).parent

    PLOT_PATH = PROJECT_PATH / f"images/{exp}"

    os.makedirs(PLOT_PATH, exist_ok=True)

    DATA_DIR = PROJECT_PATH / "data"
    # DATA_2D = DATA_DIR / "data_2d"
    DATA_2D = DATA_DIR / "data_2d_fixed_val"
    DATA_2D_thr = DATA_DIR / "data_2d_thr"
    RESULT_CSV = PROJECT_PATH / "notebooks/experiments_evaluation/results_2d.csv"

    path_settings = dict(
        df_train=DATA_2D / "train_2d.snappy",
        df_val=DATA_2D / "eval_2d.snappy",
        df_test=DATA_2D / "test_2d.snappy",

        results_csv=PROJECT_PATH / "notebooks/experiments_evaluation/results_2d.csv",
        plots=PROJECT_PATH / f"images/{exp}"
    )
    path_settings = {key: val.resolve() for key, val in path_settings.items()}

    model_settings = dict(
        file=__file__,
        experiment=args.experiment,
        sampling_method=ST[args.sampling_method.replace("-", "_")],
        seed=args.seed,
        n_samples_constraints=args.n_samples_constraints,

        no_background_loss=ST[args.sampling_method.replace(
            "-", "_")] != ST.none,
        mass_only=args.mass_only,

        hidden_features=256,
        vf_hidden_features=64,
        density_hidden_layers=2,
        vf_hidden_layers=1,

        num_its=500,
        nonlinearity=args.nonlinearity,
        sine_frequency=args.sine_frequency,

        # mcmc=False,
        # mcmc_noiselevel=args.mcmc_noiselevel,  # 1e-5,

        tmin=-.5, tmax=3.5,
        xmin=-4, xmax=4,
        ymin=-4, ymax=4,

        rar=args.rar,

        device=args.device,
        pde_weight=args.pde_weight,
        ot_rar=args.ot_rar,

        no_plots=args.no_plots,
        run_num=0
    )

    main(model_settings, path_settings)
