import copy
import os
import time

import hydra
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import torch
import wandb
from scipy.stats import kendalltau
from torch_geometric.seed import seed_everything

from graphsmodel import PROJECT_ROOT
from graphsmodel.evaluation import compute_data_shapley, train_datamodels
from graphsmodel.evaluation.pc_winter import (
    compute_pc_winter,
    generate_maps,
    generate_permutations,
)
from graphsmodel.training import get_logits, node_level_subset_training
from graphsmodel.utils import (
    convert_seconds_to_dhms,
    get_appropriate_edge_index,
    get_dataset,
    get_margin_incorrect_vectorized,
    get_model,
    get_optimizer,
)
from graphsmodel.utils.visualization import get_figsize


def avg_rank_correlation(margins, ground_truth):
    assert margins.shape == ground_truth.shape

    per_node_tau = []
    n_nodes = margins.shape[0]
    for i in range(n_nodes):
        per_node_tau.append(kendalltau(margins[i], ground_truth[i]).correlation)

    return np.mean(per_node_tau), np.std(per_node_tau)


def run_experiment(cfg, wandb_run):
    if cfg.train.deterministic:
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
        torch.use_deterministic_algorithms(True)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    dataset = get_dataset(cfg)
    data = dataset[0]

    training_nodes = data.train_mask.nonzero().squeeze()
    val_nodes = data.val_mask.nonzero().squeeze()
    test_nodes = data.test_mask.nonzero().squeeze()

    if cfg.data.subset_mode == "train":
        nodes_to_select = training_nodes
    elif cfg.data.subset_mode == "val":
        nodes_to_select = val_nodes
    elif cfg.data.subset_mode == "test":
        nodes_to_select = test_nodes
    else:
        nodes_to_select = torch.arange(data.num_nodes)

    # length of the valuation tensor to use
    n_samples = len(nodes_to_select)

    datamodel_alpha = f"alpha_{cfg.data.alpha}"
    results_dir = cfg.core.results_dir / datamodel_alpha / "subset_size_influence"
    results_dir.mkdir(parents=True, exist_ok=True)

    ####### Uncomment to double check if the linear regression approach approximates correctly the logistic regression #######
    # adj = pyg.utils.to_scipy_sparse_matrix(data.edge_index)

    # deg_sq = sp.diags(np.power(adj.sum(1).A1, -1/2))
    # lap = deg_sq @ adj @ deg_sq
    # diff_x = lap @ lap @ data.x.numpy()

    # model = linear_model.Ridge(alpha=1, fit_intercept=False)
    # model.fit(diff_x[data.train_mask], np.eye(7)[data.y[data.train_mask].numpy()])
    # print(f"sklearn accuracy: (model.predict(diff_x[data.test_mask]).argmax(1) == data.y[data.test_mask].numpy()).mean()")
    ##########################################################################################################################

    true_logits = []
    for model_seed in range(cfg.train.n_models):
        seed_everything(model_seed)
        model = get_model(cfg, data, num_classes=dataset.num_classes)
        if "_target_" in cfg.train.optimizer:
            optimizer = get_optimizer(cfg.train.optimizer, model)
            _, _, _, _, model_logits = node_level_subset_training(
                data=data,
                model=model,
                optimizer=optimizer,
                patience=cfg.train.patience,
                max_epochs=cfg.train.epochs,
            )
        else:
            edge_index = get_appropriate_edge_index(data)
            model.fit(data.x, edge_index, data.y, data.train_mask)
            model_logits = get_logits(data, model)
        true_logits.append(model_logits)
    true_logits = torch.stack(true_logits)
    true_margins = (
        get_margin_incorrect_vectorized(true_logits.unsqueeze(0), data.y)
        .squeeze(0)
        .numpy()
        .mean(0)
    )
    true_test_acc = (true_margins > 0)[data.test_mask].mean().item()
    true_val_acc = (true_margins > 0)[data.val_mask].mean().item()

    print(f"True test accuracy: {true_test_acc}, True val accuracy: {true_val_acc}")

    try:
        res = torch.load(results_dir / "subset_size_influence_rank.pt")
        dm_margins = res["dm_margins"]
        shap_margins = res["shap_margins"]
        banzhaf_margins = res["banzhaf_margins"]
        pc_margins = res["pc_margins"]
        pc_perms_margins = res["pc_perms_margins"]
        subset_sizes = res["subset_sizes"]
    except FileNotFoundError:
        if (results_dir.parent / "subsets_res.pt").exists():
            subset_res = torch.load(results_dir.parent / "subsets_res.pt")
        else:
            raise FileNotFoundError(
                "No subsets found. Run the `train_subsets` script first."
            )
        subsets = subset_res["subsets"]
        logits = subset_res["logits"]
        margins = get_margin_incorrect_vectorized(logits, data.y).mean(1).numpy()

        dm_margins = []
        shap_margins = []
        banzhaf_margins = []
        pc_margins = []
        pc_perms_margins = []

        subset_sizes = [1000, 2500, 5000, 10000, 25000, 50000]
        for n_subsets in subset_sizes:
            # DATAMODEL
            train_subsets = subsets[:n_subsets]
            train_margins = margins[:n_subsets]

            margins_dm, _ = train_datamodels(
                mode=cfg.data.subset_mode,
                x=train_subsets,
                y=train_margins,
                train_mask=data.train_mask,
                val_mask=data.val_mask,
                test_mask=data.test_mask,
                n_jobs=-1,
            )

            # SHAP
            p_trunc = 0.25
            n_perms = np.round(n_subsets / (n_samples * p_trunc)).astype(int).item()

            perms = []
            for _ in range(1, n_perms + 1):
                rnd_perm = np.random.permutation(n_samples)
                perms.append(rnd_perm)

            (
                _,
                _,
                _,
                _,
                _,
                margins_shap,
                _,
            ) = compute_data_shapley(perms, n_samples, p_trunc, cfg, data, n_jobs=-1)

            # BANZHAF
            nodes_subsets_in = train_subsets.sum(0).A1
            nodes_subsets_notin = train_subsets.shape[0] - nodes_subsets_in

            margins_banzhaf = (
                np.array(
                    (1 / nodes_subsets_in[:, None]) * (train_subsets.T @ train_margins)
                    - (1 / nodes_subsets_notin[:, None])
                    * np.array((np.array([1]) - train_subsets).T @ train_margins)
                )
                .squeeze()
                .T
            )

            # PC
            labeled_node_list = data.train_mask.nonzero(as_tuple=True)[0].tolist()

            labeled_to_player_map, _, _ = generate_maps(
                train_idx_list=labeled_node_list,
                edge_index=data.edge_index,
                num_nodes=data.num_nodes,
            )

            perms = generate_permutations(
                n_subsets=n_subsets,
                labeled_node_list=labeled_node_list,
                label_trunc_ratio=0.0,
                group_trunc_ratio_hop_1=0.5,
                group_trunc_ratio_hop_2=0.7,
                labeled_to_player_map=labeled_to_player_map,
            )

            (
                _,
                _,
                _,
                _,
                margins_pc,
                _,
            ) = compute_pc_winter(
                perms,
                n_samples=data.num_nodes,
                cfg=copy.deepcopy(cfg),
                data=data,
                n_jobs=-1,
            )
            margins_pc = margins_pc[:, nodes_to_select]

            # PC-PERMS
            perms = generate_permutations(
                n_subsets=n_subsets,
                labeled_node_list=labeled_node_list,
                label_trunc_ratio=0.0,
                group_trunc_ratio_hop_1=0.99,
                group_trunc_ratio_hop_2=0.99,
                labeled_to_player_map=labeled_to_player_map,
            )

            (
                _,
                _,
                _,
                _,
                margins_pc_perms,
                _,
            ) = compute_pc_winter(
                perms,
                n_samples=data.num_nodes,
                cfg=copy.deepcopy(cfg),
                data=data,
                n_jobs=-1,
            )
            margins_pc_perms = margins_pc_perms[:, nodes_to_select]

            dm_margins.append(margins_dm)
            shap_margins.append(margins_shap)
            banzhaf_margins.append(margins_banzhaf)
            pc_margins.append(margins_pc)
            pc_perms_margins.append(margins_pc_perms)

        res = {
            "dm_margins": dm_margins,
            "shap_margins": shap_margins,
            "banzhaf_margins": banzhaf_margins,
            "pc_margins": pc_margins,
            "pc_perms_margins": pc_perms_margins,
            "subset_sizes": subset_sizes,
        }
        torch.save(res, results_dir / "subset_size_influence_rank.pt")

    dm_ground_truth = dm_margins[-1]
    shap_ground_truth = shap_margins[-1]
    banzhaf_ground_truth = banzhaf_margins[-1]
    pc_ground_truth = pc_margins[-1]
    pc_perms_ground_truth = pc_perms_margins[-1]

    dm_taus = []
    shap_taus = []
    banzhaf_taus = []
    pc_taus = []
    pc_perms_taus = []

    for dm_margin, shap_margin, banzhaf_margin, pc_margin, pc_perms_margin in zip(
        dm_margins, shap_margins, banzhaf_margins, pc_margins, pc_perms_margins
    ):
        dm_taus.append(avg_rank_correlation(dm_margin, dm_ground_truth))
        shap_taus.append(avg_rank_correlation(shap_margin, shap_ground_truth))
        banzhaf_taus.append(avg_rank_correlation(banzhaf_margin, banzhaf_ground_truth))
        pc_taus.append(avg_rank_correlation(pc_margin, pc_ground_truth))
        pc_perms_taus.append(
            avg_rank_correlation(pc_perms_margin, pc_perms_ground_truth)
        )

    # PLOTS
    mpl.rcParams.update({"font.size": 4})
    fig, ax = plt.subplots(
        figsize=get_figsize(nrows=1, ncols=1), constrained_layout=True
    )

    for tau, label in zip(
        [dm_taus, shap_taus, banzhaf_taus, pc_taus, pc_perms_taus],
        ["datamodel", "shap", "banzhaf", "pc-winter", "pc-winter perms"],
    ):
        tau_mean = np.array([t[0] for t in tau])
        tau_std = np.array([t[1] for t in tau])

        ax.plot(subset_sizes, tau_mean, label=label)
        ax.fill_between(subset_sizes, tau_mean - tau_std, tau_mean + tau_std, alpha=0.3)

    ax.set_xticks(subset_sizes)
    ax.tick_params(axis="x", labelrotation=90)

    ax.set_xlabel("# of subsets")
    ax.set_ylabel("Kendall's tau")
    ax.legend()

    fig.savefig(
        results_dir / "subset_size_influence_rank.png", dpi=300, bbox_inches="tight"
    )


@hydra.main(
    version_base=None, config_path=str(PROJECT_ROOT / "conf"), config_name="experiment"
)
def main(cfg: omegaconf.DictConfig):
    cpu_time_start = time.process_time()
    wall_time_start = time.time()

    config = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=False)
    if "wandb" in cfg.log:
        run = wandb.init(project=cfg.log.project, config=config)
    else:
        run = None
    run_experiment(cfg=cfg, wandb_run=run)

    if "wandb" in cfg.log:
        run.finish()

    cpu_time = time.process_time() - cpu_time_start
    wall_time = time.time() - wall_time_start

    cpu_days, cpu_hours, cpu_minutes, cpu_seconds = convert_seconds_to_dhms(cpu_time)
    wall_days, wall_hours, wall_minutes, wall_seconds = convert_seconds_to_dhms(
        wall_time
    )
    print(
        f"CPU time: {cpu_days} Days: {cpu_hours} Hours: {cpu_minutes} Min: {cpu_seconds:.2f} Sec"
    )
    print(
        f"Wall time: {wall_days} Days: {wall_hours} Hours: {wall_minutes} Min: {wall_seconds:.2f} Sec"
    )


if __name__ == "__main__":
    main()
