import truststore

truststore.inject_into_ssl()

import argparse
import datetime
import itertools
import pickle

import git
from easy_tpp.config_factory import Config
from easy_tpp.runner import Runner
from easy_tpp.utils import create_folder, save_yaml_config, set_seed


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--config_dir",
        type=str,
        required=False,
        default="/content/hhp-iclr2026/hyper-hawkes-process/EasyTPP/examples/configs/exp_config_taxi.yaml",
        help="Dir of configuration yaml to train and evaluate the model.",
    )

    parser.add_argument(
        "--experiment_id", type=str, required=False, default="HHP_train", help="Experiment id in the config file."
    )

    args = parser.parse_args()
    config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id)

    # Grab some git information.  # TODO: sync git info to server
    try:
        git_commit = git.Repo(search_parent_directories=True).head.object.hexsha
        git_branch = git.Repo(search_parent_directories=True).active_branch
        git_is_dirty = git.Repo(search_parent_directories=True).is_dirty()
    except:
        print("Failed to grab git info...")
        git_commit = "NoneFound"
        git_branch = "NoneFound"
        git_is_dirty = "NoneFound"

    i = 0
    results = {
        "date": datetime.date.today(),
        "gitcommit": git_commit,
        "git_branch": git_branch,
        "git_is_dirty": git_is_dirty,
    }

    # file_suffix = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    # save_path = config.base_config.base_dir + f'dlhp/{config.base_config.dataset_id}/' + file_suffix + '/'
    save_path = config.base_config.specs["saved_model_dir"]

    create_folder(save_path)

    best_valid_ll = float("-inf")
    best_valid_res = {}

    # HP:
    #     "non_latent": [True,]

    # LHP:
    #     "hidden_size": [8, 16, 32, 64, 128],
    #     "complex_latent": [False, True],

    # HHP:
    #     "hidden_size": [8, 16, 32, 64, 128],
    #     "complex_latent": [False, True],

    #     "full_A": [False, True],

    #     "predict_A": [False, True],
    #     "predict_B": [False, True],
    #     "predict_C": [False, True],
    #     "predict_S": [False, True],

    #     "hyper_type": ['GRU', 'LSTM', 'Transformer'],
    #     "hyper_size": [8, 16, 32, 64, 128],
    #     "hyper_layers": [1, 2, 3],

    grid = {
        "loss_integral_num_sample_per_step": [
            100,
        ],
        "use_mc_samples": [
            True,
        ],
        "full_C": [
            True,
        ],
        "non_latent": [
            False,
        ],
        "hidden_size": [128, 256],  # 32, 64, 128, 256], # [4, 8, 16, 32, 64, 128],
        "complex_latent": [False, True],  # True],
        "full_A": [
            False,
        ],
        "predict_A": [
            True,
        ],
        "predict_B": [False, True],  # , True],
        "predict_C": [
            False,
        ],
        "predict_S": [True],
        "hyper_type": [
            "GRU",
        ],
        "hyper_size": [16, 32, 64],  # 4, 8, 16,],
        "hyper_layers": [2],
        "normalize_A": [
            False,
        ],
        "normalize_B": [
            False,
        ],
        "orthogonal_A": [
            True,
        ],
        "num_rotations": [8],
    }
    grid = [list(zip([k] * len(v), v)) for k, v in grid.items()]

    try:
        for i, options in enumerate(itertools.product(*grid)):
            for k, v in options:
                setattr(config.model_config, k, v)

            print(f"Start training model with {options}")
            save_yaml_config(f"{save_path}/model_{i}.yaml", config)
            set_seed(config.trainer_config.seed)
            model_runner = Runner.build_from_config(config)  # , unique_model_dir=True)

            try:
                res = model_runner.run(save_model_id=f"_{i}")
            except KeyboardInterrupt as e:
                raise
            except Exception as e:
                print(f"Encountered error. Skipping {options}")
                continue

            # save configs for model i
            res["model_id"] = i
            res["config"] = config
            res["params"] = dict(options)
            results[i] = res
            if res["best_valid_ll"] > best_valid_ll:
                best_valid_ll = res["best_valid_ll"]
                best_valid_res = res

            with open(save_path + "_results.pkl", "wb") as f:
                pickle.dump(results, f)

            for k, val in results.items():
                print(f"Model {k}:")
                print(val)

            # print(results)
            print(f"Best valid ll so far: {best_valid_ll}")
            print(best_valid_res)

            print("===" * 10)
            print("\n" * 5)
            print("===" * 10)

    except KeyboardInterrupt as e:
        pass

    with open(save_path + "_results.pkl", "wb") as f:
        pickle.dump(results, f)
    # print(results)
    for k, val in results.items():
        print(f"Model {k}:")
        print(val)
    print("Experiment finished")

    with open(save_path + "_best_results.pkl", "wb") as f:
        pickle.dump(best_valid_res, f)

    print(f"Global best validation ll: {best_valid_ll}")
    print(best_valid_res)


if __name__ == "__main__":
    main()
