import copy
from collections import OrderedDict

from tigramite.toymodels import structural_causal_processes as toys
import tigramite.plotting as tp
from matplotlib import pyplot as plt

import numpy as np
from utils import lin_f

from regime_model import RegimeModel



def add_regime_children_single_regime(links, regime_children, regime_indicator):
    """
    Adds regime children to regime links.

    Args:
    - links (dict): The causal links for each variable.
    - regime_children (list): List of children variables to be included.
    - regime_indicator (tuple): Index of the regime / context indicator and timelag.

    Returns:
    - dict: Updated links with regime children included.
    """
    for var, tau in regime_children:
        links[var].append((regime_indicator, 0.5, "reg"))
    return links


def add_regime_children_all_regimes(links, regime_children, regime_indicator):
    """
    Adds regime indicator children to all regimes in the provided links structure.

    For each regime, iterates over the specified regime_children and appends a tuple
    (regime_indicator, 0.5, 'regime') to the corresponding variable's list in links.

    Args:
    - links (list of list): A nested list where each sublist corresponds to a regime,
            and contains lists for each variable, representing their children.
    - regime_children (list of tuple): A list of (variable, lag) pairs specifying
            which variables in each regime should receive the regime indicator as a child.
    - regime_indicator (int or str): The identifier for the regime indicator to be added
            as a child.

    Returns:
        list of list: The updated links structure with regime indicator children added
            to the specified variables in all regimes.
    """
    for regime in range(len(links)):
        for var, lag in regime_children:
            links[regime][var].append((regime_indicator, 0.5, 'regime'))

    return links


def check_links(links):
    """
    Checks the links to determine the causal order and if the graph is acyclic.

    Args:
    - links (dict): The causal links for each variable.

    Returns:
    - list or None: The causal order if the graph is acyclic; otherwise, None.
    """
    N = len(links.keys())
    max_lag = 0
    contemp_dag = toys._Graph(N)
    for j in range(N):
        for link_props in links[j]:
            var, lag = link_props[0]
            coeff = link_props[1]
            if var not in range(N):
                raise ValueError("var must be in 0..{}.".format(N - 1))
            if 'float' not in str(type(coeff)):
                raise ValueError("coeff must be float.")
            if lag > 0:
                raise ValueError("lag must be non-positive int.")
            
            max_lag = max(max_lag, abs(lag))

            # Create contemp DAG
            if var != j and lag == 0:
                contemp_dag.addEdge(var, j)
    
    if contemp_dag.isCyclic() == 1:
       causal_order = None
    else:
        causal_order = contemp_dag.topologicalSort()
    return causal_order


def clean_regime_children(regime_children):
    """
    Cleans the list of regime children by removing duplicates and those present in every regime.

    Parameters:
    - regime_children (list): A list of regime children across all regimes.

    Returns:
    - list: Cleaned list of regime children.
    """
    # remove regime_children that appear in each regime
    children_to_remove = []
    for child in regime_children:
        is_in_all = True
        for children_i in regime_children:
            if child not in children_i:
                is_in_all = False
                break
        if is_in_all:
            children_to_remove.append(child)

    regime_children = sum(regime_children, [])
    regime_children = [child for child in regime_children if child not in children_to_remove]
    regime_children = list(OrderedDict.fromkeys(regime_children))
    return regime_children

def unionize_links(links, regime_children=None, nb_regimes=2, regime_indicator=None):
    """
    Combines the links from multiple regimes into a union graph.
    Parameters:
    - links (list): List of links for each regime.
    - regime_children (list): List of regime children to include in the union graph.
    - nb_regimes (int): Number of regimes.
    - regime_indicator (tuple):  Index of the regime / context indicator.

    Returns:
    - dict: The union graph combining links from all regimes.
    """
    # links are assumed not to differ in their mechanisms between regimes
    N = len(links[0])
    union_graph = {}
    for var in range(N):
        parents = [links[i][var] for i in range(nb_regimes)]
        parents = sum(parents, [])
        parents = list(set(parents))
        union_graph[var] = parents

    # add regime children
    if regime_children is not None:
        union_graph = add_regime_children_single_regime(copy.deepcopy(union_graph), regime_children, regime_indicator)
    return union_graph

