import os
import time

import hydra
import numpy as np
import omegaconf
import scipy.sparse as sp
import torch
import wandb
from joblib import Parallel, delayed
from torch_geometric import seed_everything
from tqdm import tqdm

from graphsmodel import PROJECT_ROOT
from graphsmodel.evaluation import (
    compute_data_shapley,
    compute_pc_winter,
    generate_maps,
    generate_permutations,
)
from graphsmodel.training import train_subset
from graphsmodel.utils import (
    convert_seconds_to_dhms,
    get_dataset,
    get_subset,
    sparse_tensor_to_scipy,
)


def train_subsets_datamodel(cfg, data, results_dir):
    n_subsets = cfg.train.n_subsets

    training_nodes = data.train_mask.nonzero().squeeze()
    val_nodes = data.val_mask.nonzero().squeeze()
    test_nodes = data.test_mask.nonzero().squeeze()

    n_train = len(training_nodes)
    n_val = len(val_nodes)
    n_test = len(test_nodes)
    n_nodes = data.x.shape[0]

    # length of the valuation tensor to use
    n_samples = (
        n_train
        if cfg.data.subset_mode == "train"
        else (
            n_val
            if cfg.data.subset_mode == "val"
            else n_test if cfg.data.subset_mode == "test" else n_nodes
        )
    )

    subsets = []
    seen_subsets = set()
    alpha = cfg.data.alpha

    if cfg.data.subsets_seed is not None:
        seed_everything(cfg.data.subsets_seed)

    for _ in range(n_subsets):
        subset = get_subset(seen_subsets=seen_subsets, d=n_samples, alpha=alpha)
        subsets.append(subset)

    res = Parallel(n_jobs=-1)(
        delayed(train_subset)(
            subset_idx=subset_idx,
            subset=subset,
            cfg=cfg,
            data=data,
            logits_on_data=True,
        )
        for subset_idx, subset in enumerate(tqdm(subsets))
    )

    subsets, sub_logits, sub_y, logits = zip(*res)
    subsets = sparse_tensor_to_scipy(torch.stack(subsets).to_sparse())
    sub_logits = torch.stack(sub_logits)
    sub_y = torch.stack(sub_y)
    logits = torch.stack(logits)

    # Save the results
    res = {
        "subsets": subsets,
        "logits": logits,
        "sub_logits": sub_logits,
        "sub_y": sub_y,
    }
    torch.save(res, results_dir / "subsets_res.pt", pickle_protocol=4)


def train_subsets_loo(cfg, data, results_dir):
    training_nodes = data.train_mask.nonzero().squeeze()
    val_nodes = data.val_mask.nonzero().squeeze()
    test_nodes = data.test_mask.nonzero().squeeze()

    n_train = len(training_nodes)
    n_val = len(val_nodes)
    n_test = len(test_nodes)
    n_nodes = data.x.shape[0]

    # length of the valuation tensor to use
    n_samples = (
        n_train
        if cfg.data.subset_mode == "train"
        else (
            n_val
            if cfg.data.subset_mode == "val"
            else n_test if cfg.data.subset_mode == "test" else n_nodes
        )
    )

    subsets = np.ones((n_samples, n_samples), dtype=bool)
    subsets[np.arange(n_samples), np.arange(n_samples)] = False

    res = Parallel(n_jobs=-1)(
        delayed(train_subset)(
            subset_idx=subset_idx,
            subset=subset,
            cfg=cfg,
            data=data,
            logits_on_data=True,
        )
        for subset_idx, subset in enumerate(tqdm(subsets))
    )

    subsets, sub_logits, sub_y, logits = zip(*res)
    subsets = sp.csr_matrix(np.stack(subsets))
    sub_logits = torch.stack(sub_logits)
    sub_y = torch.stack(sub_y)
    logits = torch.stack(logits)

    # Save the results
    res = {
        "subsets": subsets,
        "logits": logits,
        "sub_logits": sub_logits,
        "sub_y": sub_y,
    }
    torch.save(res, results_dir / "loo_res.pt", pickle_protocol=4)


def train_permutations_shapley(cfg, data, results_dir):
    n_subsets = cfg.train.n_subsets

    training_nodes = data.train_mask.nonzero().squeeze()
    val_nodes = data.val_mask.nonzero().squeeze()
    test_nodes = data.test_mask.nonzero().squeeze()

    n_train = len(training_nodes)
    n_val = len(val_nodes)
    n_test = len(test_nodes)
    n_nodes = data.x.shape[0]

    # length of the valuation tensor to use
    n_samples = (
        n_train
        if cfg.data.subset_mode == "train"
        else (
            n_val
            if cfg.data.subset_mode == "val"
            else n_test if cfg.data.subset_mode == "test" else n_nodes
        )
    )

    p_trunc = cfg.p_trunc
    n_perms = np.round(n_subsets / (n_samples * p_trunc)).astype(int).item()

    if cfg.data.subsets_seed is not None:
        seed_everything(cfg.data.subsets_seed)

    perms = []
    for _ in range(1, n_perms + 1):
        rnd_perm = np.random.permutation(n_samples)
        perms.append(rnd_perm)

    (
        perms,
        sub_test_accs_shap,
        sub_val_accs_shap,
        test_accs_shap,
        val_accs_shap,
        margins_shap,
        sub_margins_shap,
    ) = compute_data_shapley(perms, n_samples, p_trunc, cfg, data, n_jobs=-1)

    shapley_res = {
        "perms": perms,
        "p_trunc": p_trunc,
        "sub_test_accs_shap": sub_test_accs_shap,
        "sub_val_accs_shap": sub_val_accs_shap,
        "test_accs_shap": test_accs_shap,
        "val_accs_shap": val_accs_shap,
        "margins_shap": margins_shap,
        "sub_margins_shap": sub_margins_shap,
    }
    torch.save(shapley_res, results_dir / "shapley_res.pt", pickle_protocol=4)


