import os

import numpy as np
from joblib import Parallel, delayed
from sklearn.linear_model import RidgeCV
from tqdm import tqdm

from graphsmodel.utils import (
    flatten_input,
    get_margin_incorrect_vectorized,
)


def fit_ridge(x, y):
    ridge = RidgeCV(alpha_per_target=True, fit_intercept=False)
    ridge.fit(x, y)
    return ridge.coef_, ridge.intercept_


def training_datamodel(x, y, to_keep):
    """
    x: subsets
    y: utility to predict
    """
    x = x[to_keep]
    y = y[to_keep]
    return fit_ridge(x, y)


def train_datamodels_with_mask(weights, intercepts, x, y, mask, nodes, n_jobs=1):
    """
    mask: (n_subs, n_nodes)
    """
    n_cpus = os.cpu_count()
    n_jobs = n_cpus + 1 + n_jobs if n_jobs < 0 else n_jobs
    node_batches = [nodes[i : i + n_jobs] for i in range(0, len(nodes), n_jobs)]
    pbar = tqdm(node_batches, total=len(nodes))
    for batch in tqdm(node_batches):
        # process the nodes in batches to save memory
        results = Parallel(n_jobs=n_jobs)(
            delayed(training_datamodel)(
                x=x,
                y=y[:, node],
                to_keep=(~mask[:, node]).nonzero()[0],
            )
            for node in batch
        )
        # after each batch, update the weights and intercepts and release the memory from `param: results`
        for i, node in enumerate(batch):
            weights[node], intercepts[node] = results[i]

        pbar.update(len(batch))
    pbar.close()

    return weights, intercepts


def get_nodes_subsets_mask(mask, subsets):
    n_subs = subsets.shape[0]
    ext_mask = mask.repeat(n_subs, 1)
    ext_mask[ext_mask == True] = flatten_input(subsets)
    return ext_mask.numpy()


def train_datamodels(mode, x, y, train_mask, val_mask, test_mask, n_jobs=1):
    _, sub_size = x.shape
    n_subs, n_nodes = y.shape

    weights = np.zeros((n_nodes, sub_size))
    intercepts = np.zeros(n_nodes)

    if mode == "train":
        weights[val_mask], intercepts[val_mask] = fit_ridge(x, y[:, val_mask])
        weights[test_mask], intercepts[test_mask] = fit_ridge(x, y[:, test_mask])

        traininig_nodes = train_mask.nonzero(as_tuple=True)[0].numpy()
        ext_mask = get_nodes_subsets_mask(train_mask, x)
        weights, intercepts = train_datamodels_with_mask(
            weights, intercepts, x, y, ext_mask, traininig_nodes, n_jobs=-1
        )
    elif mode == "val":
        weights[train_mask], intercepts[train_mask] = fit_ridge(x, y[:, train_mask])
        weights[test_mask], intercepts[test_mask] = fit_ridge(x, y[:, test_mask])

        val_nodes = val_mask.nonzero(as_tuple=True)[0].numpy()
        ext_mask = get_nodes_subsets_mask(val_mask, x)
        weights, intercepts = train_datamodels_with_mask(
            weights, intercepts, x, y, ext_mask, val_nodes, n_jobs=-1
        )
    elif mode == "test":
        weights[train_mask], intercepts[train_mask] = fit_ridge(x, y[:, train_mask])
        weights[val_mask], intercepts[val_mask] = fit_ridge(x, y[:, val_mask])

        test_nodes = test_mask.nonzero(as_tuple=True)[0].numpy()
        ext_mask = get_nodes_subsets_mask(test_mask, x)
        weights, intercepts = train_datamodels_with_mask(
            weights, intercepts, x, y, ext_mask, test_nodes, n_jobs=n_jobs
        )
    elif mode == "mixed":
        traininig_nodes = train_mask.nonzero(as_tuple=True)[0].numpy()
        ext_mask = get_nodes_subsets_mask(train_mask, x[:, train_mask])
        # Fill the weights and intercepts for the training nodes
        weights, intercepts = train_datamodels_with_mask(
            weights, intercepts, x, y, ext_mask, traininig_nodes, n_jobs=-1
        )

        val_nodes = val_mask.nonzero(as_tuple=True)[0].numpy()
        ext_mask = get_nodes_subsets_mask(val_mask, x[:, val_mask])
        # Fill the weights and intercepts for the val nodes
        weights, intercepts = train_datamodels_with_mask(
            weights, intercepts, x, y, ext_mask, val_nodes, n_jobs=-1
        )

        test_nodes = test_mask.nonzero(as_tuple=True)[0].numpy()
        ext_mask = get_nodes_subsets_mask(test_mask, x[:, test_mask])
        # Fill the weights and intercepts for the test nodes
        weights, intercepts = train_datamodels_with_mask(
            weights, intercepts, x, y, ext_mask, test_nodes, n_jobs=n_jobs
        )

    else:
        raise ValueError("Invalid mode specified")

    return weights, intercepts