def process_graph_entries(entries):
    """
    Processes entries in the graph to determine the type of connection.

    Parameters:
    - entries (list): List of graph entries to process.

    Returns:
    - str: Processed value for the graph entries.
    """
    # Remove empty strings from the list
    filtered_entries = [entry for entry in entries if entry != '']
    
    # Check if the filtered list is empty
    if len(filtered_entries) == 0:
        return ''
    
    # Check if all entries in the list are the same
    if all(entry == filtered_entries[0] for entry in filtered_entries):
        return filtered_entries[0]
    
    # Check if there are different entries in the list
    if set(filtered_entries) == {'<--', '-->'}:
        return '<->'

    if any([entry == 'o-o' for entry in filtered_entries]):
        if any([entry == 'x-x' for entry in filtered_entries]):
            return 'x-x'
        else:
            return 'o-o'

    return ''

def unionize_graphs(graphs, nb_regimes):
    """
    Unionizes multiple graphs to combine links from all regimes.

    Parameters:
    - graphs (list): List of graphs for each regime.
    - nb_regimes (int): Number of regimes.

    Returns:
    - np.ndarray: The unionized graph.
    """
    # links are assumed not to differ in their mechanisms between regimes
    N, N, tau_max = graphs[0].shape
    union_graph = np.zeros(graphs[0].shape, dtype='<U3')
    for i in range(N):
        for j in range(N):
            for lag in range(tau_max):
                graph_entries_list = [graphs[k][i,j, lag] for k in range(nb_regimes)]
                union_graph[i, j, lag] = process_graph_entries(graph_entries_list)
        
    return union_graph

def process_intersection_graph_entries(entries):
    """
    Processes entries in the graph for the PC-B method.

    Parameters:
    - entries (list): List of graph entries to process.

    Returns:
    - str: Processed value for the graph entries.
    """
    if len(entries) == 0:
        return ''

    elif any([entry == '' for entry in entries]):
        return ''

    else:
        # Check if all entries in the list are the same
        if all(entry == entries[0] for entry in entries):
            return entries[0]
        
        # Check if there are different entries in the list
        if set(entries) == {'<--', '-->'}:
            return '<-->'

        if any([entry == 'o-o' for entry in entries]):
            if any([entry == 'x-x' for entry in entries]):
                return 'x-x'
        else:
            return 'o-o'


    return ''
        
def intersect_graphs(graphs, regime_indicator):
    """
    Intersects the regime graphs to the pooled data graphs - the implementation of the B-PCMCI method.

    Parameters:
    - graphs (list): List of graphs for each regime.
    - regime_indicator (tuple): Regime indicator.

    Returns:
    - np.ndarray: The intersected graph.
    """
    # links are assumed not to differ in their mechanisms between regimes
    N, N, tau_max = graphs[0].shape
    intersection_graph = np.zeros(graphs[0].shape, dtype='<U3')
    pos_contemp = []
    for i in range(N):
        for j in range(N):
            for lag in range(tau_max):
                if (i, lag) != regime_indicator and (j, lag) != regime_indicator:
                    graph_entries_list = [graphs[k][i,j,lag] for k in range(len(graphs))]
                    intersection_graph[i, j, lag] = process_intersection_graph_entries(graph_entries_list)
                if i == regime_indicator[0] and lag == regime_indicator[1]:
                    intersection_graph[regime_indicator[0], :, lag] = graphs[1][regime_indicator[0], :, lag]
                if j == regime_indicator[0] and lag == regime_indicator[1]:
                    intersection_graph[:, regime_indicator[0], lag] = graphs[1][:, regime_indicator[0], lag]
                    
    return intersection_graph


