import copy
import os
import time

import hydra
import numpy as np
import omegaconf
import torch
import wandb
from torch_geometric.seed import seed_everything
from torch_geometric.utils import degree

from graphsmodel import PROJECT_ROOT
from graphsmodel.evaluation import compute_banzhaf, compute_datamodels, compute_loo
from graphsmodel.experiments.node_influence import (
    LeastInfluentialAddition,
    LeastInfluentialRemoval,
    MostInfluentialAddition,
    MostInfluentialRemoval,
)
from graphsmodel.training import get_logits, node_level_subset_training
from graphsmodel.utils import (
    convert_seconds_to_dhms,
    get_appropriate_edge_index,
    get_dataset,
    get_margin_incorrect_vectorized,
    get_model,
    get_optimizer,
    initialize_dict_results,
)


def run_experiment(cfg, wandb_run):
    if cfg.train.deterministic:
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
        torch.use_deterministic_algorithms(True)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    dataset = get_dataset(cfg)
    data = dataset[0]

    training_nodes = data.train_mask.nonzero().squeeze()
    val_nodes = data.val_mask.nonzero().squeeze()
    test_nodes = data.test_mask.nonzero().squeeze()

    if cfg.data.subset_mode == "train":
        nodes_to_select = training_nodes
    elif cfg.data.subset_mode == "val":
        nodes_to_select = val_nodes
    elif cfg.data.subset_mode == "test":
        nodes_to_select = test_nodes
    else:
        nodes_to_select = torch.arange(data.num_nodes)

    # length of the valuation tensor to use
    n_samples = len(nodes_to_select)

    datamodel_alpha = f"alpha_{cfg.data.alpha}"
    results_dir = cfg.core.results_dir / datamodel_alpha / "nodes_influence"
    results_dir.mkdir(parents=True, exist_ok=True)

    print(results_dir)

    ####### Uncomment to double check if the linear regression approach approximates correctly the logistic regression #######
    # adj = pyg.utils.to_scipy_sparse_matrix(data.edge_index)

    # deg_sq = sp.diags(np.power(adj.sum(1).A1, -1/2))
    # lap = deg_sq @ adj @ deg_sq
    # diff_x = lap @ lap @ data.x.numpy()

    # model = linear_model.Ridge(alpha=1, fit_intercept=False)
    # model.fit(diff_x[data.train_mask], np.eye(7)[data.y[data.train_mask].numpy()])
    # print(f"sklearn accuracy: (model.predict(diff_x[data.test_mask]).argmax(1) == data.y[data.test_mask].numpy()).mean()")
    ##########################################################################################################################

    true_logits = []
    for model_seed in range(cfg.train.n_models):
        seed_everything(model_seed)
        model = get_model(cfg, data, num_classes=dataset.num_classes)
        if "_target_" in cfg.train.optimizer:
            optimizer = get_optimizer(cfg.train.optimizer, model)
            _, _, _, _, model_logits = node_level_subset_training(
                data=data,
                model=model,
                optimizer=optimizer,
                patience=cfg.train.patience,
                max_epochs=cfg.train.epochs,
            )
        else:
            edge_index = get_appropriate_edge_index(data)
            fit_args = {
                "x": data.x,
                "edge_index": edge_index,
                "y": data.y,
                "train_mask": data.train_mask,
            }
            model.fit(**fit_args)
            model_logits = get_logits(data, model)
        true_logits.append(model_logits)
    true_logits = torch.stack(true_logits)
    true_margins = (
        get_margin_incorrect_vectorized(true_logits.unsqueeze(0), data.y)
        .squeeze(0)
        .numpy()
        .mean(0)
    )
    true_test_acc = (true_margins > 0)[data.test_mask].mean().item()
    true_val_acc = (true_margins > 0)[data.val_mask].mean().item()

    print(f"True test accuracy: {true_test_acc}, True val accuracy: {true_val_acc}")

    # DATAMODEL
    try:
        datamodel_res = torch.load(results_dir.parent / "datamodel_res.pt")
        margins_dm = datamodel_res["margins_dm"]
        test_accs_dm = datamodel_res["test_accs_dm"]
        val_accs_dm = datamodel_res["val_accs_dm"]
        sub_margins_dm = datamodel_res["sub_margins_dm"]
        sub_test_accs_dm = datamodel_res["sub_test_accs_dm"]
        sub_val_accs_dm = datamodel_res["sub_val_accs_dm"]

        subset_res = torch.load(results_dir.parent / "subsets_res.pt")
        subsets = subset_res["subsets"]
        logits = subset_res["logits"]
        sub_logits = subset_res["sub_logits"]
        sub_y = subset_res["sub_y"]

        print("Datamodel results found.")

    except FileNotFoundError:
        print("No datamodel results found. Training the datamodels.")
        if (results_dir.parent / "subsets_res.pt").exists():
            subset_res = torch.load(results_dir.parent / "subsets_res.pt")
        else:
            raise FileNotFoundError(
                "No subsets found. Run the `train_subsets` script first."
            )
        subsets = subset_res["subsets"]
        logits = subset_res["logits"]
        sub_logits = subset_res["sub_logits"]
        sub_y = subset_res["sub_y"]

        (
            margins_dm,
            test_accs_dm,
            val_accs_dm,
            sub_margins_dm,
            sub_test_accs_dm,
            sub_val_accs_dm,
        ) = compute_datamodels(
            subset_mode=cfg.data.subset_mode,
            subsets=subsets,
            logits=logits,
            sub_logits=sub_logits,
            sub_y=sub_y,
            data=data,
            n_jobs=cfg.n_jobs,
        )

        # Save the results
        datamodel_res = {
            "margins_dm": margins_dm,
            "test_accs_dm": test_accs_dm,
            "val_accs_dm": val_accs_dm,
            "sub_margins_dm": sub_margins_dm,
            "sub_test_accs_dm": sub_test_accs_dm,
            "sub_val_accs_dm": sub_val_accs_dm,
        }
        torch.save(datamodel_res, results_dir.parent / "datamodel_res.pt")

    # BANZHAF
    (
        margins_banzhaf,
        test_accs_banzhaf,
        val_accs_banzhaf,
        sub_margins_banzhaf,
        sub_test_accs_banzhaf,
        sub_val_accs_banzhaf,
    ) = compute_banzhaf(
        subsets=subsets,
        logits=logits,
        sub_logits=sub_logits,
        sub_y=sub_y,
        data=data,
        subset_mode=cfg.data.subset_mode,
    )
    print("Banzhaf computed.")

    # LOO
    if (results_dir.parent.parent / "loo_res.pt").exists():
        loo_res = torch.load(results_dir.parent.parent / "loo_res.pt")
    else:
        raise FileNotFoundError(
            "No LOO results found. Run the `train_subsets` script first."
        )

    loo_logits = loo_res["logits"]
    loo_sub_logits = loo_res["sub_logits"]
    loo_sub_y = loo_res["sub_y"]

    (
        margins_loo,
        test_accs_loo,
        val_accs_loo,
        sub_margins_loo,
        sub_test_accs_loo,
        sub_val_accs_loo,
    ) = compute_loo(
        logits=loo_logits,
        sub_logits=loo_sub_logits,
        sub_y=loo_sub_y,
        data=data,
        true_test_acc=true_test_acc,
        true_val_acc=true_val_acc,
        true_margins=true_margins,
    )
    print("LOO computed.")

    # SHAPLEY
    if (results_dir.parent.parent / "shapley_res.pt").exists():
        shapley_res = torch.load(results_dir.parent.parent / "shapley_res.pt")
    else:
        raise FileNotFoundError(
            "No shapley results found. Run the `train_subsets` script first."
        )

    sub_test_accs_shap = shapley_res["sub_test_accs_shap"]
    sub_val_accs_shap = shapley_res["sub_val_accs_shap"]
    test_accs_shap = shapley_res["test_accs_shap"]
    val_accs_shap = shapley_res["val_accs_shap"]
    margins_shap = shapley_res["margins_shap"]
    sub_margins_shap = shapley_res["sub_margins_shap"]

    print("Shapley results found.")

    # PC-Winter
    if (results_dir.parent.parent.parent / "pc_winter_0_0.5_0.7_res.pt").exists():
        pc_winter_res = torch.load(
            results_dir.parent.parent.parent / "pc_winter_0_0.5_0.7_res.pt"
        )
    else:
        raise FileNotFoundError(
            "No PC-Winter results found. Run the `train_subsets` script first."
        )

    sub_test_accs_pc = pc_winter_res["sub_test_accs_pc"][nodes_to_select]
    sub_val_accs_pc = pc_winter_res["sub_val_accs_pc"][nodes_to_select]
    test_accs_pc = pc_winter_res["test_accs_pc"][nodes_to_select]
    val_accs_pc = pc_winter_res["val_accs_pc"][nodes_to_select]
    margins_pc = pc_winter_res["margins_pc"][:, nodes_to_select]
    sub_margins_pc = pc_winter_res["sub_margins_pc"][:, nodes_to_select]

    print("PC-Winter results found.")

    # PC-Winter permutations
    if (results_dir.parent.parent.parent / "pc_winter_0_0.99_0.99_res.pt").exists():
        pc_winter_perms_res = torch.load(
            results_dir.parent.parent.parent / "pc_winter_0_0.99_0.99_res.pt"
        )
    else:
        raise FileNotFoundError(
            "No PC-Winter results found. Run the `train_subsets` script first."
        )

    sub_test_accs_perms_pc = pc_winter_perms_res["sub_test_accs_pc"][nodes_to_select]
    sub_val_accs_perms_pc = pc_winter_perms_res["sub_val_accs_pc"][nodes_to_select]
    test_accs_perms_pc = pc_winter_perms_res["test_accs_pc"][nodes_to_select]
    val_accs_perms_pc = pc_winter_perms_res["val_accs_pc"][nodes_to_select]
    margins_perms_pc = pc_winter_perms_res["margins_pc"][:, nodes_to_select]
    sub_margins_perms_pc = pc_winter_perms_res["sub_margins_pc"][:, nodes_to_select]

    print("PC-Winter results found.")

    assert (
        margins_dm.shape
        == margins_banzhaf.shape
        == margins_loo.shape
        == margins_shap.shape
        == margins_pc.shape
        == margins_perms_pc.shape
    )

    train_signal_value = cfg.train_signal_value

    sub_margins_dm[nodes_to_select, np.arange(n_samples)] = np.nan
    sub_margins_banzhaf[nodes_to_select, np.arange(n_samples)] = np.nan
    sub_margins_loo[nodes_to_select, np.arange(n_samples)] = np.nan
    sub_margins_shap[nodes_to_select, np.arange(n_samples)] = np.nan
    sub_margins_pc[nodes_to_select, np.arange(n_samples)] = np.nan
    sub_margins_perms_pc[nodes_to_select, np.arange(n_samples)] = np.nan

    approaches = {
        "datamodel": (
            margins_dm,
            sub_margins_dm,
            test_accs_dm,
            val_accs_dm,
            sub_test_accs_dm,
            sub_val_accs_dm,
        ),
        "banzhaf": (
            margins_banzhaf,
            sub_margins_banzhaf,
            test_accs_banzhaf,
            val_accs_banzhaf,
            sub_test_accs_banzhaf,
            sub_val_accs_banzhaf,
        ),
        "loo": (
            margins_loo,
            sub_margins_loo,
            test_accs_loo,
            val_accs_loo,
            sub_test_accs_loo,
            sub_val_accs_loo,
        ),
        "shap": (
            margins_shap,
            sub_margins_shap,
            test_accs_shap,
            val_accs_shap,
            sub_test_accs_shap,
            sub_val_accs_shap,
        ),
        "pc": (
            margins_pc,
            sub_margins_pc,
            test_accs_pc,
            val_accs_pc,
            sub_test_accs_pc,
            sub_val_accs_pc,
        ),
        "pc_perms": (
            margins_perms_pc,
            sub_margins_perms_pc,
            test_accs_perms_pc,
            val_accs_perms_pc,
            sub_test_accs_perms_pc,
            sub_val_accs_perms_pc,
        ),
    }

    margins_suffixes = ["margins_perf", "margins_test_perf", "margins_val_perf"]
    sub_margins_suffixes = [
        "sub_margins_perf",
        "sub_test_margins_perf",
        "sub_val_margins_perf",
    ]
    accs_suffixes = ["test_accs_perf", "val_accs_perf"]
    sub_accs_suffixes = ["sub_test_accs_perf", "sub_val_accs_perf"]

    most_influent_initialization = initialize_dict_results(
        approaches=approaches,
        margins_suffixes=margins_suffixes,
        sub_margins_suffixes=sub_margins_suffixes,
        accs_suffixes=accs_suffixes,
        sub_accs_suffixes=sub_accs_suffixes,
        test_mask=data.test_mask,
        val_mask=data.val_mask,
        eval_mask=data.test_mask,
        reverse=True,
        rnd=np.random.permutation(n_samples),
        degree=degree(data.edge_index[0], num_nodes=data.num_nodes)[nodes_to_select],
    )
    least_influent_initialization = initialize_dict_results(
        approaches=approaches,
        margins_suffixes=margins_suffixes,
        sub_margins_suffixes=sub_margins_suffixes,
        accs_suffixes=accs_suffixes,
        sub_accs_suffixes=sub_accs_suffixes,
        test_mask=data.test_mask,
        val_mask=data.val_mask,
        eval_mask=data.test_mask,
        reverse=False,
        rnd=np.random.permutation(n_samples),
        degree=degree(data.edge_index[0], num_nodes=data.num_nodes)[nodes_to_select],
    )

    # MOST INFLUENTIAL REMOVAL
    top_k = n_samples if cfg.data.subset_mode in ["train", "val"] else 500

    most_influential_removal = MostInfluentialRemoval(
        cfg=cfg,
        data=data,
        top_k=top_k,
        train_signal_value=train_signal_value,
        n_samples=n_samples,
        results_dir=results_dir,
    )

    most_influential_removal.run_experiment(
        results_dict=copy.deepcopy(most_influent_initialization)
    )

    # MOST INFLUENTIAL ADDITION
    top_k = n_samples if cfg.data.subset_mode in ["train", "val"] else data.num_nodes

    most_influential_addition = MostInfluentialAddition(
        cfg=cfg,
        data=data,
        top_k=top_k,
        train_signal_value=train_signal_value,
        n_samples=n_samples,
        results_dir=results_dir,
    )
    most_influential_addition.run_experiment(
        results_dict=copy.deepcopy(most_influent_initialization)
    )

    # LEAST INFLUENTIAL ADDITION
    top_k = n_samples if cfg.data.subset_mode in ["train", "val"] else data.num_nodes

    least_influential_add = LeastInfluentialAddition(
        cfg=cfg,
        data=data,
        top_k=top_k,
        train_signal_value=train_signal_value,
        n_samples=n_samples,
        results_dir=results_dir,
    )
    least_influential_add.run_experiment(
        results_dict=copy.deepcopy(least_influent_initialization)
    )

    # LEAST INFLUENTIAL REMOVAL
    top_k = n_samples if cfg.data.subset_mode in ["train", "val"] else 500

    least_influential_removal = LeastInfluentialRemoval(
        cfg=cfg,
        data=data,
        top_k=top_k,
        train_signal_value=train_signal_value,
        n_samples=n_samples,
        results_dir=results_dir,
    )
    least_influential_removal.run_experiment(
        results_dict=copy.deepcopy(least_influent_initialization)
    )


@hydra.main(
    version_base=None, config_path=str(PROJECT_ROOT / "conf"), config_name="experiment"
)
def main(cfg: omegaconf.DictConfig):
    cpu_time_start = time.process_time()
    wall_time_start = time.time()

    config = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=False)
    if "wandb" in cfg.log:
        run = wandb.init(project=cfg.log.project, config=config)
    else:
        run = None
    run_experiment(cfg=cfg, wandb_run=run)

    if "wandb" in cfg.log:
        run.finish()

    cpu_time = time.process_time() - cpu_time_start
    wall_time = time.time() - wall_time_start

    cpu_days, cpu_hours, cpu_minutes, cpu_seconds = convert_seconds_to_dhms(cpu_time)
    wall_days, wall_hours, wall_minutes, wall_seconds = convert_seconds_to_dhms(
        wall_time
    )
    print(
        f"CPU time: {cpu_days} Days: {cpu_hours} Hours: {cpu_minutes} Min: {cpu_seconds:.2f} Sec"
    )
    print(
        f"Wall time: {wall_days} Days: {wall_hours} Hours: {wall_minutes} Min: {wall_seconds:.2f} Sec"
    )


if __name__ == "__main__":
    main()
