"""
Collection of methods used across files.
Data Generation originally by https://github.com/xunzheng/notears, with some adaptations.
"""
from sklearn.linear_model import LinearRegression
from scipy.special import expit as sigmoid
from collections import namedtuple
from pathlib import Path
from datetime import datetime
from scipy import linalg
import igraph as ig
import numpy as np
import pandas as pd
import pickle as pk
import os
import shutil
import random
import json
import types
import networkx as nx


def thresholds(x):
    """use x as threshold for all algorithms that require thresholding"""
    return {
        "randomregressIC": x,
        "sortnregressIC": -np.inf,
        "sortnregressIC_R2": -np.inf,
    }


# Full experiment options
Options = namedtuple(
    "Options",
    [
        "overwrite",
        "exp_name",
        "MEC",
        "thres",
        "thres_type",
        "vsb_function",
        "R2sb_function",
        "CEVsb_function",
        "base_dir",
        "n_repetitions",
        "graphs",
        "edges",  # expected number of edges per node
        "edge_types",  # 'dynamic' or 'fixed'
        "noise_distributions",
        "edge_weights",
        "n_nodes",
        "n_obs",
        "scaler",
    ],
)

# Experiment description
DatasetParameters = namedtuple(
    "DatasetParameters",
    [
        "graph_type",
        "edge_type",
        "x",
        "noise_dist",
        "noise_sigma_dist",
        "noise_sigma_lims",
        "edge_weight_range",
        "n_nodes",
        "n_obs",
        "random_seed",
    ],
)

# Actual experiment
Dataset = namedtuple(
    "Dataset",
    [
        "parameters",  # a 'DatasetParameters' namedtuple
        "W_true",
        "B_true",
        "data",
        "hash",
        "scaler",
        "scaling_factors",
        "sigma",
        # vsb
        "vars",
        "scaled_vars",
        "varsortability",
        # rsb
        "R2",
        "R2sortability",
        # cev-stb
        "CEVsortability",
    ],
)

# Noise distribution dtype
NoiseDistribution = namedtuple(
    "NoiseDistribution",
    [
        "noise_dist",  # noise distribution
        "noise_sigma_dist",  # noise sigma distribution
        "noise_sigma_lims",  # noise sigma upper and lower limit for uniform -> defines variance!
    ],
)


def elnw(params):
    """Calculate E[ln|w|] for a weight distribution and range"""
    weight_dist = lambda n: np.random.choice([-1, 1], n) * np.random.uniform(
        *params.edge_weight_range, n
    )
    sample = weight_dist(100000)
    return np.mean(np.log(np.abs(sample)))  # for log


def vlnw(params):
    """Calculate E[ln|w|] for a weight distribution and range"""
    weight_dist = lambda n: np.random.choice([-1, 1], n) * np.random.uniform(
        *params.edge_weight_range, n
    )
    sample = weight_dist(100000)
    return np.var(np.log(np.abs(sample)))  # for log


def matching_dataset(hash, datasets):
    """get dataset with same hash from list of datasets"""
    idx = [dataset.hash for dataset in datasets].index(hash)
    return datasets[idx]


def l2_loss(X, W):
    return 0.5 / X.shape[0] * ((X - X @ W) ** 2).sum()


def dataset_parameters(dataset):
    return "_".join([str(i) for i in list(dataset.parameters)])


def dataset_dirname(dataset):  # without random seed
    return "_".join([str(i) for i in list(dataset.parameters)[:-1]])


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)


def load_pk_files(path):
    return [pk.load(open(i, "rb")) for i in list(Path(path).rglob("*.pk"))]


def load_results(path):
    """Load results df and recover np arrays for adjacency matrices. A little hacky."""

    def _eval_ndarray_string(x):
        return [eval("np." + x[1:-1])]

    results = pd.read_csv(path)
    results["scaling_factors"] = results["scaling_factors"].apply(_eval_ndarray_string)
    results["vars"] = results["vars"].apply(_eval_ndarray_string)
    results["scaled_vars"] = results["scaled_vars"].apply(_eval_ndarray_string)
    try:
        results["R2"] = results["R2"].apply(_eval_ndarray_string)
    except KeyError:
        pass
    results["W_true"] = results["W_true"].apply(_eval_ndarray_string)
    results["W_est"] = results["W_est"].apply(_eval_ndarray_string)
    results["start_time"] = results["start_time"].apply(
        lambda x: datetime.strptime(x, "%Y-%m-%d %H:%M:%S.%f")
    )
    return results


