import hydra
import numpy as np
import torch
from loguru import logger
from omegaconf import DictConfig

from utils import initialize_and_load_model

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")


@hydra.main(config_path="./config/", config_name="config", version_base=None)
def main(cfg: DictConfig):
    assert cfg.seed_a == cfg.seed_b
    for strategy in ["wm", "am"]:
        for mode in [
            "non_oblivious",
            "oblivious",
            "decision_list",
            "decision_list_no_binary",
        ]:
            logger.info(f"strategy: {strategy}, mode: {mode}")
            model_b_name = f"{mode}_{cfg.model_b}"
            if mode == "non_oblivious":
                model_a_name = f"{strategy}_permuted_flip_{mode}_{cfg.model_a}"
            elif mode in ("decision_list", "decision_list_no_binary"):
                model_a_name = f"{strategy}_permuted_noop_{mode}_{cfg.model_a}"
            elif mode == "oblivious":
                model_a_name = f"{strategy}_permuted_orderflip_{mode}_{cfg.model_a}"
            else:
                raise NotImplementedError

            model_a = initialize_and_load_model(
                f"./output/model/{model_a_name}",
                cfg.dataset.input_dim,
                cfg.dataset.output_dim,
                cfg.depth,
                cfg.alpha,
                cfg.n_tree,
                mode,
            ).to(device)
            model_b = initialize_and_load_model(
                f"./output/model/{model_b_name}",
                cfg.dataset.input_dim,
                cfg.dataset.output_dim,
                cfg.depth,
                cfg.alpha,
                cfg.n_tree,
                mode,
            ).to(device)
            for (a_key, a_value), (b_key, b_value) in zip(
                model_a.state_dict().items(), model_b.state_dict().items()
            ):
                assert torch.isclose(
                    a_value, b_value
                ).all(), f"Parameters does not match ({mode}, {strategy}, {a_key}) at {a_value}"

    logger.info("All model parameters successfully match")


if __name__ == "__main__":
    main()
