import os
import time
from pathlib import Path

import hydra
import omegaconf
import torch
import wandb

from graphsmodel import PROJECT_ROOT
from graphsmodel.evaluation.banzhaf import compute_banzhaf
from graphsmodel.experiments.node_influence import MostInfluentialRemoval
from graphsmodel.utils import (
    convert_seconds_to_dhms,
    get_dataset,
)


def run_experiment(cfg, wandb_run):
    if cfg.train.deterministic:
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
        torch.use_deterministic_algorithms(True)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    dataset_name = str(cfg.dataset.name).replace("-", "_")
    model_name = str(cfg.train.model._target_).split(".")[-1]

    dataset = get_dataset(cfg)
    data = dataset[0]

    training_nodes = data.train_mask.nonzero().squeeze()
    val_nodes = data.val_mask.nonzero().squeeze()
    test_nodes = data.test_mask.nonzero().squeeze()

    n_train = len(training_nodes)
    n_val = len(val_nodes)
    n_test = len(test_nodes)
    n_nodes = data.x.shape[0]

    # length of the valuation tensor to use
    n_samples = (
        n_train
        if cfg.data.subset_mode == "train"
        else (
            n_val
            if cfg.data.subset_mode == "val"
            else n_test if cfg.data.subset_mode == "test" else n_nodes
        )
    )

    datamodel_alpha = f"alpha_{cfg.data.alpha}"

    for subsets_seed in [
        1,
        12,
        123,
        1234,
        12345,
        123456,
        1234567,
        12345678,
        123456789,
        1234567890,
    ]:

        results_dir = (
            cfg.core.storage_dir
            / model_name.lower()
            / f"{dataset_name.lower()}_seed_{cfg.data.data_seed}"
            / (f"subsets_seed_{subsets_seed}" if subsets_seed is not None else "")
            / f"{'induced' if cfg.task.induced_subgraph else 'unlabeled'}"
            / cfg.data.subset_mode
            / datamodel_alpha
        )

        # Remove any empty path segments (in case the condition was False)
        results_dir = Path(*filter(None, results_dir.parts))
        results_dir.mkdir(parents=True, exist_ok=True)

        print(f"Running influence experiment for {results_dir}")
        datamodel_res = torch.load(results_dir / "datamodel_res.pt")
        margins_dm = datamodel_res["margins_dm"]

        subset_res = torch.load(results_dir / "subsets_res.pt")
        subsets = subset_res["subsets"]
        logits = subset_res["logits"]
        sub_logits = subset_res["sub_logits"]
        sub_y = subset_res["sub_y"]

        (
            margins_banzhaf,
            _,
            _,
            _,
            _,
            _,
        ) = compute_banzhaf(
            subsets=subsets,
            logits=logits,
            sub_logits=sub_logits,
            sub_y=sub_y,
            data=data,
        )

        shapley_res = torch.load(results_dir.parent / "shapley_res.pt")
        margins_shap = shapley_res["margins_shap"]

        pc_winter_perms_res = torch.load(
            results_dir.parent.parent / "pc_winter_0_0.99_0.99_res.pt"
        )
        margins_pc = pc_winter_perms_res["margins_pc"]

        data_sources = [
            ("datamodel_margins", margins_dm),
            ("banzhaf_margins", margins_banzhaf),
            ("shap_margins", margins_shap),
            ("pc_margins", margins_pc),
        ]

        masks = [("test_perf", data.test_mask), ("val_perf", data.val_mask)]

        results = {}

        for source_name, margins in data_sources:
            for mask_name, mask in masks:
                key = f"{source_name}_{mask_name}"
                results[key] = {
                    "argsrt": margins[mask].mean(0).argsort()[::-1],
                    "perf": [],
                    "mask": data.test_mask,
                }

        most_influential_removal = MostInfluentialRemoval(
            cfg=cfg,
            data=data,
            top_k=500,
            train_signal_value=True,
            n_samples=n_samples,
            results_dir=results_dir,
        )

        most_influential_removal.run_experiment(results_dict=results)


@hydra.main(
    version_base=None, config_path=str(PROJECT_ROOT / "conf"), config_name="experiment"
)
def main(cfg: omegaconf.DictConfig):
    cpu_time_start = time.process_time()
    wall_time_start = time.time()

    config = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=False)
    if "wandb" in cfg.log:
        run = wandb.init(project=cfg.log.project, config=config)
    else:
        run = None
    run_experiment(cfg=cfg, wandb_run=run)

    if "wandb" in cfg.log:
        run.finish()

    cpu_time = time.process_time() - cpu_time_start
    wall_time = time.time() - wall_time_start

    cpu_days, cpu_hours, cpu_minutes, cpu_seconds = convert_seconds_to_dhms(cpu_time)
    wall_days, wall_hours, wall_minutes, wall_seconds = convert_seconds_to_dhms(
        wall_time
    )
    print(
        f"CPU time: {cpu_days} Days: {cpu_hours} Hours: {cpu_minutes} Min: {cpu_seconds:.2f} Sec"
    )
    print(
        f"Wall time: {wall_days} Days: {wall_hours} Hours: {wall_minutes} Min: {wall_seconds:.2f} Sec"
    )


if __name__ == "__main__":
    main()