def create_folder(path, overwrite=False):
    if overwrite:
        if os.path.isdir(path):
            shutil.rmtree(path)
        os.makedirs(path)
    else:
        if not os.path.isdir(path):
            os.makedirs(path)


def standardize(data):
    return data / data.std(0, keepdims=True)


class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        if isinstance(obj, types.FunctionType):
            return obj.__name__
        if "Scalers" in str(obj.__class__):
            return obj.__class__.__name__
        return super(NpEncoder, self).default(obj)


def snapshot(opt):
    """create a snapshot of experiment options"""
    create_folder(opt.base_dir, overwrite=False)
    with open(f"{opt.base_dir}/{opt.exp_name}_params.json", "w") as f:
        json.dump(opt._asdict(), f, indent=4, cls=NpEncoder)


def is_edge_in_cycle(edge, cycle):
    for i in range(len(cycle) - 1):
        if cycle[i] == edge[0] and cycle[i + 1] == edge[1]:
            return True
    if cycle[-1] == edge[0] and cycle[0] == edge[1]:
        return True
    return False


def dagify_break_cycles(W):
    W = np.copy(W)
    # identify cycles
    G = nx.from_numpy_array(W.copy(), create_using=nx.DiGraph)
    cycles = list(nx.simple_cycles(G))

    short_cycle = (None, np.inf)  # (index, length)

    # find a shortest cycle to break
    for k in np.random.permutation(len(cycles)):
        candidate = (k, len(cycles[k]))
        if candidate[1] < short_cycle[1]:
            short_cycle = candidate
        if short_cycle[1] <= 2:
            break

    # get the candidate edges to break
    cycle_to_break = cycles[short_cycle[0]]
    candidate_edges = [(cycle_to_break[-1], cycle_to_break[0])] + [
        (cycle_to_break[k], cycle_to_break[k + 1])
        for k in range(len(cycle_to_break) - 1)
    ]

    # break it at the most promising edge
    edge_to_break = (None, -1)
    for edge in np.random.permutation(candidate_edges):
        number_of_cycles_broken = 0
        for c in cycles:
            number_of_cycles_broken += is_edge_in_cycle(edge, c)
        candidate = (edge, number_of_cycles_broken)
        if candidate[1] > edge_to_break[1]:
            edge_to_break = candidate
    to_break = edge_to_break[0]
    W[to_break[0], to_break[1]] = 0
    return W


def stop_pycausalvm():
    try:
        import javabridge

        javabridge.detach()
        javabridge.kill_vm()
    except AssertionError:
        pass


def order_alignment_CountAll(W, scores, tol=0.0):
    """
    Order alignment that counts all paths (regardless of length)
    """
    assert tol >= 0.0, "tol must be non-negative"
    E = (W != 0).astype(int)  # for counting all paths
    Ek = E.copy()

    n_paths = 0
    n_correctly_ordered_paths = 0

    scores = scores.reshape(1, -1)
    differences = scores - scores.T

    for _ in range(len(E) - 1):
        n_paths += Ek.sum()
        n_correctly_ordered_paths += (Ek * (differences >= 0 - tol)).sum() / 2
        n_correctly_ordered_paths += (Ek * (differences > 0 + tol)).sum() / 2
        Ek = Ek.dot(E)
    return n_correctly_ordered_paths / n_paths


def order_alignment_CountOne(W, scores, tol=0.0):
    """
    Order alignment that only considers the existence of a path regardless of how many there are.
    """
    assert tol >= 0.0, "tol must be non-negative"
    E = W != 0
    n_paths = 0
    n_correctly_ordered_paths = 0

    scores = scores.reshape(1, -1)
    differences = scores - scores.T

    path_exists = np.sum(
        [np.linalg.matrix_power(E, i) for i in range(1, len(E))], axis=0
    ).astype(bool)

    n_paths += path_exists.sum()
    n_correctly_ordered_paths += (path_exists * (differences >= 0 - tol)).sum() / 2
    n_correctly_ordered_paths += (path_exists * (differences > 0 + tol)).sum() / 2

    return n_correctly_ordered_paths / n_paths