def compute_datamodels(subset_mode, subsets, logits, sub_logits, sub_y, data, n_jobs=1):
    margins = get_margin_incorrect_vectorized(logits, data.y).mean(1).numpy()
    test_accs = (margins > 0)[:, data.test_mask].mean(-1)
    val_accs = (margins > 0)[:, data.val_mask].mean(-1)

    margins_dm, _ = train_datamodels(
        mode=subset_mode,
        x=subsets,
        y=margins,
        train_mask=data.train_mask,
        val_mask=data.val_mask,
        test_mask=data.test_mask,
        n_jobs=n_jobs,
    )

    test_accs_dm, _ = fit_ridge(subsets, test_accs)
    val_accs_dm, _ = fit_ridge(subsets, val_accs)

    sub_margins = get_margin_incorrect_vectorized(sub_logits, sub_y).mean(1).numpy()
    sub_margins_ma = np.ma.masked_invalid(sub_margins)
    sub_test_margins = sub_margins_ma[:, data.test_mask]
    sub_val_margins = sub_margins_ma[:, data.val_mask]
    sub_test_accs = (sub_test_margins > 0).mean(-1)
    sub_val_accs = (sub_val_margins > 0).mean(-1)

    n_nodes = sub_margins.shape[1]
    sub_size = subsets.shape[1]
    sub_margins_dm = np.zeros((n_nodes, sub_size))
    traininig_nodes = data.train_mask.nonzero(as_tuple=True)[0]
    val_nodes = data.val_mask.nonzero(as_tuple=True)[0]
    test_nodes = data.test_mask.nonzero(as_tuple=True)[0]

    # Fill the weights and intercepts for the training nodes
    sub_margins_dm, _ = train_datamodels_with_mask(
        weights=sub_margins_dm,
        intercepts=np.zeros(n_nodes),
        x=subsets,
        y=sub_margins,
        mask=np.isnan(sub_margins),
        nodes=traininig_nodes,
        n_jobs=-1,
    )

    # Fill the weights and intercepts for the val nodes
    sub_margins_dm, _ = train_datamodels_with_mask(
        weights=sub_margins_dm,
        intercepts=np.zeros(n_nodes),
        x=subsets,
        y=sub_margins,
        mask=np.isnan(sub_margins),
        nodes=val_nodes,
        n_jobs=-1,
    )

    # Fill the weights and intercepts for the test nodes
    sub_margins_dm, _ = train_datamodels_with_mask(
        weights=sub_margins_dm,
        intercepts=np.zeros(n_nodes),
        x=subsets,
        y=sub_margins,
        mask=np.isnan(sub_margins),
        nodes=test_nodes,
        n_jobs=-1,
    )

    # Datamodels for the (val and test) accuracies of the subgraphs
    sub_test_accs_dm, _ = fit_ridge(subsets, sub_test_accs)
    sub_val_accs_dm, _ = fit_ridge(subsets, sub_val_accs)

    return (
        margins_dm,
        test_accs_dm,
        val_accs_dm,
        sub_margins_dm,
        sub_test_accs_dm,
        sub_val_accs_dm,
    )
