import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from typing import Tuple, List, Union, Sequence, Optional
import random
import math
from itertools import product, accumulate


def reorder_bucket_by_middle(layer_edges, rng=None):
    rng = rng or random
    base = max([len(edges) for edges in layer_edges])
    B = int(rng.randint(1, 6)/2 * base)
    triples = [(l, u, v) for l, es in enumerate(layer_edges[::-1]) for (u, v) in es]
    m = max(1, math.ceil(len(triples) / max(1, B)))
    triples.sort(key=lambda t: (((t[2] if t[0] % 2 == 0 else t[1]) % m), rng.random()))
    return [(u, v) for _, u, v in triples]

def random_shuffle_layer(layer_edges, rng=None):
    flattened_edges = [edge for sublist in layer_edges for edge in sublist]
    if rng is None:
        rng = random
    rng.shuffle(flattened_edges)
    return flattened_edges


def block_shuffle(layer_edges, rng=None):
    for edges in layer_edges:
        if rng is None:
            rng = random
        rng.shuffle(edges)
    layer_idxs = random.sample(range(len(layer_edges)), len(layer_edges))
    return [edge for sublist in [layer_edges[i] for i in layer_idxs] for edge in sublist]


def shuffle_edges(layer_edges, edge_shuffle_rule, rng=None):
    if edge_shuffle_rule == 'random':
        return random_shuffle_layer(layer_edges, rng)
    elif edge_shuffle_rule == 'bucketby':
        return reorder_bucket_by_middle(layer_edges, rng)
    elif edge_shuffle_rule == 'bylayer':
        return block_shuffle(layer_edges, rng)
    else:
        raise ValueError(f"Invalid edge shuffle rule: {edge_shuffle_rule}")

def fpar(v: int) -> bool:
    return bool(v % 2 == 0)


def get_hard_rule(_path_nodes, node_rule_name):
    start_mod_node = _path_nodes[0][0]
    end_mod_node = _path_nodes[len(_path_nodes)-1][0]
    if 'x' in node_rule_name:
        mod_start = fpar(start_mod_node)
        mod_end = fpar(end_mod_node)
        parity = mod_start ^ mod_end
    else:
        parity = fpar(end_mod_node)
    if '0' in node_rule_name:
        return 'mod0' if parity else 'mod1'
    elif '1' in node_rule_name:
        return 'mod0' if not parity else 'mod1'
    else:
        raise ValueError(f"Invalid node rule name: {node_rule_name}")

def get_shard_rule(_path_nodes, node_rule_name):
    rule = get_hard_rule(_path_nodes, 'hard0' if '0' in node_rule_name else 'hard1')
    start = 1 if '1' in rule else 0
    active_indices = [idx for idx, nodes in _path_nodes.items() if len(nodes) > 1] 
    k = len(active_indices)
    parities = [(start + i) % 2 for i in range(k)]
    node_rules = ['random' for _ in range(len(_path_nodes))]
    for i, idx in enumerate(active_indices):
        node_rules[idx] = f'mod{parities[i]}'
    return node_rules


def get_vhard_rule(_path_nodes, node_rule_name):
    all_nodes = [v for nodes in _path_nodes.values() for v in nodes]
    parities = [fpar(node) for node in all_nodes]
    parity = sum(parities) % 2
    active_layers = [idx for idx in range(len(_path_nodes)) if len(_path_nodes[idx]) > 1]
    node_rules = ['random' for _ in range(len(_path_nodes))]
    for i, layer_idx in enumerate(active_layers):
        layer_parity = int(fpar(layer_idx) ^ parity)
        node_rules[layer_idx] = f'mod{layer_parity}'
    return node_rules



