from tigramite.toymodels import structural_causal_processes as toys
import numpy as np
import matplotlib.pyplot as plt
import tigramite.plotting as tp


def lin_f(x): return x


def links_to_cyclic_graph(links, tau_max=None):
    """Helper function to convert dictionary of links to graph array format.

    Parameters
    ---------
    links : dict
        Dictionary of form {0:[((0, -1), coeff, func), ...], 1:[...], ...}.
        Also format {0:[(0, -1), ...], 1:[...], ...} is allowed.
    tau_max : int or None
        Maximum lag. If None, the maximum lag in links is used.

    Returns
    -------
    graph : array of shape (N, N, tau_max+1)
        Matrix format of graph with 1 for true links and 0 else.
    """
    N = len(links)

    # Get maximum time lag
    min_lag, max_lag = toys._get_minmax_lag(links)

    # Set maximum lag
    if tau_max is None:
        tau_max = max_lag
    else:
        if max_lag > tau_max:
            raise ValueError("tau_max is smaller than maximum lag = %d "
                             "found in links, use tau_max=None or larger "
                             "value" % max_lag)

    graph = np.zeros((N, N, tau_max + 1), dtype='<U3')
    pos_contemp = []
    for j in links.keys():
        for link_props in links[j]:
            if len(link_props) > 2:
                var, lag = link_props[0]
                coeff = link_props[1]
                if coeff != 0.:
                    graph[var, j, abs(lag)] = "-->"
                    if lag == 0:
                        pos_contemp.append([j, var, 0])
            else:
                var, lag = link_props
                graph[var, j, abs(lag)] = "-->"
                if lag == 0:
                    pos_contemp.append([j, var, 0])

    """for j, var, lag in pos_contemp:
        if graph[var, j, lag] != '-->':
            graph[var, j, lag] = "<--" """

    # do post-processing to add bi-directed edges and symmetrize contemp edges
    for j, var, lag in pos_contemp:
        if [var, j, lag] in pos_contemp:
            graph[var, j, lag] = '<->'
            graph[j, var, lag] = '<->'
        else:
            graph[j, var, lag] = '<--'

    return graph


def links_to_cyclic_graph_skeleton(links, tau_max=None):
    """Helper function to convert dictionary of links to graph array format.

    Parameters
    ---------
    links : dict
        Dictionary of form {0:[((0, -1), coeff, func), ...], 1:[...], ...}.
        Also format {0:[(0, -1), ...], 1:[...], ...} is allowed.
    tau_max : int or None
        Maximum lag. If None, the maximum lag in links is used.

    Returns
    -------
    graph : array of shape (N, N, tau_max+1)
        Matrix format of graph with 1 for true links and 0 else.
    """

    N = len(links)

    # Get maximum time lag
    min_lag, max_lag = toys._get_minmax_lag(links)

    # Set maximum lag
    if tau_max is None:
        tau_max = max_lag
    else:
        if max_lag > tau_max:
            raise ValueError("tau_max is smaller than maximum lag = %d "
                             "found in links, use tau_max=None or larger "
                             "value" % max_lag)

    graph = np.zeros((N, N, tau_max + 1), dtype='<U3')
    for j in links.keys():
        for link_props in links[j]:
            if len(link_props) > 2:
                var, lag = link_props[0]
                coeff = link_props[1]
                if coeff != 0.:
                    graph[var, j, abs(lag)] = "o-o"
                    if lag == 0:
                        graph[j, var, 0] = "o-o"
            else:
                var, lag = link_props
                graph[var, j, abs(lag)] = "o-o"
                if lag == 0:
                    graph[j, var, 0] = "o-o"

    return graph



def plot_masked_results(masked_results, nb_regimes=2, var_names=None):
    fig, axis = plt.subplots(1, nb_regimes, figsize=(20, 5))
    for r in range(nb_regimes):
        tp.plot_graph(
            val_matrix=masked_results[r]['val_matrix'],
            graph=masked_results[r]['graph'],
            var_names=var_names,
            link_colorbar_label='cross-MCI',
            node_colorbar_label='auto-MCI',
            show_autodependency_lags=False,
            figsize=(10, 10),
            node_size=0.2,
            fig_ax=(fig, axis[r])
        )


def plot_RPCMCI_results(pd_dataframe, vac_values, results):
    T, N = pd_dataframe.to_numpy().shape
    datatime = np.arange(T)
    regimes = results['regimes'].argmax(axis=0)
    n_regimes = len(results['causal_results'])
    var_names = pd_dataframe.columns
    # Create nice Mosaic plot
    mosaic = [['data %s' % j for i in range(n_regimes)] for j in range(N + 2)]
    for n in range(N):
        mosaic.append(['graph %s' % i for i in range(n_regimes)])

    fig, axs = plt.subplot_mosaic(mosaic=mosaic, figsize=(20, 10))

    for j in range(N):
        ax = axs['data %s' % j]
        ax.axhline(0., color='grey')
        data = copy.deepcopy(pd_dataframe.to_numpy())
        data[data == 999.] = np.nan
        ax.plot(datatime, data[:, j])
        ax.plot(datatime, data[:, j])
        for loc, spine in ax.spines.items():
            if loc != 'left':
                spine.set_color("none")

        ax.xaxis.set_ticks([])
        ax.set_xlim(0., T)
        ax.set_ylabel(pd_dataframe.columns[j])

    # Regime variable
    ax = axs['data %s' % N]
    ax.plot(datatime, regimes.astype('int'), lw=3, color='black')
    for loc, spine in ax.spines.items():
        if loc != 'left':
            spine.set_color("none")
    ax.xaxis.set_ticks([])
    ax.yaxis.set_ticks(range(n_regimes))
    ax.set_xlim(0., T)
    ax.set_ylabel("Regime")

    # Expert regime GT
    ax = axs['data %s' % (N + 1)]
    vac_values_copy = copy.deepcopy(vac_values)
    vac_values_copy[vac_values_copy == 999.] = np.nan
    ax.plot(datatime, vac_values_copy, lw=3, color='black')
    for loc, spine in ax.spines.items():
        if loc != 'left':
            spine.set_color("none")
    ax.xaxis.set_ticks([])
    ax.yaxis.set_ticks(range(n_regimes))
    ax.set_xlim(0., T)
    ax.set_ylabel("Expert regime GT")

    # Causal graphs for each regime
    for w in range(n_regimes):
        if w == 0:
            show_colorbar = True
        else:
            show_colorbar = False
        tp.plot_graph(graph=results['causal_results'][w]['graph'],
                      val_matrix=results['causal_results'][w]['val_matrix'],
                      show_colorbar=show_colorbar,
                      var_names=var_names,
                      fig_ax=(fig, axs['graph %s' % w]))
        axs['graph %s' % w].set_title("Regime %d" % w, pad=-4)

    fig.subplots_adjust(hspace=0.6)
    plt.show()



def count_persistence(data, regime_indicator):
    data_regime = data[:, regime_indicator[0]]

    current_value = data_regime[0]
    current_count = 1
    current_counts = []
    
    for val in data_regime[1:]:
        if val == current_value:
            current_count += 1
        else:
            current_counts.append(current_count)
            current_value = val
            current_count = 1

    current_counts.append(current_count)

    return np.asarray(current_counts).mean(), np.max(current_counts)