def order_alignment(W, scores, tol=0.0):
    """
    Compute a measure for the agreement of an ordering incurred by the scores
    with a causal ordering incurred by the (weighted) adjacency matrix W.
    Args:
        W: (d x d) matrix
        scores: (d) vector
        tol (optional): non-negative float
    Returns:
        Scalar measure of agreement between the orderings
    """
    assert tol >= 0.0, "tol must be non-negative"
    E = W != 0
    Ek = E.copy()
    n_paths = 0
    n_correctly_ordered_paths = 0

    # arrange scores as row vector
    scores = scores.reshape(1, -1)

    # create d x d matrix of score differences of scores such that
    # the entry in the i-th row and j-th column is
    #     * positive if score i < score j
    #     * zero if score i = score j
    #     * negative if score i > score j
    differences = scores - scores.T

    # measure ordering agreement
    # see 10.48550/arXiv.2102.13647, Section 3.1
    # and 10.48550/arXiv.2303.18211, Equation (3)
    for _ in range(len(E) - 1):
        n_paths += Ek.sum()
        # count 1/2 per correctly ordered or unordered pair
        n_correctly_ordered_paths += (Ek * (differences >= 0 - tol)).sum() / 2
        # count another 1/2 per correctly ordered pair
        n_correctly_ordered_paths += (Ek * (differences > 0 + tol)).sum() / 2
        Ek = Ek.dot(E)
    return n_correctly_ordered_paths / n_paths


def r2coef(X):
    """
    Compute R^2's
    using partial correlations obtained through matrix inversion.
    Args:
        X: (d x n) array
    """
    try:
        return 1 - 1 / np.diag(linalg.inv(np.corrcoef(X)))
    except linalg.LinAlgError:
        # fallback if correlation matrix is singular is below
        pass
    d = X.shape[0]
    r2s = np.zeros(d)
    R = LinearRegression()
    X = X.T
    for k in range(d):
        parents = np.arange(d) != k
        R.fit(X[:, parents], X[:, k])
        r2s[k] = R.score(X[:, parents], X[:, k])
    return r2s


def var_sortability(X, W, tol=0.0, alignment="Original"):
    scores = np.var(X, axis=0)
    if alignment == "Original":
        return order_alignment(W, scores, tol=tol)
    elif alignment == "CountOne":
        return order_alignment_CountOne(W, scores, tol=tol)
    elif alignment == "CountAll":
        return order_alignment_CountAll(W, scores, tol=tol)
    else:
        raise NotImplementedError


def r2_sortability(X, W, tol=0.0, alignment="Original"):
    scores = r2coef(X.T)
    if alignment == "Original":
        return order_alignment(W, scores, tol=tol)
    elif alignment == "CountOne":
        return order_alignment_CountOne(W, scores, tol=tol)
    elif alignment == "CountAll":
        return order_alignment_CountAll(W, scores, tol=tol)
    else:
        raise NotImplementedError


def cev_sortability(X, W, tol=0.0, alignment="Original"):
    d = X.shape[1]
    scores = np.zeros((1, d))
    LR = LinearRegression()
    for k in range(d):
        parents = W[:, k] != 0
        if np.sum(parents) > 0:
            LR.fit(X[:, parents], X[:, k])
            scores[0, k] = LR.score(X[:, parents], X[:, k])
    if alignment == "Original":
        return order_alignment(W, scores, tol=tol)
    elif alignment == "CountOne":
        return order_alignment_CountOne(W, scores, tol=tol)
    elif alignment == "CountAll":
        return order_alignment_CountAll(W, scores, tol=tol)
    else:
        raise NotImplementedError


def shd_cpdag(B_true, B_est):
    """
    Both inputs are [d, d] matrices with entries {0, 1}
    where bidirected edges i <-> j are coded by placing a 1 at [i,j] and [j,i]
    """
    if ~np.isin(B_true, [0, 1]).any() or ~np.isin(B_est, [0, 1]).any():
        raise ValueError("Both inputs should be CPDAG matrices with 0 and 1 only")

    d = B_true.shape[0]
    # code lower and upper triangle differently
    W = np.tril(np.ones((d, d)), -1) * 1.0 + np.triu(np.ones((d, d)), 1) * 2.0
    # if we fold the weighted B matrices up, i.e. consider
    # W * B + (W * B).T
    # then 0, 1, 2, 3 code no edge/->/<-/<->
    WB_true = np.triu((W * B_true) + (W * B_true).T, 1)
    WB_est = np.triu((W * B_est) + (W * B_est).T, 1)

    # return count of mistakes, where mistaking any of
    # no edge/->/<-/<->
    # by one of the remaining three options counts as 1 mistake
    return (WB_true != WB_est).sum()


def sid(W, W_est):
    """The first argument must be the ground-truth graph"""
    res = r_sid.structIntervDist(W != 0, W_est != 0)
    return res.rx2("sid")[0]


def is_dag(W):
    """Determine if graph formed by adj matrix W is DAG"""
    G = ig.Graph.Weighted_Adjacency(W.tolist())
    return G.is_dag()