def create_combo_graph(
    layers: int,
    width: int,
    m: int,
    **kwargs,
) -> Tuple[nx.Graph, int, int]:
    if kwargs.get('num_pass_layers', None) is not None:
        num_pass_layers = kwargs['num_pass_layers']
        num_pass_width = kwargs['num_pass_width']
        layer_indices = random.sample(range(layers), num_pass_layers)
        pass_per_layer = [num_pass_width if i in layer_indices else 1 for i in range(layers)]
    elif kwargs.get('pass_per_layer', None) is None:
        p = kwargs.get('p', 0.5)
        pass_per_layer = [max(1, int(np.round(p * width))) for _ in range(layers)]
    assert pass_per_layer is not None, "pass_per_layer must be specified"
    node_rule_dict = {
        'random': lambda nodes: nodes,
        'max': lambda nodes: [max(nodes)],
        'min': lambda nodes: [min(nodes)],
        'mod0': lambda nodes: random.sample([node for node in nodes if node % 2 == 0], 1) if len(nodes) > 1 else nodes,
        'mod1': lambda nodes: random.sample([node for node in nodes if node % 2 == 1], 1) if len(nodes) > 1 else nodes,
    }
    node_rule_name = kwargs.get('node_rule', 'random')
    total_nodes = layers * width + 2
    k_per_type = (max(pass_per_layer) + 1) // 2
    odds = random.sample([i for i in range(m) if i % 2], k_per_type * layers)
    evens = random.sample([i for i in range(m) if not i % 2], k_per_type * layers)
    remaining = random.sample([i for i in range(m) if i not in odds + evens], total_nodes - len(odds) - len(evens))
    start_node, goal_node = int(remaining[0]), int(remaining[1])
    layer_nodes = []
    active_nodes = []
    for i in range(layers):
        layer_odds = odds[i * k_per_type:(i + 1) * k_per_type]
        layer_evens = evens[i * k_per_type:(i + 1) * k_per_type]
        layer_nodes.append(layer_odds + layer_evens + remaining[2 + i * (width - 2*k_per_type):2 + (i + 1) * (width - 2*k_per_type)])
        k = pass_per_layer[i]
        if 'mod' in node_rule_name or 'hard' in node_rule_name:
            active_nodes.append(random.sample(layer_odds + layer_evens, k) if k > 1 else random.sample(layer_nodes[-1], k))
        else:
            active_nodes.append(random.sample(layer_nodes[-1], k))
    layer_nodes.append([goal_node])
    layer_edges = [[(start_node, next_node) for next_node in active_nodes[0]]]
    layer_edges += [[(prev_node, next_node) for prev_node in active_nodes[layer] for next_node in layer_nodes[layer+1]] for layer in range(layers)]
    num_pad_layers = kwargs.get('pad_layers', 0)
    num_edges_per_layer = kwargs.get('pad_edges', 0)
    pad_style = kwargs.get('pad_style', None)
    remaining_nodes = list(set(range(m)) - set(odds + evens + remaining))
    if pad_style == 'append':
        pad_nodes = random.sample(remaining_nodes, num_edges_per_layer * num_pad_layers)
        pad_layer_idxs = random.sample(range(layers), num_pad_layers)
        for pad_layer_idx in pad_layer_idxs:
            pad_edges = [(random.sample(pad_nodes, 1)[0], random.sample(pad_nodes, 1)[0]) for _ in range(num_edges_per_layer)]
            layer_edges[pad_layer_idx + 1] += pad_edges
    elif pad_style == 'k1layer':
        num_edges_per_layer = width 
        nodes_per_layer = num_edges_per_layer + 1
        pad_nodes = random.sample(remaining_nodes, nodes_per_layer * num_pad_layers)
        for pad_layer_idx in range(num_pad_layers):
            pad_layer_nodes = pad_nodes[pad_layer_idx*nodes_per_layer:(pad_layer_idx+1)*nodes_per_layer]
            parent_node = pad_layer_nodes[0]
            child_nodes = pad_layer_nodes[1:]
            pad_edges = [(parent_node, child_node) for child_node in child_nodes]
            layer_edges.append(pad_edges)
    flattened_nodes = [start_node] + [node for sublist in layer_nodes for node in sublist]
    edge_shuffle_rule = kwargs.get('edge_shuffle_rule', 'random')
    flattened_edges = shuffle_edges(layer_edges, edge_shuffle_rule)
    path_nodes = {0: [start_node], **{idx+1: active_nodes[idx] for idx in range(layers)}, layers+1: [goal_node]}
    if 'shard' in node_rule_name:
        _node_rules = get_shard_rule(path_nodes, node_rule_name)
        node_rules = [node_rule_dict[rule] for rule in _node_rules]
    elif 'vhard' in node_rule_name:
        _node_rules = get_vhard_rule(path_nodes, node_rule_name)
        node_rules = [node_rule_dict[rule] for rule in _node_rules]
    elif 'hard' in node_rule_name:
        _rule = get_hard_rule(path_nodes, node_rule_name)
        node_rules = [node_rule_dict[_rule] for _ in range(layers+2)]
    else:
        node_rules = [node_rule_dict[node_rule_name] for _ in range(layers+2)]
    policy_nodes = {i: node_rules[i](nodes) for i, nodes in path_nodes.items()}
    g_dict = {
        'source': start_node,
        'goal': goal_node,
        'edge_list': flattened_edges,
        'layer_edges': layer_edges, 
        'path_nodes': path_nodes,
        'policy_nodes': policy_nodes,
        'terminal_nodes': [goal_node],
        'path_length': layers+2,
        'edge_shuffle_rule': edge_shuffle_rule,
        'num_paths': int(np.prod([len(nodes) for nodes in policy_nodes.values()])),
        'layers': layers,
        'width': width,
        'm': m,
        'paths': [random.sample(policy_nodes[i], 1)[0] for i in range(layers+2)],
        **kwargs,
    }
    return g_dict