def generate_regime_model(N, density, child_seeds, nb_regimes, max_lag, nb_changed_links, 
                          remove_only=False, cycles_only=False,
                          dep_coeffs=[-0.1, -0.2, 0.1, 0.2], auto_coeffs=[-0.1, 0.1], contemp_fraction=0.0,
                          regime_autocorr=0.7, regime_endo=True, regime_indicator=None, contemp_context=False):
    """
    Generates a regime model with the specified parameters.

    Parameters:
    - N (int): Number of variables in the model.
    - density (float): Density of the links in the model.
    - child_seeds (list): List of seeds for generating child links.
    - nb_regimes (int): Number of regimes.
    - max_lag (int): Maximum lag in the model.
    - regime_indicator (tuple): Regime indicator
    - max_lag (int): Maximum lag in the model.
    - nb_changed_links (int): Number of links to change.
    - remove_only (bool): If True, only remove links; otherwise, add and flip links as well.
    - cycles_only (bool): If True, only generate models with cycles.
    - dep_coeffs (list): Coefficients for dependencies.
    - auto_coeffs (list): Coefficients for self-dependencies. Does not apply to the non-timeseries case.
    - contemp_fraction (float): Fraction of contemporaneous links in the model.
    - context_context (bool): If True, the regime indicator must have at least one contemporaneous parent.
    - regime_autocorr (float): Autocorrelation coefficient for the regime indicator.
    - regime_endo (bool): If True, the regime indicator is endogenous; otherwise, it is exogenous.

    Returns:
    - tuple: Contains the base links, joint links, regime children, and causal order.
    """
    L = int(density * (N * (N - 1) / 2))
    
    # generate basic link dictionary that then is modified per regime
    trial_counter = 0
    
    links_base = links_joint = regime_children = causal_order = None

    if max_lag > 0:
        cycles_only = False

    if max_lag == 0:
        contemp_fraction = 1.

    while trial_counter < 10:
        base_seed = np.random.MT19937(child_seeds[-2])
        noise_seed = np.random.MT19937(child_seeds[-1])
        links_base, _ = toys.generate_structural_causal_process(N=N, L=L, contemp_fraction=contemp_fraction, max_lag=max_lag,
                                                                dependency_coeffs=dep_coeffs,
                                                                auto_coeffs=auto_coeffs,
                                                                noise_seed=noise_seed,
                                                                seed=base_seed)
        if links_base is None:
            continue

        if regime_indicator is None:
            if regime_endo:
                candidates = []
                for j, links in links_base.items():
                    # collect (parent, lag) for non-self parents of j
                    parents = [(p, lag) for ((p, lag), _, _) in links if p != j]

                    if contemp_context:
                        cont_parents = {p for (p, lag) in parents if lag == 0}
                        lagged_parents = {p for (p, lag) in parents if lag < 0}
                        if len(cont_parents) >= 1 and len(lagged_parents) >= 1:
                            lag = np.random.choice(np.arange(max_lag + 1))
                            candidates.append((j, 0))
                    else:
                        lagged_parents = {p for (p, lag) in parents if lag < 0}
                        if len(lagged_parents) >= 1:
                            lag = np.random.choice(np.arange(max_lag + 1))
                            candidates.append((j, 0))
                vars_to_sel_reg_from = candidates

            else:
                candidates = []
                for j, links in links_base.items():
                    nonself_parents = [(p, lag) for ((p, lag), _, _) in links if p != j]
                    if len(nonself_parents) == 0:
                        if contemp_context:
                            candidates.append((j, 0))
                        else:
                            for lag in range(max_lag + 1):
                                candidates.append((j, -lag))
                                
                vars_to_sel_reg_from = candidates

            if vars_to_sel_reg_from == []:
                return None, None, None, None, None, None
            # select regime
            regime_indicator = vars_to_sel_reg_from[np.random.choice(list(range(len(vars_to_sel_reg_from))))]
        # clear all dicts
        links_joint = {}
        regime_children = {}
        causal_order = []
        regime_children_final = [[]]

        links_base[regime_indicator[0]][0] = ((regime_indicator[0], -1), regime_autocorr, lin_f)
        
        regime_has_causal_order_none_or_non_unique = False
        links_joint[0] = links_base

        for i in range(1, nb_regimes):
            links_i, regime_children_i, causal_order_i = get_links_i(links_base, i, child_seeds, regime_indicator,
                                                                    max_lag, nb_changed_links, 
                                                                    remove_only=remove_only)
            if causal_order_i is None:
                print('regime not valid')
                return None, None, None, None, None, None

            if links_i == links_base:
                print('regime equal to base')
                return None, None, None, None, None, None

            for key, existing_regimes in links_joint.items():
                if links_i == existing_regimes:
                    regime_has_causal_order_none_or_non_unique = True 
            links_joint[i] = links_i
            regime_children[i] = regime_children_i
            causal_order.append(causal_order_i)
        
        if regime_has_causal_order_none_or_non_unique:
            links_base = links_joint = regime_children = causal_order = None
            trial_counter += 1
            continue

        for k, rc in regime_children.items():
            regime_children_final.append(rc)

        regime_children = clean_regime_children(regime_children_final)

        links_with_regime_children = add_regime_children_single_regime(copy.deepcopy(links_base), regime_children,
                                                                    regime_indicator)

        causal_order_base = check_links(links_with_regime_children)
        
        if causal_order_base is None:
            print('no base without cycles found, redraw base')
            trial_counter += 1
            links_base = links_joint = regime_children = causal_order = None
            continue

        causal_order.insert(0, causal_order_base)
        

        if cycles_only==False:
            union_graph = unionize_links(links_joint, regime_children, nb_regimes=nb_regimes, regime_indicator=regime_indicator)
            causal_order_union = check_links(union_graph)
            if causal_order_union is not None:
                links_with_regime_children = add_regime_children_all_regimes(copy.deepcopy(links_joint), regime_children,
                                                            regime_indicator)
                return links_base, links_joint, links_with_regime_children, regime_children, causal_order, regime_indicator
            else:
                print('union has cycles')
                return None, None, None, None, None, None
        else:
            # check if unionization has a cycle
            union_graph = unionize_links(links_joint, regime_children, nb_regimes=nb_regimes, regime_indicator=regime_indicator)
            causal_order_union = check_links(union_graph)
            if causal_order_union is None:
                links_with_regime_children = add_regime_children_all_regimes(copy.deepcopy(links_joint), regime_children,
                                                        regime_indicator)
                return links_base, links_joint, links_with_regime_children, regime_children, causal_order, regime_indicator
            else:
                print('union does not have cycles')
                return None, None, None, None, None, None
    links_with_regime_children = add_regime_children_all_regimes(copy.deepcopy(links_joint), regime_children,
                                                        regime_indicator)
    
    return links_base, links_joint, links_with_regime_children, regime_children, causal_order, regime_indicator

