import matplotlib.pyplot as plt
import matplotlib as mpl
import torch
import tqdm
import numpy as np

from datetime import datetime

import argparse
from pathlib import Path
import os
import logging
from typing import List, Tuple, Dict
import warnings
import seaborn as sns
import pandas as pd

import pdPINN.model.siren
from pdPINN.util.data_gen import generate_radar_positions, \
    ParticleCollection
from pdPINN.util.plot_util import AnimatedImages, AnimatedScatter, plot_to_image
from pdPINN.util.vecfield_util import generate_vf_functions, generate_vf_functions_time
from pdPINN.util.system_util import memory

sns.set(style="whitegrid", font_scale=1.5, palette="rocket_r")

config = dict(
    seed=6667,

    num_time_steps=11,
    num_time_steps_val=21,
    num_time_steps_test=101,
    # num_time_steps=21,
    # num_time_steps_test=21,

    xmin=-3., xmax=3,
    ymin=-3, ymax=3,
    zmin=0, zmax=1.,
    tmin=0, tmax=3., tmax_test=3.,

    img_dim=41,
    # img_dim=21,
    num_train=13 ** 2,
    num_val=60,
    num_altitude_layers=21,

    radius=0.08,
    rotation_only=False,
    divergence_only=False,

    num_landings=2,
    threshold=0.3,

    # num_density_samples=5_000,
    birds_per_flock=30, number_of_swarms=50, number_of_flocks=50,

    dt_eulersolve=1e-2,

    file_suffix="_fixed_val"
)

xmin, xmax, ymin, ymax, zmin, zmax, tmin, tmax = [config[v] for v in
                                                  ["xmin", "xmax", "ymin", "ymax", "zmin", "zmax", "tmin", "tmax"]]
tmax_test = config["tmax_test"]


