import argparse
import os

import numpy as np
import pandas as pd
import xgboost as xgb
from aion_eval.benchmarks.desiddpayne.dataset import DESIDDPayneDatasetModule
from astropy.table import Table
from tqdm.auto import tqdm


def fit_xgb(
    x_train,
    y_train,
    x_test,
    y_test,
    means,
    stds,
    n_rounds=50,
):
    dtrain = xgb.DMatrix(x_train, label=y_train)
    dtest = xgb.DMatrix(x_test, label=y_test)

    params = {
        "device": "gpu",
        "objective": "reg:squarederror",
        "max_depth": 6,
        "eta": 0.1,
        "eval_metric": "rmse",
    }

    # Train the model
    tree = xgb.train(params, dtrain, n_rounds)
    preds = tree.predict(dtest) * stds + means
    return tree, preds


def atleast_2d(x):
    while x.ndim < 2:
        x = x[..., None]
    return x


def main(args):
    configs = pd.read_csv(args.config_file, delimiter="|")

    dm = DESIDDPayneDatasetModule(args.data_path, input_fields=[])
    dm.setup(None)

    for _, row in tqdm(configs.iterrows()):
        name, input_fields, num_examples = (
            row["name"],
            row["input_fields"].split(","),
            row["num_examples"],
        )

        if num_examples == -1:
            num_examples = 1_000_000_000

        x_train = np.concatenate(
            [atleast_2d(dm.train_data[k][:num_examples]) for k in input_fields], axis=1
        )  # b n
        x_test = np.concatenate(
            [atleast_2d(dm.val_data[k][:num_examples]) for k in input_fields], axis=1
        )  # b n

        y_train = np.concatenate(
            [atleast_2d(dm.train_data[k][:num_examples]) for k in dm.output_fields],
            axis=1,
        )
        y_test = np.concatenate(
            [atleast_2d(dm.val_data[k][:num_examples]) for k in dm.output_fields],
            axis=1,
        )

        means = np.mean(y_train, axis=0)[None, :]
        stds = np.std(y_train, axis=0)[None, :]

        y_train = (y_train - means) / stds
        y_test = (y_test - means) / stds

        tree, preds = fit_xgb(x_train, y_train, x_test, y_test, means, stds)

        preds = {k: v for k, v in zip(dm.output_fields, preds.T)}
        preds = Table(preds)
        preds.write(
            os.path.join(args.output_dir, f"xgbbaseline_{name}_{num_examples}.fits"),
            overwrite=True,
        )
        if args.save_models:
            tree.save_model(
                os.path.join(
                    args.output_dir, f"xgbbaseline_{name}_{num_examples}.model"
                )
            )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str, default="data")
    parser.add_argument("--save_models", type=bool, default=False)
    parser.add_argument(
        "--config_file", type=str, default="csv_runs/desiddpayne_baselines_v1.csv"
    )
    parser.add_argument(
        "--output_dir", type=str, default="data/analysis/desiddpayne/indomain_v1/"
    )
    args = parser.parse_args()
    main(args)