def simulate_parameter(B, w_ranges=((-2.0, -0.5), (0.5, 2.0))):
    """Simulate SEM parameters for a DAG.

    Args:
        B (np.ndarray): [d, d] binary adj matrix of DAG
        w_ranges (tuple): disjoint weight ranges

    Returns:
        W (np.ndarray): [d, d] weighted adj matrix of DAG
    """
    W = np.zeros(B.shape)
    S = np.random.randint(len(w_ranges), size=B.shape)  # which range
    for i, (low, high) in enumerate(w_ranges):
        U = np.random.uniform(low=low, high=high, size=B.shape)
        W += B * (S == i) * U
    return W


def simulate_dag(opt):
    """Simulate random DAG with some expected number of edges.

    Args:
        d (int): num of nodes
        s0 (int): expected num of edges
        graph_type (str): ER, SF, BP

    Returns:
        B (np.ndarray): [d, d] binary adj matrix of DAG
    """
    d = opt.n_nodes
    s0 = int(opt.x * opt.n_nodes)

    def _random_permutation(M):
        # np.random.permutation permutes first axis only
        P = np.random.permutation(np.eye(M.shape[0]))
        return P.T @ M @ P

    def _random_acyclic_orientation(B_und):
        return np.tril(_random_permutation(B_und), k=-1)

    def _graph_to_adjmat(G):
        return np.array(G.get_adjacency().data)

    if opt.graph_type not in ["ER", "chain"]:
        assert opt.edge_type == "fixed"

    if opt.graph_type == "ER":
        if opt.edge_type == "fixed":
            # Erdos-Renyi
            s0 = min(s0, int((d**2 - d) / 2))  # fully connect where too many edges
            G_und = ig.Graph.Erdos_Renyi(n=d, m=s0)
            B_und = _graph_to_adjmat(G_und)
        elif opt.edge_type == "dynamic":
            p = min(2 * opt.x / (d - 1), 1)
            edges = np.random.choice(a=[0, 1], size=int((d * d - d) / 2), p=[1 - p, p])
            B_und = np.zeros(shape=(d, d))
            ind = np.triu_indices_from(B_und, k=1)
            B_und[ind] = edges
            B_und += B_und.T
        B = _random_acyclic_orientation(B_und)
    elif opt.graph_type == "SF":
        # Scale-free, Barabasi-Albert
        G = ig.Graph.Barabasi(n=d, m=int(round(s0 / d)), directed=True)
        B = _graph_to_adjmat(G)
    elif opt.graph_type == "BP":
        # Bipartite, Sec 4.1 of (Gu, Fu, Zhou, 2018)
        top = int(0.2 * d)
        G = ig.Graph.Random_Bipartite(top, d - top, m=s0, directed=True, neimode=ig.OUT)
        B = _graph_to_adjmat(G)
    elif opt.graph_type == "fc":
        B = np.triu(np.ones(shape=(d, d)), k=1)
    elif opt.graph_type == "chain":
        B = np.diag(np.ones(d - 1), 1)
    else:
        raise ValueError("unknown graph type")
    B_perm = _random_permutation(B)
    assert ig.Graph.Adjacency(B_perm.tolist()).is_dag()
    return B_perm