# this limits RAM usage to 90% and stops the program instead of freezing your whole workstation.
@memory(0.9)
def advection():
    np.random.seed(config["seed"])
    torch.random.manual_seed(config["seed"])

    # landing_periodicity = config["num_landings"] / (config["tmax"] - 0.5)
    landing_periodicity = config["num_landings"] / (config["tmax"])
    vf_funs = generate_vf_functions(name="landing", landing_periodicity=landing_periodicity,
                                    rotation_only=config["rotation_only"], divergence_only=config["divergence_only"])
    print("sampling and simulating particles...")
    particle_collection = ParticleCollection(birds_per_flock=config["birds_per_flock"],
                                             number_of_swarms=config["number_of_swarms"],
                                             number_of_flocks=config["number_of_flocks"],
                                             vf_funs=vf_funs, tmax=max(tmax_test, tmax),
                                             dt=config["dt_eulersolve"])  # sample from density field
    # idx = np.random.randint(0, particle_collection[2.5].shape[0], 5_000)
    # plot_parts = particle_collection[2.5][idx,:]
    #
    # fig = plt.figure()
    # ax = fig.add_subplot(projection='3d')
    # ax.scatter(*plot_parts.T, marker='+')

    # time steps for each data set
    tsteps_train = np.linspace(0, tmax, config["num_time_steps"], endpoint=True)
    tsteps_val = np.linspace(0, tmax, config["num_time_steps_val"], endpoint=True)
    tsteps_test = np.linspace(0, tmax_test, config["num_time_steps_test"], endpoint=True)

    # different altitude bins for each dataset
    altitudes = np.linspace(zmin, zmax, config["num_altitude_layers"])
    altitudes_test = np.linspace(zmin - 0.1, zmax, config["num_altitude_layers"] + 4)

    print("count particles at radar positions...")
    train_radar_xyz = generate_radar_positions(config["num_altitude_layers"], altitudes, config["num_train"],
                                               extent=(xmin, xmax, ymin, ymax), noise=0.06)
    val_radar_xyz = generate_radar_positions(config["num_altitude_layers"], altitudes, config["num_val"],
                                             extent=(xmin,# - .5,
                                                     xmax,# + .5,
                                                     ymin - .5, ymax + .5),
                                             uniform_samples=True,
                                             noise=0.08)

    xx = np.linspace(xmin - .5, xmax + .5, config["img_dim"])
    yy = np.linspace(ymin - .5, ymax + .5, config["img_dim"])
    # XYZ = np.meshgrid(xx,
    #                   yy,
    #                   altitudes_test)
    # test_data_xyz = np.stack(XYZ, -1).reshape(-1, len(XYZ))

    df_test_3d = pd.concat(particle_collection.density_on_grid(t, grid_list=[xx, yy, altitudes_test])
                           for t in tsteps_test)

    df_train_3d = pd.concat(particle_collection.density_at_radars(t,
                                                                  radar_positions=np.unique(train_radar_xyz, axis=0),
                                                                  altitude_bins=altitudes, radius_xy=config["radius"])
                            for t in tsteps_train)
    df_val_3d = pd.concat(particle_collection.density_at_radars(t,
                                                                radar_positions=np.unique(val_radar_xyz, axis=0),
                                                                altitude_bins=altitudes, radius_xy=config["radius"])
                          for t in tsteps_val)
    if not df_train_3d.isnull().sum().sum() == df_val_3d.isnull().sum().sum() == df_test_3d.isnull().sum().sum() == 0:
        warnings.warn("Some nan entries found.")
    print("generate data..")
    df_val_1d, df_test_1d, df_train_1d = gen_1d_data(df_test_3d)
    df_test_2d, df_train_2d, df_val_2d = gen_2d_data(df_test_3d, df_train_3d, df_val_3d)
    df_test_2d_thr, df_train_2d_thr, df_val_2d_thr = gen_2d_thresholded_data(df_test_3d, df_train_3d, df_val_3d)

    print("save data...")
    os.makedirs(DATA_1D, exist_ok=True)
    df_train_1d.to_parquet(DATA_1D / "train_1d.snappy")
    df_val_1d.to_parquet(DATA_1D / "eval_1d.snappy")
    df_test_1d.to_parquet(DATA_1D / "test_1d.snappy")

    os.makedirs(DATA_2D, exist_ok=True)
    df_train_2d.to_parquet(DATA_2D / "train_2d.snappy")
    df_val_2d.to_parquet(DATA_2D / "eval_2d.snappy")
    df_test_2d.to_parquet(DATA_2D / "test_2d.snappy")

    os.makedirs(DATA_2D_thr, exist_ok=True)
    df_train_2d_thr.to_parquet(DATA_2D_thr / "train_2d_thr.snappy")
    df_val_2d_thr.to_parquet(DATA_2D_thr / "eval_2d_thr.snappy")
    df_test_2d_thr.to_parquet(DATA_2D_thr / "test_2d_thr.snappy")

    os.makedirs(DATA_3D, exist_ok=True)
    df_train_3d.to_parquet(DATA_3D / "train_3d.snappy")
    df_val_3d.to_parquet(DATA_3D / "eval_3d.snappy")
    df_test_3d.to_parquet(DATA_3D / "test_3d.snappy")

    print("generate plots...")

    plot_1d_data(df_test_1d, df_train_1d)
    plot_2d_data(df_test_2d, df_train_2d)

    # xy_solution = ode_result.sol(tsteps_test).reshape(-1, 2, len(tsteps_test))
    xy_solution = np.stack([particle_collection[i][:, :2] for i in tsteps_test], -1)
    plot_2d_scatter(xy_solution, tsteps_test)

    plot_2d_data(df_test_2d_thr, df_train_2d_thr, filename="data_2d_thr.mp4")

    # plt.close()
    # fig = plt.figure()
    # ax = fig.add_subplot(projection='3d')
    #
    # cmap = plt.cm.RdYlGn
    # cmap.set_bad('white', alpha=0)

    # tmp = df_test_3d.query("t==0")
    # mask = lambda arr: np.ma.masked_where(tmp.mass <= 50, arr)
    # ax.scatter(mask(tmp.x.values), mask(tmp.y.values), mask(tmp.z.values), cmap=cmap)
    # plt.show()
    # plt.close()

    plot_3D_1Dproj(df_test_1d, df_train_3d)


