import copy
import os
import time

import hydra
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import torch
import wandb
from scipy.stats import spearmanr
from torch_geometric.seed import seed_everything

from graphsmodel import PROJECT_ROOT
from graphsmodel.evaluation import (
    compute_data_shapley,
    compute_loo,
    compute_pc_winter,
    generate_maps,
    generate_permutations,
    train_datamodels,
)
from graphsmodel.training import get_logits, node_level_subset_training, train_subset
from graphsmodel.utils import (
    convert_seconds_to_dhms,
    get_appropriate_edge_index,
    get_dataset,
    get_margin_incorrect_vectorized,
    get_model,
    get_optimizer,
    get_subset,
)
from graphsmodel.utils.visualization import correlation_plot, get_figsize


def run_experiment(cfg, wandb_run):
    if cfg.train.deterministic:
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
        torch.use_deterministic_algorithms(True)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    dataset = get_dataset(cfg)
    data = dataset[0]

    training_nodes = data.train_mask.nonzero().squeeze()
    val_nodes = data.val_mask.nonzero().squeeze()
    test_nodes = data.test_mask.nonzero().squeeze()

    if cfg.data.subset_mode == "train":
        nodes_to_select = training_nodes
    elif cfg.data.subset_mode == "val":
        nodes_to_select = val_nodes
    elif cfg.data.subset_mode == "test":
        nodes_to_select = test_nodes
    else:
        nodes_to_select = torch.arange(data.num_nodes)

    # length of the valuation tensor to use
    n_samples = len(nodes_to_select)

    datamodel_alpha = f"alpha_{cfg.data.alpha}"
    results_dir = cfg.core.results_dir / datamodel_alpha / "subset_size_influence"
    results_dir.mkdir(parents=True, exist_ok=True)

    ####### Uncomment to double check if the linear regression approach approximates correctly the logistic regression #######
    # adj = pyg.utils.to_scipy_sparse_matrix(data.edge_index)

    # deg_sq = sp.diags(np.power(adj.sum(1).A1, -1/2))
    # lap = deg_sq @ adj @ deg_sq
    # diff_x = lap @ lap @ data.x.numpy()

    # model = linear_model.Ridge(alpha=1, fit_intercept=False)
    # model.fit(diff_x[data.train_mask], np.eye(7)[data.y[data.train_mask].numpy()])
    # print(f"sklearn accuracy: (model.predict(diff_x[data.test_mask]).argmax(1) == data.y[data.test_mask].numpy()).mean()")
    ##########################################################################################################################

    true_logits = []
    for model_seed in range(cfg.train.n_models):
        seed_everything(model_seed)
        model = get_model(cfg, data, num_classes=dataset.num_classes)
        if "_target_" in cfg.train.optimizer:
            optimizer = get_optimizer(cfg.train.optimizer, model)
            _, _, _, _, model_logits = node_level_subset_training(
                data=data,
                model=model,
                optimizer=optimizer,
                patience=cfg.train.patience,
                max_epochs=cfg.train.epochs,
            )
        else:
            edge_index = get_appropriate_edge_index(data)
            model.fit(data.x, edge_index, data.y, data.train_mask)
            model_logits = get_logits(data, model)
        true_logits.append(model_logits)
    true_logits = torch.stack(true_logits)
    true_margins = (
        get_margin_incorrect_vectorized(true_logits.unsqueeze(0), data.y)
        .squeeze(0)
        .numpy()
        .mean(0)
    )
    true_test_acc = (true_margins > 0)[data.test_mask].mean().item()
    true_val_acc = (true_margins > 0)[data.val_mask].mean().item()

    print(f"True test accuracy: {true_test_acc}, True val accuracy: {true_val_acc}")

    try:
        res = torch.load(results_dir / "subset_size_influence_corr.pt")
        dm_margins_per_size = res["dm_margins"]
        shap_margins_per_size = res["shap_margins"]
        banzhaf_margins_per_size = res["banzhaf_margins"]
        pc_margins_per_size = res["pc_margins"]
        pc_perms_margins_per_size = res["pc_perms_margins"]
        train_subsets_per_size = res["train_subsets"]
        test_subsets_per_size = res["test_subsets"]
        train_margins_per_size = res["train_margins"]
        test_margins_per_size = res["test_margins"]
        n_trains_per_size = res["n_trains"]
        subset_sizes = res["subset_sizes"]

    except FileNotFoundError:
        if (results_dir.parent / "subsets_res.pt").exists():
            subset_res = torch.load(results_dir.parent / "subsets_res.pt")
        else:
            raise FileNotFoundError(
                "No subsets found. Run the `train_subsets` script first."
            )
        subsets = subset_res["subsets"]
        logits = subset_res["logits"]
        margins = get_margin_incorrect_vectorized(logits, data.y).mean(1).numpy()

        dm_margins_per_size = []
        shap_margins_per_size = []
        banzhaf_margins_per_size = []
        pc_margins_per_size = []
        pc_perms_margins_per_size = []
        train_subsets_per_size = []
        test_subsets_per_size = []
        train_margins_per_size = []
        test_margins_per_size = []
        n_trains_per_size = []

        subset_sizes = [1000, 2500, 5000, 10000, 25000, 50000]
        for n_subsets in subset_sizes:
            # DATAMODEL
            n_train_datamodel = int(0.9 * n_subsets)
            train_subsets = subsets[:n_train_datamodel]
            test_subsets = subsets[n_train_datamodel:n_subsets]
            train_margins = margins[:n_train_datamodel]
            test_margins = margins[n_train_datamodel:n_subsets]

            margins_dm, _ = train_datamodels(
                mode=cfg.data.subset_mode,
                x=train_subsets,
                y=train_margins,
                train_mask=data.train_mask,
                val_mask=data.val_mask,
                test_mask=data.test_mask,
                n_jobs=-1,
            )

            # SHAP
            p_trunc = 0.25
            n_perms = (
                np.round(n_train_datamodel / (n_samples * p_trunc)).astype(int).item()
            )

            perms = []
            for _ in range(1, n_perms + 1):
                rnd_perm = np.random.permutation(n_samples)
                perms.append(rnd_perm)

            (
                _,
                _,
                _,
                _,
                _,
                margins_shap,
                _,
            ) = compute_data_shapley(perms, n_samples, p_trunc, cfg, data, n_jobs=-1)

            # BANZHAF
            nodes_subsets_in = train_subsets.sum(0).A1
            nodes_subsets_notin = train_subsets.shape[0] - nodes_subsets_in

            margins_banzhaf = (
                np.array(
                    (1 / nodes_subsets_in[:, None]) * (train_subsets.T @ train_margins)
                    - (1 / nodes_subsets_notin[:, None])
                    * np.array((np.array([1]) - train_subsets).T @ train_margins)
                )
                .squeeze()
                .T
            )

            # PC
            labeled_node_list = data.train_mask.nonzero(as_tuple=True)[0].tolist()

            labeled_to_player_map, _, _ = generate_maps(
                train_idx_list=labeled_node_list,
                edge_index=data.edge_index,
                num_nodes=data.num_nodes,
            )

            perms = generate_permutations(
                n_subsets=n_train_datamodel,
                labeled_node_list=labeled_node_list,
                label_trunc_ratio=0.0,
                group_trunc_ratio_hop_1=0.5,
                group_trunc_ratio_hop_2=0.7,
                labeled_to_player_map=labeled_to_player_map,
            )

            (
                _,
                _,
                _,
                _,
                margins_pc,
                _,
            ) = compute_pc_winter(
                perms,
                n_samples=data.num_nodes,
                cfg=copy.deepcopy(cfg),
                data=data,
                n_jobs=-1,
            )
            margins_pc = margins_pc[:, nodes_to_select]

            # PC-PERMS
            perms = generate_permutations(
                n_subsets=n_train_datamodel,
                labeled_node_list=labeled_node_list,
                label_trunc_ratio=0.0,
                group_trunc_ratio_hop_1=0.99,
                group_trunc_ratio_hop_2=0.99,
                labeled_to_player_map=labeled_to_player_map,
            )

            (
                _,
                _,
                _,
                _,
                margins_pc_perms,
                _,
            ) = compute_pc_winter(
                perms,
                n_samples=data.num_nodes,
                cfg=copy.deepcopy(cfg),
                data=data,
                n_jobs=-1,
            )
            margins_pc_perms = margins_pc_perms[:, nodes_to_select]

            dm_margins_per_size.append(margins_dm)
            shap_margins_per_size.append(margins_shap)
            banzhaf_margins_per_size.append(margins_banzhaf)
            pc_margins_per_size.append(margins_pc)
            pc_perms_margins_per_size.append(margins_pc_perms)
            train_subsets_per_size.append(train_subsets)
            test_subsets_per_size.append(test_subsets)
            train_margins_per_size.append(train_margins)
            test_margins_per_size.append(test_margins)
            n_trains_per_size.append(n_train_datamodel)
            ##########################################

        res = {
            "dm_margins": dm_margins_per_size,
            "shap_margins": shap_margins_per_size,
            "banzhaf_margins": banzhaf_margins_per_size,
            "pc_margins": pc_margins_per_size,
            "pc_perms_margins": pc_perms_margins_per_size,
            "train_subsets": train_subsets_per_size,
            "test_subsets": test_subsets_per_size,
            "train_margins": train_margins_per_size,
            "test_margins": test_margins_per_size,
            "n_trains": n_trains_per_size,
            "subset_sizes": subset_sizes,
        }
        torch.save(res, results_dir / "subset_size_influence_corr.pt")

    results_dir = cfg.core.results_dir / datamodel_alpha / "counterfactual"
    results_dir.mkdir(parents=True, exist_ok=True)

    try:
        counter_res = torch.load(results_dir / "counterfactual_res.pt")
        margins_dm = counter_res["margins_dm"]
        margins_shap = counter_res["margins_shap"]
        margins_banzhaf = counter_res["margins_banzhaf"]
        margins_pc = counter_res["margins_pc"]
        margins_pc_perms = counter_res["margins_pc_perms"]
        n_train_datamodel = counter_res["n_train_datamodel"]
        alphas_subsets = counter_res["alphas_subsets"]
    except FileNotFoundError:
        if (results_dir.parent / "subsets_res.pt").exists():
            subset_res = torch.load(results_dir.parent / "subsets_res.pt")
        else:
            raise FileNotFoundError(
                "No subsets found. Run the `train_subsets` script first."
            )
        subsets = subset_res["subsets"]
        logits = subset_res["logits"]
        margins = get_margin_incorrect_vectorized(logits, data.y).mean(1).numpy()

        alphas = {0.1, 0.25, 0.5, 0.75, 0.9}

        alphas_subsets = {counter_alpha: ([], []) for counter_alpha in alphas}
        for counter_alpha in alphas_subsets.keys():
            seen_subsets = set()
            for _ in range(100):
                subset = get_subset(
                    seen_subsets, d=n_samples, alpha=counter_alpha
                ).numpy()
                _, _, _, data_logits = train_subset(
                    subset_idx=None,
                    subset=subset,
                    cfg=cfg,
                    data=data,
                    logits_on_data=True,
                )
                counter_margins = (
                    get_margin_incorrect_vectorized(data_logits.unsqueeze(0), data.y)
                    .squeeze(0)
                    .numpy()
                    .mean(0)
                )
                alphas_subsets[counter_alpha][0].append(subset)
                alphas_subsets[counter_alpha][1].append(counter_margins)

        counter_res = {
            "margins_dm": dm_margins_per_size[-1],
            "margins_shap": shap_margins_per_size[-1],
            "margins_banzhaf": banzhaf_margins_per_size[-1],
            "margins_pc": pc_margins_per_size[-1],
            "margins_pc_perms": pc_perms_margins_per_size[-1],
            "n_train_datamodel": n_trains_per_size[-1],
            "test_subsets": test_subsets_per_size[-1],
            "test_margins": test_margins_per_size[-1],
            "alphas_subsets": alphas_subsets,
        }
        torch.save(counter_res, results_dir / "counterfactual_res.pt")

    # LOO
    if (results_dir.parent.parent / "loo_res.pt").exists():
        loo_res = torch.load(results_dir.parent.parent / "loo_res.pt")
    else:
        raise FileNotFoundError(
            "No LOO results found. Run the `train_subsets` script first."
        )

    loo_logits = loo_res["logits"]
    loo_sub_logits = loo_res["sub_logits"]
    loo_sub_y = loo_res["sub_y"]

    (
        margins_loo,
        _,
        _,
        _,
        _,
        _,
    ) = compute_loo(
        logits=loo_logits,
        sub_logits=loo_sub_logits,
        sub_y=loo_sub_y,
        data=data,
        true_test_acc=true_test_acc,
        true_val_acc=true_val_acc,
        true_margins=true_margins,
    )

    test_subsets = counter_res["test_subsets"]
    test_margins = counter_res["test_margins"]
    margins_dm = counter_res["margins_dm"]
    margins_shap = counter_res["margins_shap"]
    margins_banzhaf = counter_res["margins_banzhaf"]
    margins_pc = counter_res["margins_pc"]
    margins_pc_perms = counter_res["margins_pc_perms"]

    datamodel_preds = test_subsets @ margins_dm.T
    shap_preds = test_subsets @ margins_shap.T
    banzhaf_preds = test_subsets @ margins_banzhaf.T
    loo_preds = test_subsets @ margins_loo.T
    pc_preds = test_subsets @ margins_pc.T
    pc_perms_preds = test_subsets @ margins_pc_perms.T

    fig, axes = plt.subplots(
        ncols=6, nrows=1, figsize=get_figsize(ncols=6, nrows=1), constrained_layout=True
    )
    correlation_plot(datamodel_preds, test_margins, "Datamodel", ax=axes[0])
    correlation_plot(shap_preds, test_margins, "Shap", ax=axes[1])
    correlation_plot(banzhaf_preds, test_margins, "Banzhaf", ax=axes[2])
    correlation_plot(loo_preds, test_margins, "LOO", ax=axes[3])
    correlation_plot(pc_preds, test_margins, "PC-Winter", ax=axes[4])
    correlation_plot(pc_perms_preds, test_margins, "PC-Winter perms", ax=axes[5])

    fig.savefig(results_dir / "correlation.png", dpi=300, bbox_inches="tight")

    ### Counterfactual predictions
    fig, ax = plt.subplots(
        ncols=1, nrows=1, figsize=get_figsize(ncols=1, nrows=1), constrained_layout=True
    )
    for counter_alpha, (counter_subsets, counter_margins) in alphas_subsets.items():
        counter_subsets = np.array(counter_subsets)
        counter_margins = np.array(counter_margins)
        assert counter_margins.shape == (
            100,
            data.num_nodes,
        ) and counter_subsets.shape == (
            100,
            n_samples,
        )

        counter_preds = counter_subsets @ margins_dm.T
        ax.scatter(
            counter_preds.flatten(),
            counter_margins.flatten(),
            alpha=0.7,
            label=f"alpha={counter_alpha}",
        )
        if counter_alpha == cfg.data.alpha:
            spear = spearmanr(counter_preds.flatten(), counter_margins.flatten())
            ax.set_title(
                r"$\alpha = {}$\nSpearman: {}".format(counter_alpha, spear[0].round(2))
            )

    ax.set_xlabel("Predicted margins")
    ax.set_ylabel("True margins")
    ax.legend()
    fig.savefig(
        results_dir / "counterfactual_correlation.png", dpi=300, bbox_inches="tight"
    )


@hydra.main(
    version_base=None, config_path=str(PROJECT_ROOT / "conf"), config_name="experiment"
)
def main(cfg: omegaconf.DictConfig):
    cpu_time_start = time.process_time()
    wall_time_start = time.time()

    config = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=False)
    if "wandb" in cfg.log:
        run = wandb.init(project=cfg.log.project, config=config)
    else:
        run = None
    run_experiment(cfg=cfg, wandb_run=run)

    if "wandb" in cfg.log:
        run.finish()

    cpu_time = time.process_time() - cpu_time_start
    wall_time = time.time() - wall_time_start

    cpu_days, cpu_hours, cpu_minutes, cpu_seconds = convert_seconds_to_dhms(cpu_time)
    wall_days, wall_hours, wall_minutes, wall_seconds = convert_seconds_to_dhms(
        wall_time
    )
    print(
        f"CPU time: {cpu_days} Days: {cpu_hours} Hours: {cpu_minutes} Min: {cpu_seconds:.2f} Sec"
    )
    print(
        f"Wall time: {wall_days} Days: {wall_hours} Hours: {wall_minutes} Min: {wall_seconds:.2f} Sec"
    )


if __name__ == "__main__":
    main()