def count_accuracy(B_true, B_est):
    """Compute various accuracy metrics for B_est.

    true positive = predicted association exists in condition in correct direction
    reverse = predicted association exists in condition in opposite direction
    false positive = predicted association does not exist in condition

    Args:
        B_true (np.ndarray): [d, d] ground truth graph, {0, 1}
        B_est (np.ndarray): [d, d] estimate, {0, 1, -1}, -1 is undirected edge in CPDAG

    Returns:
        fdr: (reverse + false positive) / prediction positive
        tpr: (true positive) / condition positive
        fpr: (reverse + false positive) / condition negative
        shd: undirected extra + undirected missing + reverse
        nnz: prediction positive
    """
    if (B_est == -1).any():  # cpdag
        raise ValueError("We do not want this utils function to act on CPDAGs")
        if not ((B_est == 0) | (B_est == 1) | (B_est == -1)).all():
            raise ValueError("B_est should take value in {0,1,-1}")
        if ((B_est == -1) & (B_est.T == -1)).any():
            raise ValueError("undirected edge should only appear once")
    else:  # dag
        if not ((B_est == 0) | (B_est == 1)).all():
            raise ValueError("B_est should take value in {0,1}")
        if not is_dag(B_est):
            raise ValueError("B_est should be a DAG")
    d = B_true.shape[0]
    # linear index of nonzeros
    pred_und = np.flatnonzero(B_est == -1)
    pred = np.flatnonzero(B_est == 1)
    cond = np.flatnonzero(B_true)
    cond_reversed = np.flatnonzero(B_true.T)
    cond_skeleton = np.concatenate([cond, cond_reversed])
    # true pos
    true_pos = np.intersect1d(pred, cond, assume_unique=True)
    # treat undirected edge favorably
    true_pos_und = np.intersect1d(pred_und, cond_skeleton, assume_unique=True)
    true_pos = np.concatenate([true_pos, true_pos_und])
    # false pos
    false_pos = np.setdiff1d(pred, cond_skeleton, assume_unique=True)
    false_pos_und = np.setdiff1d(pred_und, cond_skeleton, assume_unique=True)
    false_pos = np.concatenate([false_pos, false_pos_und])
    # reverse
    extra = np.setdiff1d(pred, cond, assume_unique=True)
    reverse = np.intersect1d(extra, cond_reversed, assume_unique=True)
    # compute ratio
    pred_size = len(pred) + len(pred_und)
    cond_neg_size = 0.5 * d * (d - 1) - len(cond)
    fdr = float(len(reverse) + len(false_pos)) / max(pred_size, 1)
    tpr = float(len(true_pos)) / max(len(cond), 1)
    fpr = float(len(reverse) + len(false_pos)) / max(cond_neg_size, 1)
    # structural hamming distance
    pred_lower = np.flatnonzero(np.tril(B_est + B_est.T))
    cond_lower = np.flatnonzero(np.tril(B_true + B_true.T))
    extra_lower = np.setdiff1d(pred_lower, cond_lower, assume_unique=True)
    missing_lower = np.setdiff1d(cond_lower, pred_lower, assume_unique=True)
    shd = len(extra_lower) + len(missing_lower) + len(reverse)
    return {"fdr": fdr, "tpr": tpr, "fpr": fpr, "shd": shd, "nnz": pred_size}


def simulate_linear_sem(W, n, sem_type, noise_scale, harmonize):
    """Simulate samples from linear SEM with specified type of noise.

    Args:
        W (np.ndarray): [d, d] weighted adj matrix of DAG
        n (int): num of samples, n=inf mimics population risk
        sem_type (str): gauss, exp, gumbel, uniform, logistic, poisson
        noise_scale (np.ndarray): scale parameter of additive noise, default all ones
        harmonize: re-scale weights for unit variances

    Returns:
        X (np.ndarray): [n, d] sample matrix, [d, d] if n=inf
    """

    def _simulate_single_equation(X, w, scale):
        """X: [n, num of parents], w: [num of parents], x: [n]"""
        if sem_type == "gauss":
            z = np.random.normal(scale=scale, size=n)
            x = X @ w + z
        elif sem_type == "exp":
            z = np.random.exponential(scale=scale, size=n)
            x = X @ w + z
        elif sem_type == "gumbel":
            a = np.sqrt(6) / np.pi * scale
            z = np.random.gumbel(scale=a, size=n)
            x = X @ w + z
        elif sem_type == "uniform":
            a = scale * np.sqrt(3)
            z = np.random.uniform(low=-a, high=a, size=n)
            x = X @ w + z
        elif sem_type == "logistic":
            x = np.random.binomial(1, sigmoid(X @ w)) * 1.0
        elif sem_type == "poisson":
            x = np.random.poisson(np.exp(X @ w)) * 1.0
        else:
            raise ValueError("unknown sem type")
        return x

    d = W.shape[0]
    if noise_scale is None:
        scale_vec = np.ones(d)
    elif np.isscalar(noise_scale):
        scale_vec = noise_scale * np.ones(d)
    else:
        if len(noise_scale) != d:
            raise ValueError("noise scale must be a scalar or has length d")
        scale_vec = noise_scale
    if not is_dag(W):
        raise ValueError("W must be a DAG")
    if np.isinf(n):  # population risk for linear gauss SEM
        if sem_type == "gauss":
            # make 1/d X'X = true cov
            X = np.sqrt(d) * np.diag(scale_vec) @ np.linalg.inv(np.eye(d) - W)
            return X
        else:
            raise ValueError("population risk not available")
    # empirical risk
    G = ig.Graph.Weighted_Adjacency(W.tolist())
    ordered_vertices = G.topological_sorting()
    assert len(ordered_vertices) == d
    X = np.zeros([n, d])
    for j in ordered_vertices:
        parents = G.neighbors(j, mode=ig.IN)
        scale = scale_vec[j]
        if harmonize:
            # uniform rescaling
            c = np.linalg.norm(np.concatenate((W[parents, j], np.array([1]))))
            W[parents, j] /= c
        X[:, j] = _simulate_single_equation(X[:, parents], W[parents, j], scale)
    return X, W