def plot_3D_1Dproj(df_test_1d, df_train_3d):
    fig, axs = plt.subplots(1, 1, figsize=(10, 4))
    df_train = df_train_3d.query(f"z>{config['threshold']}").copy()
    df_train.loc[df_train["density"] == 0, ["u", "v"]] = np.nan
    aggregation = dict(mass="sum",
                       density="sum",
                       u=np.nanmean,
                       v=np.nanmean
                       #    u_atradar: np.nanmean, v_atradar: np.nanmean
                       )
    df_train = df_train.groupby(["t", "z"]).aggregate(aggregation).reset_index()
    zdiff = np.diff(df_test_1d.z.unique())[-1]
    # plt.suptitle("Density in altitude bins over time")
    sc = axs.scatter(df_train.t, df_train.z, c=np.sqrt(df_train.density), marker='o', cmap="Blues", linewidth=.5,
                     edgecolor='black')
    axs.axhline(df_train.z.min() - zdiff / 2, color="red", label="vertical threshold")
    axs.set_xlabel("time")
    axs.set_ylabel("altitude")
    axs.set_xlim(df_test_1d.t.min(), df_test_1d.t.max())
    axs.set_ylim(df_test_1d.z.min(), df_test_1d.z.max())
    axs.set_title("Observed Data")
    cb = plt.colorbar(sc, ax=axs)
    cb.set_label("Density")
    plt.tight_layout()
    fig.savefig(PLOT_PATH / "vertical_mass_over_time_3D.png")
    plt.close()


def plot_2d_data(df_test_2d, df_train_2d, filename="data_2d.mp4"):
    # df_test_2d.density
    df_test_copy = df_test_2d.copy()
    df_test_copy["density_sqrt"] = np.sqrt(df_test_copy["density"])

    ani = AnimatedImages(df_train=df_train_2d,
                         config=config,
                         image_extent=[df_test_2d.x.min(), df_test_2d.x.max(),
                                       df_test_2d.y.min(), df_test_2d.y.max()],
                         df_test=df_test_copy, fps=6,
                         density_varname="density_sqrt"
                         )
    ani.save(PLOT_PATH / filename)


def plot_2d_scatter(xy_solution, tsteps_test):
    ani_scatter = AnimatedScatter(solutions_y=xy_solution,
                                  timesteps=tsteps_test, fps=6)
    ani_scatter.save(PLOT_PATH / "data_2d_scatter.mp4")


def gen_2d_data(df_test, df_train, df_val):
    df_train = df_train.copy().reset_index()
    df_val = df_val.copy().reset_index()
    df_test = df_test.copy().reset_index()

    df_train.loc[df_train["density"] == 0, ["u_atradar", "v_atradar"]] = 0.  # np.nan
    df_val.loc[df_val["density"] == 0, ["u_atradar", "v_atradar"]] = 0.  # np.nan
    df_test.loc[df_test["density"] == 0, ["u_atradar", "v_atradar"]] = 0.  # np.nan

    # add a little so that not all weights in np.average are zero.
    df_train["w_density"] = df_train["density"] + 1e-10
    df_test["w_density"] = df_train["density"] + 1e-10
    df_val["w_density"] = df_train["density"] + 1e-10

    def weighted_average(df: pd.DataFrame, data_col: List[str], weight_col: str, by_cols: List[str],
                         other_aggregations: Dict[str, str]) -> pd.DataFrame:
        """
        Calculate weighted average over list of columns data_cols, grouped by columns by_cols,
        with weights in column weight_col.
        The other_aggregations are a dictionary that is fed to df.agg.

        Avoids copying dataframes, much faster than df.agg with custom functions.

        https://stackoverflow.com/questions/10951341/pandas-dataframe-aggregate-function-using-multiple-columns/44683506#44683506
        Args:
            df (pd.Dataframe):
            data_col ():
            weight_col ():
            by_cols ():
            other_aggregations ():

        Returns: new grouped pandas dataframe.

        """
        g = df.groupby(by_cols)
        result = g.aggregate(other_aggregations)
        for col in data_col:
            col_dw = f'_data_times_weight_{col}'
            col_w = f'_weight_where_notnull_{col}'

            df[col_dw] = df[col] * df[weight_col]
            df[col_w] = df[weight_col] * pd.notnull(df[col])
            result[col] = (g[col_dw].sum() / g[col_w].sum()).values
            del df[col_dw], df[col_w]

        return result

    df_train_2d, df_val_2d, df_test_2d = [weighted_average(df=df,
                                                           data_col=["u", "v", "u_atradar", "v_atradar"],
                                                           weight_col="density",
                                                           by_cols=["x", "y", "t"],
                                                           other_aggregations={"mass": "sum",
                                                                               "density": "sum"}).reset_index()
                                          for df in [df_train, df_val, df_test]]

    # df_train_2d, df_val_2d, df_test_2d = [(df.groupby(["x", "y", "t"])
    #                                        .aggregate({"mass": "sum", "density": "sum",
    #                                                    "u": np.nanmean,  # create_weighted_mean(df, "density"),
    #                                                    "v": np.nanmean  # create_weighted_mean(df, "density"),
    #                                                    # create_weighted_mean(df, "density"),
    #                                                    "u_atradar": np.nanmean,  # create_weighted_mean(df, "density"),
    #                                                    "v_atradar": np.nanmean,  # create_weighted_mean(df, "density")
    #                                                    }).reset_index())
    #                                       for df in [df_train, df_val, df_test]]

    # df_train["density"] += np.clip(np.random.normal(0, df_train.density.std() / 20, df_train.density.shape),
    #                                a_min=0., a_max=None)
    # df_train["u"] += np.random.normal(0, df_train.u.std() / 20, df_train.u.shape)
    # df_train["v"] += np.random.normal(0, df_train.v.std() / 20, df_train.v.shape)

    return df_test_2d, df_train_2d, df_val_2d