def generate_regime_data(T, links_joint, nb_regimes, child_seeds, regime_indicator, max_lag, causal_order, regime_thresholds=None, imbalance_factor=None):
    """
    Generates data based on the regime model.

    Parameters:
    - T (int): Number of time steps for the data.
    - links_joint (dict): Links for each variable in the joint regime model.
    - nb_regimes (int): Number of regimes.
    - child_seeds (list): List of seeds for generating child links.
    - regime_indicator (tuple): Regime indicator.
    - max_lag (int): Maximum lag in the model.
    - causal_order (list): Causal order for the variables.
    - regime_thresholds (list, optional): Thresholds for defining regimes.
    - imbalance_factor (float, optional): Factor to adjust the imbalance when computing thresholds.

    Returns:
    - tuple: Contains the generated data, masks, and data type.
    """
    regime_toymodel = RegimeModel(links=links_joint, noises=None, seed=np.random.MT19937(child_seeds[-3]), causal_order=causal_order,
                                  regime_thresholds=regime_thresholds, extreme_indicator=regime_indicator, max_lag=max_lag, imbalance_factor=imbalance_factor)

    data = None
    data = regime_toymodel.generate_data(T=T) 

    if data is None:
        return None, None
    data_type = np.zeros(data.shape, dtype=int)
    data_type[:, regime_indicator[0]] = 1
    return data, data_type

