import copy
import os
import time

import hydra
import omegaconf
import torch
import wandb

from graphsmodel import PROJECT_ROOT
from graphsmodel.experiments.node_influence import MostInfluentialRemoval
from graphsmodel.utils import (
    convert_seconds_to_dhms,
)
from graphsmodel.utils.dataset_utils import get_dataset


def run_experiment(cfg, wandb_run):
    if cfg.train.deterministic:
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
        torch.use_deterministic_algorithms(True)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    experiment_name = cfg.experiment_name
    dataset_name = str(cfg.dataset.name).replace("-", "_")
    model_name = str(cfg.train.model._target_).split(".")[-1]

    results_dir = (
        PROJECT_ROOT
        / "storage"
        / "transferability"
        / f"{dataset_name.lower()}_seed_{cfg.data.data_seed}"
    )
    results_dir.mkdir(parents=True, exist_ok=True)

    sgc_results_dir = (
        cfg.core.storage_dir
        / "sgc"
        / f"{dataset_name.lower()}_seed_{cfg.data.data_seed}"
        / f"{'induced' if cfg.task.induced_subgraph else 'unlabeled'}"
        / cfg.data.subset_mode
        / "alpha_0.1"
        / "nodes_influence"
    )

    sgc_influence_results = torch.load(sgc_results_dir / f"{experiment_name}.pt")

    dataset = get_dataset(cfg)
    data = dataset[0]

    training_nodes = data.train_mask.nonzero().squeeze()
    val_nodes = data.val_mask.nonzero().squeeze()
    test_nodes = data.test_mask.nonzero().squeeze()

    n_train = len(training_nodes)
    n_val = len(val_nodes)
    n_test = len(test_nodes)
    n_nodes = data.x.shape[0]

    # length of the valuation tensor to use
    n_samples = (
        n_train
        if cfg.data.subset_mode == "train"
        else (
            n_val
            if cfg.data.subset_mode == "val"
            else n_test if cfg.data.subset_mode == "test" else n_nodes
        )
    )

    model_transfer_results = {
        approach
        + f"_margins_{eval_metric}_perf": {
            "argsrt": copy.deepcopy(
                sgc_influence_results[approach + f"_margins_{eval_metric}_perf"][
                    "argsrt"
                ]
            ),
            "mask": copy.deepcopy(
                sgc_influence_results[approach + f"_margins_{eval_metric}_perf"]["mask"]
            ),
            "perf": [],
        }
        for approach in ["datamodel", "banzhaf"]
        for eval_metric in ["test", "val"]
    }

    most_influential_removal = MostInfluentialRemoval(
        cfg=cfg,
        data=data,
        top_k=500,
        train_signal_value=cfg.train_signal_value,
        n_samples=n_samples,
        results_dir=results_dir,
        filename_prefix=f"{model_name}_transfer",
    )

    _ = most_influential_removal.run_experiment(results_dict=model_transfer_results)


@hydra.main(
    version_base=None, config_path=str(PROJECT_ROOT / "conf"), config_name="experiment"
)
def main(cfg: omegaconf.DictConfig):
    cpu_time_start = time.process_time()
    wall_time_start = time.time()

    config = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=False)
    if "wandb" in cfg.log:
        run = wandb.init(project=cfg.log.project, config=config)
    else:
        run = None
    run_experiment(cfg=cfg, wandb_run=run)

    if "wandb" in cfg.log:
        run.finish()

    cpu_time = time.process_time() - cpu_time_start
    wall_time = time.time() - wall_time_start

    cpu_days, cpu_hours, cpu_minutes, cpu_seconds = convert_seconds_to_dhms(cpu_time)
    wall_days, wall_hours, wall_minutes, wall_seconds = convert_seconds_to_dhms(
        wall_time
    )
    print(
        f"CPU time: {cpu_days} Days: {cpu_hours} Hours: {cpu_minutes} Min: {cpu_seconds:.2f} Sec"
    )
    print(
        f"Wall time: {wall_days} Days: {wall_hours} Hours: {wall_minutes} Min: {wall_seconds:.2f} Sec"
    )


if __name__ == "__main__":
    main()