def train_pc_winter(cfg, data, results_dir):
    n_nodes = data.num_nodes

    n_subsets = cfg.train.n_subsets
    labeled_node_list = data.train_mask.nonzero(as_tuple=True)[0].tolist()

    labeled_to_player_map, _, _ = generate_maps(
        train_idx_list=labeled_node_list,
        edge_index=data.edge_index,
        num_nodes=data.num_nodes,
    )

    if cfg.data.subsets_seed is not None:
        seed_everything(cfg.data.subsets_seed)

    perms = generate_permutations(
        n_subsets=n_subsets,
        labeled_node_list=labeled_node_list,
        label_trunc_ratio=cfg.label_trunc_ratio,
        group_trunc_ratio_hop_1=cfg.group_trunc_ratio_hop_1,
        group_trunc_ratio_hop_2=cfg.group_trunc_ratio_hop_2,
        labeled_to_player_map=labeled_to_player_map,
    )

    (
        sub_test_accs_pc,
        sub_val_accs_pc,
        test_accs_pc,
        val_accs_pc,
        margins_pc,
        sub_margins_pc,
    ) = compute_pc_winter(
        perms,
        n_samples=n_nodes,
        cfg=cfg,
        data=data,
        n_jobs=-1,
    )

    pc_winter_res = {
        "perms": perms,
        "label_trunc_ratio": cfg.label_trunc_ratio,
        "group_trunc_ratio_hop_1": cfg.group_trunc_ratio_hop_1,
        "group_trunc_ratio_hop_2": cfg.group_trunc_ratio_hop_2,
        "sub_test_accs_pc": sub_test_accs_pc,
        "sub_val_accs_pc": sub_val_accs_pc,
        "test_accs_pc": test_accs_pc,
        "val_accs_pc": val_accs_pc,
        "margins_pc": margins_pc,
        "sub_margins_pc": sub_margins_pc,
    }

    torch.save(
        pc_winter_res,
        results_dir.parent
        / f"pc_winter_{cfg.label_trunc_ratio}_{cfg.group_trunc_ratio_hop_1}_{cfg.group_trunc_ratio_hop_2}_res.pt",
        pickle_protocol=4,
    )


def run_experiment(cfg, wandb_run):
    if cfg.train.deterministic:
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
        torch.use_deterministic_algorithms(True)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    results_dir = cfg.core.results_dir
    print(results_dir)

    dataset = get_dataset(cfg)
    data = dataset[0]

    if cfg.experiment.type == "datamodel":
        print("Training datamodel")

        subset_sampling_probability = cfg.data.alpha
        results_dir = results_dir / f"alpha_{subset_sampling_probability}"
        results_dir.mkdir(parents=True, exist_ok=True)

        train_subsets_datamodel(cfg, data, results_dir)
    elif cfg.experiment.type == "loo":
        print("Training leave-one-out")
        train_subsets_loo(cfg, data, results_dir)
    elif cfg.experiment.type == "shapley":
        print("Training shapley")
        train_permutations_shapley(cfg, data, results_dir)
    elif cfg.experiment.type == "pc_winter":
        print("Training pc_winter")
        train_pc_winter(cfg, data, results_dir)
    else:
        raise ValueError(f"Unknown experiment type: {cfg.experiment.type}")


@hydra.main(
    version_base=None, config_path=str(PROJECT_ROOT / "conf"), config_name="training"
)
def main(cfg: omegaconf.DictConfig):
    cpu_time_start = time.process_time()
    wall_time_start = time.time()

    config = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=False)
    if "wandb" in cfg.log:
        run = wandb.init(project=cfg.log.project, config=config)
    else:
        run = None
    run_experiment(cfg=cfg, wandb_run=run)

    if "wandb" in cfg.log:
        run.finish()

    cpu_time = time.process_time() - cpu_time_start
    wall_time = time.time() - wall_time_start

    cpu_days, cpu_hours, cpu_minutes, cpu_seconds = convert_seconds_to_dhms(cpu_time)
    wall_days, wall_hours, wall_minutes, wall_seconds = convert_seconds_to_dhms(
        wall_time
    )
    print(
        f"CPU time: {cpu_days} Days: {cpu_hours} Hours: {cpu_minutes} Min: {cpu_seconds:.2f} Sec"
    )
    print(
        f"Wall time: {wall_days} Days: {wall_hours} Hours: {wall_minutes} Min: {wall_seconds:.2f} Sec"
    )

if __name__ == "__main__":
    main()
