import copy
import os
import time
from pathlib import Path

import hydra
import numpy as np
import omegaconf
import torch
import wandb
from joblib import Parallel, delayed
from torch_geometric.seed import seed_everything
from torch_geometric.utils import degree
from tqdm import tqdm

from graphsmodel import PROJECT_ROOT
from graphsmodel.evaluation.data_shapley import compute_data_shapley
from graphsmodel.evaluation.datamodel import fit_ridge
from graphsmodel.evaluation.pc_winter import (
    compute_pc_winter,
    generate_maps,
    generate_permutations,
)
from graphsmodel.training import get_logits, node_level_subset_training, train_subset
from graphsmodel.utils import (
    convert_seconds_to_dhms,
    get_appropriate_edge_index,
    get_dataset,
    get_margin_incorrect_vectorized,
    get_model,
    get_optimizer,
    get_subset,
    sparse_tensor_to_scipy,
)


def run_experiment(cfg, wandb_run):
    if cfg.train.deterministic:
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
        torch.use_deterministic_algorithms(True)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    dataset = get_dataset(cfg)
    data = dataset[0]

    training_nodes = data.train_mask.nonzero().squeeze()
    val_nodes = data.val_mask.nonzero().squeeze()
    test_nodes = data.test_mask.nonzero().squeeze()

    if cfg.data.subset_mode == "train":
        nodes_to_select = training_nodes
    elif cfg.data.subset_mode == "val":
        nodes_to_select = val_nodes
    elif cfg.data.subset_mode == "test":
        nodes_to_select = test_nodes
    else:
        nodes_to_select = torch.arange(data.num_nodes)

    # length of the valuation tensor to use
    n_samples = len(nodes_to_select)

    dataset_name = str(cfg.dataset.name).replace("-", "_")
    results_dir = (
        cfg.core.storage_dir
        / "memorization"
        / f"{dataset_name.lower()}_seed_{cfg.data.data_seed}"
        / (
            f"subsets_seed_{cfg.data.subsets_seed}"
            if cfg.data.subsets_seed is not None
            else ""
        )
        / f"alpha_{cfg.data.alpha}"
    )

    # Remove any empty path segments (in case the condition was False)
    results_dir = Path(*filter(None, results_dir.parts))
    results_dir.mkdir(parents=True, exist_ok=True)

    true_logits = []
    for model_seed in range(cfg.train.n_models):
        seed_everything(model_seed)
        model = get_model(cfg, data, num_classes=dataset.num_classes)
        if "_target_" in cfg.train.optimizer:
            optimizer = get_optimizer(cfg.train.optimizer, model)
            _, _, _, _, model_logits = node_level_subset_training(
                data=data,
                model=model,
                optimizer=optimizer,
                patience=cfg.train.patience,
                max_epochs=cfg.train.epochs,
            )
        else:
            edge_index = get_appropriate_edge_index(data)
            model.fit(data.x, edge_index, data.y, data.train_mask)
            model_logits = get_logits(data, model)
        true_logits.append(model_logits)
    true_logits = torch.stack(true_logits)

    training_idx = data.train_mask.nonzero(as_tuple=True)[0]
    budget = int(0.1 * len(training_idx))

    high_degree = cfg.high_degree  # True | False | None
    high_logits = cfg.high_logits  # True | False

    if high_degree is not None:
        deg = degree(data.edge_index[0], num_nodes=data.num_nodes)
        train_deg = deg[data.train_mask].numpy()
        if high_degree:
            rnd_training = train_deg.argsort()[::-1][:budget]
        else:
            rnd_training = np.random.permutation(
                train_deg.argsort()[: len(train_deg) // 2]
            )[:budget]
        assert np.all(
            deg[training_idx.numpy()[rnd_training]].numpy() == train_deg[rnd_training]
        )
    else:
        rnd_training = np.random.permutation(len(training_idx))[:budget]

    training_idx_to_poison = training_idx.numpy()[rnd_training]

    not_poisoned_idx = torch.ones_like(data.train_mask, dtype=torch.bool)
    not_poisoned_idx[training_idx_to_poison] = 0

    poisoned_data = copy.deepcopy(data)

    data_y_expanded = data.y.unsqueeze(-1).expand(cfg.train.n_models, data.num_nodes, 1)

    other_logits = true_logits.clone()
    if high_logits:
        ### Highest incorrect class
        other_logits.scatter_(-1, data_y_expanded, float("-inf"))
        incorrect_class = other_logits.argmax(dim=-1).mode(dim=0).values
    else:
        ### Lowest incorrect class
        incorrect_class = other_logits.argmin(dim=-1).mode(dim=0).values

    poisoned_data.y[training_idx_to_poison] = incorrect_class[training_idx_to_poison]

    # double check that the poisoned data is different from the original data but the train/val/test split is the same
    assert (
        (poisoned_data.y[not_poisoned_idx] == data.y[not_poisoned_idx]).all()
        and not (poisoned_data.y[~not_poisoned_idx] == data.y[~not_poisoned_idx]).all()
        and (
            poisoned_data.train_mask.nonzero(as_tuple=True)[0]
            == data.train_mask.nonzero(as_tuple=True)[0]
        ).all()
        and (
            poisoned_data.val_mask.nonzero(as_tuple=True)[0]
            == data.val_mask.nonzero(as_tuple=True)[0]
        ).all()
        and (
            poisoned_data.test_mask.nonzero(as_tuple=True)[0]
            == data.test_mask.nonzero(as_tuple=True)[0]
        ).all()
    )

    #### DATAMODEL
    subsets = []
    seen_subsets = set()
    alpha = cfg.data.alpha

    for _ in range(cfg.train.n_subsets):
        subset = get_subset(seen_subsets=seen_subsets, d=n_samples, alpha=alpha)
        subsets.append(subset)

    res = Parallel(n_jobs=-1)(
        delayed(train_subset)(
            subset_idx=subset_idx,
            subset=subset,
            cfg=cfg,
            data=poisoned_data,
            logits_on_data=True,
        )
        for subset_idx, subset in enumerate(tqdm(subsets))
    )

    subsets, sub_logits, sub_y, logits = zip(*res)
    subsets = sparse_tensor_to_scipy(torch.stack(subsets).to_sparse())
    sub_logits = torch.stack(sub_logits)
    sub_y = torch.stack(sub_y)
    logits = torch.stack(logits)

    poisoned_subsets_res = {
        "subsets": subsets,
        "logits": logits,
        "sub_logits": sub_logits,
        "sub_y": sub_y,
    }

    margins = get_margin_incorrect_vectorized(logits, poisoned_data.y).mean(1).numpy()
    margins_dm, _ = fit_ridge(subsets, margins)

    poisoned_datamodel_res = {
        "margins_dm": margins_dm,
    }

    #### Banzhaf
    nodes_subsets_in = subsets.sum(0).A1
    nodes_subsets_notin = subsets.shape[0] - nodes_subsets_in

    margins_banzhaf = (
        np.array(
            1 / nodes_subsets_in[:, None] * (subsets.T @ margins)
            - 1
            / nodes_subsets_notin[:, None]
            * np.array((np.array([1]) - subsets).T @ margins)
        )
        .squeeze()
        .T
    )

    poisoned_banzhaf_res = {
        "margins_banzhaf": margins_banzhaf,
    }

    #### LOO
    pois_true_logits = []
    for model_seed in range(cfg.train.n_models):
        seed_everything(model_seed)
        model = get_model(cfg, poisoned_data, num_classes=dataset.num_classes)
        if "_target_" in cfg.train.optimizer:
            optimizer = get_optimizer(cfg.train.optimizer, model)
            _, _, _, _, model_logits = node_level_subset_training(
                data=poisoned_data,
                model=model,
                optimizer=optimizer,
                patience=cfg.train.patience,
                max_epochs=cfg.train.epochs,
            )
        else:
            edge_index = get_appropriate_edge_index(poisoned_data)
            model.fit(
                poisoned_data.x, edge_index, poisoned_data.y, poisoned_data.train_mask
            )
            model_logits = get_logits(poisoned_data, model)
        pois_true_logits.append(model_logits)
    pois_true_logits = torch.stack(pois_true_logits)
    pois_true_margins = (
        get_margin_incorrect_vectorized(pois_true_logits.unsqueeze(0), poisoned_data.y)
        .squeeze(0)
        .numpy()
        .mean(0)
    )

    loo_subsets = torch.ones(n_samples, n_samples, dtype=torch.bool)
    loo_subsets[torch.arange(n_samples), torch.arange(n_samples)] = False

    loo_res = Parallel(n_jobs=-1)(
        delayed(train_subset)(
            subset_idx=subset_idx,
            subset=subset,
            cfg=cfg,
            data=poisoned_data,
            logits_on_data=True,
        )
        for subset_idx, subset in enumerate(tqdm(loo_subsets))
    )

    loo_subsets, loo_sub_logits, loo_sub_y, loo_logits = zip(*loo_res)
    loo_subsets = sparse_tensor_to_scipy(torch.stack(loo_subsets).to_sparse())
    loo_sub_logits = torch.stack(loo_sub_logits)
    loo_sub_y = torch.stack(loo_sub_y)
    loo_logits = torch.stack(loo_logits)

    margins_loo = (
        get_margin_incorrect_vectorized(loo_logits, poisoned_data.y).mean(1).numpy()
    )
    margins_loo = (pois_true_margins - margins_loo).T

    poisoned_loo_res = {
        "subsets": loo_subsets,
        "logits": loo_logits,
        "sub_logits": loo_sub_logits,
        "sub_y": loo_sub_y,
        "margins_loo": margins_loo,
    }

    #### SHAPLEY
    p_trunc = 0.25
    n_perms = np.round(cfg.train.n_subsets / (n_samples * p_trunc)).astype(int).item()

    perms = []
    for _ in range(1, n_perms + 1):
        rnd_perm = np.random.permutation(n_samples)
        perms.append(rnd_perm)

    (
        perms,
        _,
        _,
        _,
        _,
        margins_shap,
        _,
    ) = compute_data_shapley(
        perms=perms,
        n_samples=n_samples,
        p_trunc=p_trunc,
        cfg=cfg,
        data=poisoned_data,
        n_jobs=-1,
    )

    poisoned_shapley_res = {
        "perms": perms,
        "p_trunc": p_trunc,
        "margins_shap": margins_shap,
    }

    #### PC-Winter
    labeled_node_list = poisoned_data.train_mask.nonzero(as_tuple=True)[0].tolist()

    labeled_to_player_map, _, _ = generate_maps(
        train_idx_list=labeled_node_list,
        edge_index=poisoned_data.edge_index,
        num_nodes=poisoned_data.num_nodes,
    )

    label_trunc_ratio = 0
    group_trunc_ratio_hop_1 = 0.99
    group_trunc_ratio_hop_2 = 0.99

    pc_perms = generate_permutations(
        n_subsets=cfg.train.n_subsets,
        labeled_node_list=labeled_node_list,
        label_trunc_ratio=label_trunc_ratio,
        group_trunc_ratio_hop_1=group_trunc_ratio_hop_1,
        group_trunc_ratio_hop_2=group_trunc_ratio_hop_2,
        labeled_to_player_map=labeled_to_player_map,
    )

    (
        _,
        _,
        _,
        _,
        margins_pc,
        _,
    ) = compute_pc_winter(
        pc_perms,
        n_samples=poisoned_data.num_nodes,
        cfg=cfg,
        data=poisoned_data,
        n_jobs=-1,
    )

    poisoned_pc_winter_res = {
        "perms": perms,
        "label_trunc_ratio": label_trunc_ratio,
        "group_trunc_ratio_hop_1": group_trunc_ratio_hop_1,
        "group_trunc_ratio_hop_2": group_trunc_ratio_hop_2,
        "margins_pc": margins_pc,
    }

    to_save = {
        "budget": budget,
        "rnd_training": rnd_training,
        "training_idx_to_poison": training_idx_to_poison,
        "poisoned_data": poisoned_data,
        "subsets_res": poisoned_subsets_res,
        "datamodel_res": poisoned_datamodel_res,
        "banzhaf_res": poisoned_banzhaf_res,
        "loo_res": poisoned_loo_res,
        "shapley_res": poisoned_shapley_res,
        "pc_winter_res": poisoned_pc_winter_res,
    }

    filename = "logit_poisoned_rank"
    if high_logits:
        filename = "high_" + filename
    else:
        filename = "low_" + filename
    if high_degree is None:
        filename = "rnd_" + filename
    else:
        if high_degree:
            filename = "high_degree_" + filename
        else:
            filename = "low_degree_" + filename

    torch.save(to_save, results_dir / f"{filename}.pt")


@hydra.main(
    version_base=None, config_path=str(PROJECT_ROOT / "conf"), config_name="experiment"
)
def main(cfg: omegaconf.DictConfig):
    cpu_time_start = time.process_time()
    wall_time_start = time.time()

    config = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=False)
    if "wandb" in cfg.log:
        run = wandb.init(project=cfg.log.project, config=config)
    else:
        run = None
    run_experiment(cfg=cfg, wandb_run=run)

    if "wandb" in cfg.log:
        run.finish()

    cpu_time = time.process_time() - cpu_time_start
    wall_time = time.time() - wall_time_start

    cpu_days, cpu_hours, cpu_minutes, cpu_seconds = convert_seconds_to_dhms(cpu_time)
    wall_days, wall_hours, wall_minutes, wall_seconds = convert_seconds_to_dhms(
        wall_time
    )
    print(
        f"CPU time: {cpu_days} Days: {cpu_hours} Hours: {cpu_minutes} Min: {cpu_seconds:.2f} Sec"
    )
    print(
        f"Wall time: {wall_days} Days: {wall_hours} Hours: {wall_minutes} Min: {wall_seconds:.2f} Sec"
    )


if __name__ == "__main__":
    main()