def edit_once(links_base, rand_var, system_vars, regime_indicator, max_lag, method, random_state):
    """
    Edits the links base by adding, removing, or flipping links based on a specified method.

    Parameters:
    - links_base (dict): Base links to be edited.
    - rand_var (int): Random variable chosen for editing.
    - system_vars (list): List of system variables.
    - regime_indicator (tuple): Specifies which variable and at what lag is considered as extreme.
    - max_lag (int): Maximum lag in the model.
    - method (str): Method to edit the links ('remove', 'add').
    - random_state (np.random.Generator): Random state for generating random values.

    Returns:
    - tuple: Updated links and new regime children.
    """
    links_i = copy.deepcopy(links_base)
    new_regime_children = []

    links_to_pick = dict()

    for var in range(len(links_i)):
        if var != regime_indicator[0]:
            links_of_var = []
            for link in links_i[var]:
                # first if : the context variable cannot affect itself
                # second if: do not allow autocorrelations to be removed
                if link[0][0] != regime_indicator[0] and link[0][0] != var: # think about it 
                    links_of_var.append(link)
            if len(links_of_var) > 0:
                links_to_pick[var] = links_of_var

    if len(links_to_pick) == 0:
        return None, None
            
    if method == 'remove':
        vars_to_pick = list(links_to_pick.keys())
        rand_parent = random_state.choice(vars_to_pick)
        chosen_link_no = random_state.choice([l for l in range(len(links_to_pick[rand_parent]))])
        chosen_link = links_to_pick[rand_parent][chosen_link_no]
        # pop link
        links_i[rand_parent].remove(chosen_link)
        new_regime_children.append((rand_parent, 0))
        
    elif method == 'add':
        rand_parent = system_vars[random_state.integers(len(system_vars))]
        rand_lag = random_state.integers(max_lag + 1)

        if (rand_parent, -rand_lag) not in [item[0] for item in links_i[rand_var]] and not (
                rand_lag == 0 and rand_parent == rand_var):
            rand_coeff = random_state.choice(np.arange(-0.5, 0.5, step=0.1))
            links_i[rand_var].append(((rand_parent, -rand_lag), rand_coeff, lin_f))
            new_regime_children.append((rand_var, 0))
        else:
            return None, None
    
    return links_i, new_regime_children

def add_item(item, items):
    """
    Adds an item to a list if it is not already present.

    Parameters:
    - item: The item to be added.
    - items (list): The list to add the item to.

    Returns:
    - list: The updated list.
    """
    if item not in items:
        items.append(item)

    return items

def get_links_i(links_base, regime, child_seeds, regime_indicator, max_lag,
                nb_changed_links=1, remove_only=False):
    """
    Gets the links for a specific regime by modifying the base links.

    Parameters:
    - links_base (dict): Base links for the model.
    - regime (int): The specific regime to modify links for.
    - child_seeds (list): List of seeds for generating child links.
    - regime_indicator (tuple): Regime indicator
    - max_lag (int): Maximum lag in the model.
    - nb_changed_links (int): Number of links to change.
    - remove_only (bool): If True, only remove links; otherwise, add and flip links as well.

    Returns:
    - tuple: Contains the modified links, regime children, and causal order for the regime.
    """
    # nb_changed_links is the minimum difference between graphs,
    # maximal difference btw. two regime-specific graphs is nb_regimes*nb_changed_links
    trial_seeds = child_seeds[regime].spawn(10000)

    k = 0
    N = len(links_base)

    regime_children = None
    causal_order_i = None


    while k < 5:
        if remove_only == False:
            augment_methods = ["remove", "add", "flip"]
        else:
            augment_methods = ["remove"]

        random_state = np.random.default_rng(trial_seeds[k])
        # randomly select a subset of links (of all system variables) that should be removed in each regime
        system_vars = [var for var in range(N) if var != regime_indicator[0]]
        regime_children = []

        links_i = copy.deepcopy(links_base)

        success = True

        for _ in range(nb_changed_links):
            # Try one edit only
            success_edit = False
            for _ in range(10):  # try up to 10 times to get one valid edit
                rand_var = system_vars[random_state.integers(len(system_vars))]
            
                # only do possible operations 
                if (len(links_i[rand_var]) == 0):
                    augment_methods.remove("remove")
                    
                parents = [item for item in links_i[rand_var] if item[0] != regime_indicator]
                if len(parents) == 0:
                    augment_methods.remove("flip")

                if not augment_methods:
                    continue
                        
                # now select method
                rand_method = augment_methods[random_state.integers(len(augment_methods))]

                updated_links, new_children = edit_once(
                links_i, rand_var, system_vars, regime_indicator, max_lag, rand_method, random_state)

                if updated_links is None:
                    continue

                if new_children in regime_children:
                    continue

                candidate_links = add_regime_children_single_regime(copy.deepcopy(updated_links), new_children, regime_indicator)
                causal_order_i = check_links(candidate_links)

                if causal_order_i is not None:
                    links_i = updated_links
                    regime_children += new_children
                    success_edit = True
                    break

            if not success_edit:
                success = False
                break

        if success:
            return links_i, regime_children, causal_order_i
        else:
            return None, None, None
    