def gen_2d_thresholded_data(df_test, df_train, df_val):
    df_train = df_train.copy()
    df_val = df_val.copy()
    df_test = df_test.copy()

    df_train.loc[df_train["density"] == 0, ["u", "v"]] = np.nan
    df_val.loc[df_val["density"] == 0, ["u", "v"]] = np.nan
    df_test.loc[df_test["density"] == 0, ["u", "v"]] = np.nan

    aggregation = dict(mass="sum",
                       density="sum",
                       u=np.nanmean,
                       v=np.nanmean,
                       u_atradar=np.nanmean, v_atradar=np.nanmean
                       )

    df_train_2d = df_train.query(f"z>{config['threshold']}").groupby(["x", "y", "t"]).aggregate(
        aggregation).reset_index()
    df_val_2d = df_val.query(f"z>{config['threshold']}").groupby(["x", "y", "t"]).aggregate(aggregation).reset_index()
    df_test_2d = df_test.query(f"z>{config['threshold']}").groupby(["x", "y", "t"]).aggregate(aggregation).reset_index()

    df_train_2d.loc[df_train_2d.u.isnull(), ["u", "v"]] = 0.
    df_val_2d.loc[df_val_2d.u.isnull(), ["u", "v"]] = 0.
    df_test_2d.loc[df_test_2d.u.isnull(), ["u", "v"]] = 0.

    # df_train["density"] += np.clip(np.random.normal(0, df_train.density.std() / 20, df_train.density.shape),
    #                                a_min=0., a_max=None)
    # df_train["u"] += np.random.normal(0, df_train.u.std() / 20, df_train.u.shape)
    # df_train["v"] += np.random.normal(0, df_train.v.std() / 20, df_train.v.shape)

    return df_test_2d, df_train_2d, df_val_2d


def gen_1d_data(df_test) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    df_all = df_test.groupby(["z", "t"]).aggregate({"mass": "sum",
                                                    "density": "sum",
                                                    "w": "mean"}).reset_index()
    df_all.loc[:, "sqrt_density"] = np.sqrt(df_all.density)
    df_all.loc[:, "log1p_density"] = np.log1p(df_all.density)
    # df_all.loc[:, "density"] = np.sqrt(df_all.density)
    altitudes = df_all.z.unique().tolist()
    altitudes_idx = [altitudes.index(tmp) + 1 for tmp in df_all.z.values]
    assert np.allclose([altitudes[i - 1] for i in altitudes_idx], df_all.z.values)
    df_all.loc[:, "z_idx"] = altitudes_idx
    msk = np.random.rand(len(df_all)) < 0.2
    df_train = df_all.sort_values(by=["z_idx", "t"])[msk].query(f"z>{config['threshold']}")

    # df_train["sqrt_density"] += np.clip(np.random.normal(0, df_train.density.std() / 20, df_train.density.shape),
    #                                a_min=0., a_max=None)
    # df_train["sqrt_density"] += np.random.normal(0, df_train["sqrt_density"].std() / 20, df_train.sqrt_density.shape)

    # df_train["w"] += np.random.normal(0, df_train.w.std() / 20, df_train.w.shape)

    df_val = df_all.sort_values(by=["z_idx", "t"])[~msk]  # .iloc[1::2, :]
    df_test = df_all.sort_values(by=["z_idx", "t"])  # [~msk]  # .iloc[1::2, :]

    df_train.loc[df_train.w.isna(), "w"] = 0.
    df_test.loc[df_test.w.isna(), "w"] = 0.
    df_val.loc[df_val.w.isna(), "w"] = 0.

    return df_val, df_test, df_train


