import os
import pandas as pd
import torch
import yaml
import matplotlib.pyplot as plt
import numpy as np
from sbsep.util import get_data_random
from sbsep.config import load_yaml

if __name__ == "__main__":
    plot = True

    fig_folder = "./figs"

    fig_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), fig_folder)

    conf_folder = "../conf"
    conf_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), conf_folder)

    npoints = 10000
    (
        coordinates,
        yd,
        _,
        _,
    ) = get_data_random(npoints)
    xlima, xlimb = -5.0, 5.0
    pgrid = torch.linspace(xlima, xlimb, 11)
    table = []

    A = 1.0 / (yd.mean() * (xlimb - xlima))
    print(A)
    for pa, pb in zip(pgrid, pgrid[1:]):
        mask = (pa < coordinates[:, 0]) & (coordinates[:, 0] <= pb)
        coord_batch = coordinates[mask, 1]
        fbatch = yd[mask]
        if sum(mask) > 0:
            table += [(0.5 * (pa + pb).item(), A * torch.mean(fbatch).item())]
    arr = np.array(table)
    df = pd.DataFrame(arr, columns=["x", "f"])
    y = load_yaml(os.path.join(conf_folder, "normal_gen.yaml"))

    y["table"]["xs"] = df.x.tolist()
    y["table"]["ys"] = df.f.tolist()

    with open(os.path.join(conf_folder, "norm_gen.yaml"), "w") as outfile:
        yaml.dump(y, outfile, default_flow_style=False, indent=4)

    if plot:
        fig, ax = plt.subplots(figsize=(5, 5))
        ax.plot(df.x, df.f, alpha=0.7)
        plt.savefig(os.path.join(fig_folder, "norm_integrated.pdf"))