def plot_1d_data(df_test, df_train):
    zdiff = np.diff(df_test.z.unique())[-1]
    tdiff = np.diff(df_test.t.unique())[-1]
    num_time_steps = len(df_test.t.unique())
    density_true_img = df_test.density.values.reshape((-1, num_time_steps))
    imshow_settings = dict(
        origin='lower',
        cmap="Blues",
        extent=[df_test.t.min() - tdiff / 2, df_test.t.max() + tdiff / 2,
                df_test.z.min() - zdiff / 2, df_test.z.max() + zdiff / 2],
        aspect='auto', interpolation='none', resample=False,
        alpha=0.7)
    quiver_settings = dict(angles='xy',
                           scale_units='xy',
                           scale=1e9)
    df_test.loc[:, "flux"] = df_test.loc[:, "w"] * df_test.loc[:, "density"] ** 2
    norms = {"flux": mpl.colors.Normalize(vmin=df_test.flux.min(), vmax=df_test.flux.max()),
             "density": mpl.colors.Normalize(vmin=(df_test.density).min(), vmax=(df_test.density).max())}
    # plt.figure(figsize=(10, 4))
    fig, axs = plt.subplots(2, 1, figsize=(10, 8))
    axs = axs.flatten()
    im = axs[0].imshow(density_true_img, norm=norms["density"], **imshow_settings)
    axs[0].axhline(df_train.z.min() - zdiff / 2, color="red", label="vertical threshold")
    cb = plt.colorbar(im, ax=axs[0])
    cb.set_label("Density")
    axs[0].quiver(df_test.t, df_test.z, np.zeros_like(df_test.flux.values), df_test.flux, **quiver_settings)
    axs[0].set_xlabel("time")
    axs[0].set_ylabel("altitude")
    axs[0].set_title("Groundtruth")

    # plt.suptitle("Density in altitude bins over time")
    sc = axs[1].scatter(df_train.t, df_train.z, c=df_train.density, marker='o', cmap="Blues", linewidth=.5,
                        edgecolor='black')
    axs[1].axhline(df_train.z.min() - zdiff / 2, color="red", label="vertical threshold")
    axs[1].set_xlabel("time")
    axs[1].set_ylabel("altitude")
    axs[1].set_xlim(imshow_settings["extent"][0:2])
    axs[1].set_ylim(imshow_settings["extent"][2:])
    axs[1].set_title("Observed Data")
    cb = plt.colorbar(sc, ax=axs[1])
    cb.set_label("Density")

    plt.tight_layout()
    fig.savefig(PLOT_PATH / "vertical_mass_over_time.png")

    plt.close(fig)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", default=1234, type=int)

    parser.add_argument("-no-plots", action="store_true")
    args = parser.parse_args()

    logger = logging.getLogger('experiment_mass_changing')
    PROJECT_PATH = Path(__file__).parent

    DATA_DIR = PROJECT_PATH / f"data"
    DATA_1D = DATA_DIR / f"data_1d{config['file_suffix']}"
    DATA_2D = DATA_DIR / f"data_2d{config['file_suffix']}"
    DATA_2D_thr = DATA_DIR / f"data_2d_thr{config['file_suffix']}"
    DATA_3D = DATA_DIR / f"data_3d{config['file_suffix']}"
    settings = {"format": "%(asctime)s \t %(message)s",
                "datefmt": '%m-%d %H:%M'}

    PLOT_PATH = PROJECT_PATH / f"images/data{config['file_suffix']}"
    os.makedirs(PLOT_PATH, exist_ok=True)
    NOW = datetime.now(tz=None).strftime("%Y_%m_%d__%H_%M_%S")
    CACHE_PATH = PROJECT_PATH / f".cache"

    advection()
