import numpy as np
import pickle
import networkx as nx
from itertools import product
import string

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import InMemoryDataset, download_url
from torch_geometric.utils import from_networkx, to_networkx
from matrices import generate_ising_from_graph, sum_N





def kagome_site(x):
    return (x[0]%2 !=0) or (x[1]%2==1)

def hexagonal_site(x):
    return (2*x[0]+x[1]) %3 !=0

def custom_site(x):
    return ((3*x[1]+x[0])%2==0 and x[1]%4==1) or (x[1]%2==0) or ((3*x[1]+x[0])%2==1 and x[1]%4==3)

def get_triangular_lattice(L=5):
    base = (0, 0)
    r1 = np.array([1,0])
    r2 = np.array([1,np.sqrt(3)])/2
    neighbors = [(1, 0), (0, 1), (-1, 1), (-1, 0), (0, -1), (1, -1)]
    G = nx.Graph()
    G.add_node(0, coord=base, r=base[0]*r1 + base[1]*r2, color=0)
    node_count = 1
    index_coords = {base:0}


    for l in range(L):
        nodes_to_add = []
        for node in G.nodes:
            x, y = G.nodes[node]['coord']
            for dx, dy in neighbors:
                new_coord = (x+dx, y+dy)
                if not new_coord in index_coords:
                    nodes_to_add.append(new_coord)


        for new_coord in nodes_to_add:
            if not new_coord in index_coords:
                G.add_node(node_count, coord=new_coord, r=new_coord[0]*r1 + new_coord[1]*r2, color=(new_coord[0] - new_coord[1])%3)
                index_coords[new_coord] = node_count
                node_count += 1

        for (x, y), node in index_coords.items():
            for dx, dy in neighbors:
                if (x+dx, y+dy) in index_coords:
                    G.add_edge(node, index_coords[(x+dx, y+dy)])

    all_coords = np.array([G.nodes[node]['r'] for node in G.nodes])

    return G, index_coords, all_coords

def get_triangular_lattice_from_triangle(L=5):
    base_nodes = [(0, 0), (1, 0), (1, -1)]
    r1 = np.array([1,0])
    r2 = np.array([1,np.sqrt(3)])/2
    neighbors = [(1, 0), (0, 1), (-1, 1), (-1, 0), (0, -1), (1, -1)]
    G = nx.Graph()
    G.add_node(0, coord=base_nodes[0], r=base_nodes[0][0]*r1 + base_nodes[0][1]*r2, color=0)
    G.add_node(1, coord=base_nodes[1], r=base_nodes[1][0]*r1 + base_nodes[1][1]*r2, color=0)
    G.add_node(2, coord=base_nodes[2], r=base_nodes[2][0]*r1 + base_nodes[2][1]*r2, color=0)
    G.add_edges_from([(0, 1), (0, 2), (1, 2)])
    node_count = 3
    index_coords = {base_nodes[0]:0, base_nodes[1]:1, base_nodes[2]:2}


    for l in range(L):
        nodes_to_add = []
        for node in G.nodes:
            x, y = G.nodes[node]['coord']
            for dx, dy in neighbors:
                new_coord = (x+dx, y+dy)
                if not new_coord in index_coords:
                    nodes_to_add.append(new_coord)


        for new_coord in nodes_to_add:
            if not new_coord in index_coords:
                G.add_node(node_count, coord=new_coord, r=new_coord[0]*r1 + new_coord[1]*r2, color=(new_coord[0] - new_coord[1])%3)
                index_coords[new_coord] = node_count
                node_count += 1

        for (x, y), node in index_coords.items():
            for dx, dy in neighbors:
                if (x+dx, y+dy) in index_coords:
                    G.add_edge(node, index_coords[(x+dx, y+dy)])

    all_coords = np.array([G.nodes[node]['r'] for node in G.nodes])

    return G, index_coords, all_coords

def get_square_lattice(L=5):
    base = (0, 0)
    r1 = np.array([1,0])
    r2 = np.array([0, 1])
    neighbors = [(1, 0), (0, 1), (1, 1)]
    G = nx.Graph()
    # G.add_node(0, coord=base, r=base[0]*r1 + base[1]*r2, color=0)
    node_count = 1
    index_coords = {base:0}


    for i in range(L):
        for j in range(L):
            G.add_node(i + L*j, coord=(i, j), r=r1*i + r2*j, color=0)
            index_coords[(i, j)] = i + L*j
            if i>0:
                G.add_edge(i + L*j, i + L*j - 1)
            if j>0:
                G.add_edge(i + L*j, i + L*(j - 1))
        # nodes_to_add = []
        # for node in G.nodes:
        #     x, y = G.nodes[node]['coord']
        #     for dx, dy in neighbors:
        #         new_coord = (x+dx, y+dy)
        #         if not new_coord in index_coords:
        #             nodes_to_add.append(new_coord)
                

        # for new_coord in nodes_to_add:
        #     if not new_coord in index_coords:
        #         G.add_node(node_count, coord=new_coord, r=new_coord[0]*r1 + new_coord[1]*r2, color=(new_coord[0] - new_coord[1])%2)
        #         index_coords[new_coord] = node_count
        #         node_count += 1
        
        # for (x, y), node in index_coords.items():
        #     for dx, dy in neighbors[0:2]:
        #         if (x+dx, y+dy) in index_coords:
        #             G.add_edge(node, index_coords[(x+dx, y+dy)])

    all_coords = np.array([G.nodes[node]['r'] for node in G.nodes])

    return G, index_coords, all_coords

def return_symmetrical_nodes(G, node_list):
    total_nodes = []
    comp_nodes = []
    all_coords = [G.nodes[node]['coord'] for node in node_list]
    for node in G.nodes():
        if G.nodes[node]['coord'][0] >= 0:
            if G.nodes[node]['coord'] in all_coords:
                total_nodes.append(node)
            else:
                comp_nodes.append(node)
        # elif G.nodes[node]['coord'][0] == 0:
        #     total_nodes.append(node)
        else:
            if (-G.nodes[node]['coord'][0], -G.nodes[node]['coord'][1]) in all_coords:
                total_nodes.append(node)
            else:
                comp_nodes.append(node)
    assert len(comp_nodes) + len(total_nodes) == nx.number_of_nodes(G)
    return total_nodes, comp_nodes



def return_data_list(L=8, dataset_size=50, pattern_size=70, seed_pattern=0, seed_dataset=98, keep_pattern=.8, keep_other=.3):
    G, index_coords, coords = get_triangular_lattice(L)
    nodes_to_select = [node for node in G.nodes() if G.nodes[node]['coord'][0]>=0]

    data_list = []

    for i in range(2):
        np.random.seed(seed_pattern + 4*i)
        pattern = np.sort(np.random.choice(nodes_to_select, size=pattern_size, replace=False))
        nodes_pattern, comp_nodes = return_symmetrical_nodes(G, pattern)
        np.random.seed(seed_dataset + 4*i)
        for _ in range(dataset_size):
            noisy_nodes = np.sort(np.concatenate([np.random.choice(nodes_pattern, size=int(len(nodes_pattern)*keep_pattern), replace=False),
                                            np.random.choice(comp_nodes, size=int(len(comp_nodes)*keep_other), replace=False)]))
            subgraph = nx.convert_node_labels_to_integers(G.subgraph(noisy_nodes), ordering='sorted')
            L = nx.laplacian_matrix(subgraph).toarray()
            l, v = np.linalg.eigh(L)
            v = np.array(v)
            data = from_networkx(subgraph, ['r'])
            data.x = torch.tensor(v[:, 0:30]).float()
            data.y = torch.tensor([i])
            data.pos = torch.tensor(coords[noisy_nodes])
            data_list.append(data)

    return data_list

def sample_positions(lattice_size=30, pattern_size=8, n_samples=4):
    positions = np.random.randint(lattice_size - pattern_size, size=(200, 2))
    positions_to_return = [positions[0]]
    for pos_to_test in positions[1::]:
        add_pos = True
        for pos in positions_to_return:
            if np.max(np.abs(pos - pos_to_test)) <= pattern_size:
                add_pos = False
        if add_pos:
            positions_to_return.append(pos_to_test)
        if len(positions_to_return) == n_samples:
            break
    assert len(positions_to_return) <= n_samples
    return positions_to_return

def create_random_patterns_dataset(lattice_size=30, n_samples=200, pattern_size=8, pattern_density=.7,
                                    keep_other=.7, noise=0, n_patterns=4, seed_pattern=73, seed_dataset=12):
    G, index_coords, coords = get_square_lattice(lattice_size)

    dataset = np.zeros((2*n_samples, len(index_coords)))
    pattern_indicators = np.zeros((2*n_samples, len(index_coords)))
    dataset_graphs = []

    targets = np.ones(2*n_samples)
    targets[0:n_samples] = 0

    np.random.seed(seed_pattern)

    patterns = []
    for k in range(2):
        pattern = np.random.choice(pattern_size**2, size=int(pattern_density * pattern_size**2), replace=False)
        patterns.append(pattern.copy())

    np.random.seed(seed_dataset)
    

    for k in range(2):
        pattern = patterns[k]
        for l in range(n_samples):

            positions = sample_positions(lattice_size=lattice_size, pattern_size=pattern_size, n_samples=n_patterns)
            nodes_to_select = np.random.binomial(1, keep_other, size=nx.number_of_nodes(G))

            for pos in positions:
                x, y = pos
                node_list = []
                for j in range(pattern_size):
                    for i in range(pattern_size):
                        node_list.append(index_coords[(x+i, y+j)])
                node_list = np.sort(node_list)
                nodes_to_select[node_list] = 0
                nodes_to_select[node_list[pattern]] = 1
                noise_nodes = np.random.binomial(1, noise, size=len(node_list))
                noise_nodes = node_list[np.abs(noise_nodes - 1)<1e-5]
                nodes_to_select[noise_nodes] = 1 - nodes_to_select[noise_nodes].copy()
                pattern_indicators[l+n_samples*k, node_list.copy()] = 1

            dataset[l+n_samples*k] = nodes_to_select.copy()
            dataset_graphs.append(
                G.subgraph(
                    np.arange(nx.number_of_nodes(G))[np.abs(nodes_to_select - 1)<1e-6]
                    ))

    return dataset, dataset_graphs, targets, pattern_indicators


def convert_graph_dataset_to_images(graphs_list, lattice_size=20):
    all_images = []
    for graph in graphs_list:
        image = np.zeros((lattice_size, lattice_size))
        for node in graph.nodes():
            image[graph.nodes[node]['coord']] = 1
        all_images.append(image)
    return np.array(all_images)

def add_branch(G, coords, position, bias, max_length=10, p=1):
    N = nx.number_of_nodes(G)
    x1, y1 = G.nodes[position[0]]['coord']
    x2, y2 = G.nodes[position[1]]['coord']

    all_biases = [np.array([0, 1]), np.array([1, 0]), np.array([-1, 0]), np.array([0, -1])]


    for i in range(max_length):
        random = np.random.binomial(1, p)
        stop = False
        for b in all_biases:
            if np.sum((b + bias)**2) > .1:
                if tuple(np.array([x1, y1])+bias*(i+1)+b) in coords:
                    stop = True
                if tuple(np.array([x2, y2])+bias*(i+1)+b) in coords:
                    stop = True
        if stop:
            break

        if random:
            c0 = np.array([x1, y1])+bias*(i+1)
            c1 = np.array([x2, y2])+bias*(i+1)
            if 'c' in G.nodes[position[0]]:
                if i==0:
                    c = G.nodes[position[0]]['c']
                else:
                    c = G.nodes[N+2*(i-1)]['c']
                G.add_node(N+2*i, coord=c0, c=1-c)
                G.add_node(N+2*i+1, coord=c1, c=c)
            else:
                G.add_node(N+2*i, coord=c0)
                G.add_node(N+2*i+1, coord=c1)
            coords.append(tuple(c0))
            coords.append(tuple(c1))
            G.add_edge(N+2*i, N+2*i+1)
            if i==0:
                G.add_edge(position[0], N + 2*i)
                G.add_edge(position[1], N + 2*i+1)
            else:
                G.add_edge(N + 2*(i-1), N + 2*i)
                G.add_edge(N + 2*(i-1)+1, N + 2*i+1)
        else:
            break


def generate_maze_graph(init_length=20, max_iter=100, max_nodes=150, p=.9, max_length=10):

    G = nx.Graph()
    G.add_node(0, coord=np.array([0, 0]), c=0)
    G.add_node(1, coord=np.array([0, 1]), c=1)
    G.add_edge(0, 1)
    all_coords = [(0, 0), (0, 1)]
    add_branch(G, all_coords, position=[0, 1], bias=np.array([1, 0]), max_length=init_length-1, p=1)

    for iter in range(max_iter):
        base_node = np.random.choice(list(G.nodes()))
        neighbors = list(G.neighbors(base_node))
        c = G.nodes[base_node]['coord']
        bias = 0
        pair_choice = []
        if len(neighbors) == 3:
            all_biases = []
            for n in neighbors:
                cn = G.nodes[n]['coord']
                all_biases.append(tuple(cn - c))
            for k, b in enumerate(all_biases):
                if tuple(-np.array(b)) in all_biases:
                    pair_choice.append(neighbors[k])
                else:
                    bias = -np.array(b)
            position = [np.random.choice(pair_choice), base_node]
            add_branch(G, all_coords, position=position, bias=bias, p=p, max_length=max_length)
        if nx.number_of_nodes(G) >= max_nodes:
            break

    return G

def twist_nodes(G, positions):
    v0, v1 = positions
    neighbors_0 = list(G.neighbors(v0))
    neighbors_1 = list(G.neighbors(v1))
    for v in neighbors_0:
        n0 = v
        if n0 != v1:
            break
    for v in neighbors_1:
        n1 = v
        if ((n1 != v0) and (n1 != n0) and (not G.has_edge(n1, n0))):
            break
    G.remove_edge(n0, v0)
    G.remove_edge(n1, v1)
    G.add_edge(n0, v1)
    G.add_edge(n1, v0)
    return G


def find_nodes_cut(G, cut_size_min=10, cut_size_max=20):
    nodes_to_select = []
    for node in G.nodes():
        neighbors = list(G.neighbors(node))
        c = G.nodes[node]['coord']
        if len(neighbors) == 3:
            all_biases = []
            for n in neighbors:
                cn = G.nodes[n]['coord']
                all_biases.append(tuple(cn - c))
            for n in neighbors:
                cn = G.nodes[n]['coord']
                if tuple(c-cn) not in all_biases:
                    if len(list(G.neighbors(n))) == 3:
                        pair = [node, n]
                        if [n, node] not in nodes_to_select:
                            G2 = G.copy()
                            G2.remove_nodes_from(pair)
                            components = list(nx.connected_components(G2))
                            min_cut = min(len(components[0]), len(components[1]))
                            if min_cut >= cut_size_min and min_cut <= cut_size_max:
                                nodes_to_select.append(pair)
    return nodes_to_select


def generate_maze_graph_with_cut(init_length=20, max_iter=100, max_nodes=150, cut_size_min=10, cut_size_max=20, n_twist=1, p=.9, max_length=10):
    G = generate_maze_graph(init_length, max_iter, max_nodes, p=p, max_length=max_length)
    nodes_to_select = find_nodes_cut(G, cut_size_min, cut_size_max)
    twists = []
    if n_twist > len(nodes_to_select):
        n_twist = len(nodes_to_select)
    for i in np.random.choice(len(nodes_to_select), replace=False, size=n_twist):
        twist_nodes(G, nodes_to_select[i])
        twists.append(nodes_to_select[i])
    
    subgraphs_twists = [G.subgraph([twist[0], twist[1]] + list(G.neighbors(twist[0])) + list(G.neighbors(twist[1]))).copy() for twist in twists]
    subgraphs_no_twists = []
    for i, subgraph in enumerate(subgraphs_twists):
        if i==0:
            G2 = G.copy()
            G2.remove_nodes_from(subgraph.nodes())
            components = list(nx.connected_components(G2))
            for component in components:
                subgraphs_no_twists.append(G2.subgraph(list(component)).copy())
        else:
            new_subgraphs = []
            for subgraph_no_twist in subgraphs_no_twists:
                G2 = subgraph_no_twist.copy()
                G2.remove_nodes_from(subgraph.nodes())
                components = list(nx.connected_components(G2))
                for component in components:
                    new_subgraphs.append(subgraph_no_twist.subgraph(list(component)).copy())
            subgraphs_no_twists = new_subgraphs

    gs_no_twists = []
    for subgraph in subgraphs_no_twists:
        gs = np.array([subgraph.nodes[node]['c'] for node in subgraph.nodes()])
        gs_no_twists.append([gs.copy(), 1-gs.copy()])
    # print(gs_no_twists)
    nodes_twists = []
    for subgraph in subgraphs_twists:
        nodes_twists += list(subgraph.nodes())
    nodes_twists = np.array(nodes_twists)
    bits_no_twists = [[0, 1] for _ in subgraphs_no_twists]

    bitstring_twists = []
    for subgraph in subgraphs_twists:
        l = []
        node_to_index = dict()
        for i, node in enumerate(list(subgraph.nodes())):
            node_to_index[node] = i
        for bitstring in product(*[[0, 1] for _ in range(nx.number_of_nodes(subgraph))]):
            select = True
            for edge in subgraph.edges():
                if bitstring[node_to_index[edge[0]]] + bitstring[node_to_index[edge[1]]] >= 1.5:
                    select = False
            if select:
                l.append(np.array(list(bitstring)).copy())
        bitstring_twists.append(l)
    max_set = 0

    edges_to_check = []
    for edge in G.edges():
        check = True
        for subgraph in subgraphs_no_twists:
            if edge in list(subgraph.edges()):
                check=False
        if check:
            edges_to_check.append(edge)
    gs_list = []

    for bitstring in product(*bitstring_twists):
        for bitstring_no_twist in product(*bits_no_twists):
            state_test = np.zeros(nx.number_of_nodes(G))
            state_test[nodes_twists] = np.concatenate(list(bitstring))
            for j, bit in enumerate(bitstring_no_twist):
                nodes = np.array(list(subgraphs_no_twists[j].nodes()))
                state_test[nodes] = gs_no_twists[j][bit].copy()
            ind_set = True
            for edge in edges_to_check:
                if state_test[edge[0]] + state_test[edge[1]] - 1 > .5:
                    ind_set=False
            if ind_set:
                if np.sum(state_test) - max_set > .5:
                    gs_list = []
                    gs_list.append(state_test)
                    max_set = np.sum(state_test)
                elif np.abs(np.sum(state_test) - max_set) < 1e-6:
                    gs_list.append(state_test)
                else:
                    pass


    return G, gs_list

def generate_maze_graph_with_cut_and_hexagons(init_length=20, max_iter=100, max_nodes=150, cut_size_min=10, cut_size_max=20, n_twist=1, p=.9, max_length=10, n_hexagon=5):
    G = generate_maze_graph(init_length, max_iter, max_nodes, p=p, max_length=max_length)
    

    nodes_to_select_hexagon = find_nodes_cut(G, 4, nx.number_of_nodes(G)-4)
    indices_hexagon = np.random.choice(len(nodes_to_select_hexagon), replace=False, size=n_hexagon)
    for i in indices_hexagon:
        G.remove_edges_from([nodes_to_select_hexagon[i]])
    
    nodes_to_select_twist = find_nodes_cut(G, cut_size_min, cut_size_max)
    if n_twist > len(nodes_to_select_twist):
        n_twist = len(nodes_to_select_twist)
    indices_twist = np.random.choice(len(nodes_to_select_twist), replace=False, size=n_twist)
    for i in indices_twist:
        twist_nodes(G, nodes_to_select_twist[i])
    

    return G

# def concatenate_graphs(G1, G2, e_g1, e_g2):
#     u = G1.nodes[e_g1[0]]['coord'] - G1.nodes[e_g1[1]]['coord']
#     v = G2.nodes[e_g2[0]]['coord'] - G2.nodes[e_g2[1]]['coord']
#     trans = G1.nodes[e_g1[0]]['coord'] - G2.nodes[e_g2[0]]['coord']
#     d = np.sign(np.linalg.det(np.concatenate([v.reshape((-1, 1)), u.reshape((-1, 1))], 1)))
#     theta = np.arccos(np.dot(v, u)) * d
#     rotation = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
#     for node in G2.nodes():
#         G2.nodes[node]['coord'] = np.dot(rotation, G2.nodes[node]['coord']) + trans + np.array([-u[1], u[0]])
#     G = nx.union(G1.copy(), G2.copy(), rename=('', 'L'))
#     G.add_edges_from([(f'{e_g1[0]}', f'L{e_g2[0]}'), (f'{e_g1[1]}', f'L{e_g2[1]}')])
#     return nx.convert_node_labels_to_integers(G)

def concatenate_graphs(G1, G2, e_g1, e_g2):
    u = G1.nodes[e_g1[0]]['coord'] - G1.nodes[e_g1[1]]['coord']
    v = G2.nodes[e_g2[0]]['coord'] - G2.nodes[e_g2[1]]['coord']
    trans = G1.nodes[e_g1[0]]['coord'] - G2.nodes[e_g2[0]]['coord']
    d = np.sign(np.linalg.det(np.concatenate([v.reshape((-1, 1)), u.reshape((-1, 1))], 1)))
    theta = np.arccos(np.dot(v, u)) * d
    rotation = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
    for node in G2.nodes():
        G2.nodes[node]['coord'] = np.dot(rotation, G2.nodes[node]['coord']) + trans + np.array([-u[1], u[0]])
    G = nx.union(G1.copy(), G2.copy(), rename=('', 'L'))
    G.add_edges_from([(f'{e_g1[0]}', f'L{e_g2[0]}'), (f'{e_g1[1]}', f'L{e_g2[1]}')])
    return nx.convert_node_labels_to_integers(G)

def construct_special_pattern_dataset(dataset_size=1000, seed=902):
    triang_piece = nx.Graph()
    triang_piece.add_nodes_from([(0, {'coord': np.array([0., 0.])}), (1, {'coord': np.array([1., 0.])}),
                                    (2, {'coord': np.array([2., 0.])}), (3, {'coord': np.array([-1., np.sqrt(3)])/2}),
                                    (4, {'coord': np.array([1., np.sqrt(3)])/2}), (5, {'coord': np.array([3., np.sqrt(3)])/2})])
    edges = [(0, 1), (1, 2), (3, 4), (4, 5), (0, 3), (0, 4), (1, 4), (1, 5), (2, 5)]
    triang_piece.add_edges_from(edges)
    coords = [triang_piece.nodes[node]['coord'] for node in triang_piece.nodes()]

    N = nx.number_of_nodes(triang_piece)

    ising = (generate_ising_from_graph(triang_piece, type_ising='N') - .1 * sum_N(N)).numpy().reshape((-1,)).real

    min_energy = np.min(ising)
    ground_states = np.arange(len(ising))[np.abs(ising - min_energy)<=1e-4]

    gs_triang = dict()
    for s in ground_states:
        bitstring = bin(s)[2::].zfill(N)
        gs_triang[bitstring] = 1

    dataset_graphs = []
    ground_states = []
    e_triang = (1, 2)
    e_g = (0, 1)
    np.random.seed(seed)
    for _ in range(dataset_size):
        gs = dict()
        g1 = generate_maze_graph(10, max_iter=20, max_nodes=200)
        g2 = generate_maze_graph(10, max_iter=20, max_nodes=200)
        G = concatenate_graphs(triang_piece, g1, e_triang, e_g)
        G = concatenate_graphs(G, g2, (4, 3), e_g)
        dataset_graphs.append(G)
        gs1 = np.array([g1.nodes[node]['c'] for node in g1.nodes()])
        gs2 = np.array([g2.nodes[node]['c'] for node in g2.nodes()])

        gs1_list = [gs1, 1 - gs1]
        gs2_list = [gs2, 1 - gs2]

        for state in gs_triang.keys():
            int_state = np.array(list(state)[::-1]).astype(int)
            for state_1 in gs1_list:
                if (state_1[0] + int_state[1] <= 1) and (state_1[1] + int_state[2] <= 1):
                    for state_2 in gs2_list:
                        if (state_2[0] + int_state[4] <= 1) and (state_2[1] + int_state[3] <= 1):
                            state_to_add = "".join(list(np.concatenate([int_state, state_1, state_2]).astype(str))[::-1])
                            gs[state_to_add] = 1
        ground_states.append(gs)
    

    for _ in range(dataset_size):
        gs = dict()
        g1 = generate_maze_graph(10, max_iter=20, max_nodes=200)
        g2 = generate_maze_graph(10, max_iter=20, max_nodes=200)
        G = concatenate_graphs(triang_piece, g1, e_triang, e_g)
        G = concatenate_graphs(G, g2, (3, 0), e_g)
        dataset_graphs.append(G)
        gs1 = np.array([g1.nodes[node]['c'] for node in g1.nodes()])
        gs2 = np.array([g2.nodes[node]['c'] for node in g2.nodes()])

        gs1_list = [gs1, 1 - gs1]
        gs2_list = [gs2, 1 - gs2]

        for state in gs_triang.keys():
            int_state = np.array(list(state)[::-1]).astype(int)
            for state_1 in gs1_list:
                if (state_1[0] + int_state[1] <= 1) and (state_1[1] + int_state[2] <= 1):
                    for state_2 in gs2_list:
                        if (state_2[0] + int_state[3] <= 1) and (state_2[1] + int_state[0] <= 1):
                            state_to_add = "".join(list(np.concatenate([int_state, state_1, state_2]).astype(str))[::-1])
                            gs[state_to_add] = 1
        ground_states.append(gs)

    targets = np.zeros(2*dataset_size).astype(int)
    targets[0:dataset_size] = 1
    return dataset_graphs, targets, ground_states

def generate_dataset_ladder_with_cross(dataset_size=100, length=200, seed=98):

    dataset_graphs = []
    np.random.seed(seed)

    for _ in range(dataset_size):
        G = generate_maze_graph(init_length=length, max_iter=0)
        length_sequences = np.random.randint(1, 6, size=500)
        G.add_edge(2, 5)
        index = 2
        assert nx.number_of_nodes(G) == 2*length

        for j, l in enumerate(length_sequences):
            if (j//4)%2 ==0:
                index += 2*(2*l)
            else:
                index += 2*(2*l + 1)
            if index+3 > nx.number_of_nodes(G) - 1:
                break
            else:
                G.add_edge(index, index+3)
            assert nx.number_of_nodes(G) == 2*length
        print(j)        
        dataset_graphs.append(G.copy())

    for _ in range(dataset_size):
        G = generate_maze_graph(init_length=length, max_iter=0)
        assert nx.number_of_nodes(G) == 2*length
        length_sequences = np.random.randint(1, 6, size=500)
        G.add_edge(2, 5)
        index = 2
        for j, l in enumerate(length_sequences):
            if (j//4)%2 ==0:
                index += 2*(2*l)
            else:
                index += 2*(2*l + 1)
            if index+3 > nx.number_of_nodes(G) - 1:
                break
            else:
                G.add_edge(index, index+3)
        assert nx.number_of_nodes(G) == 2*length
        
        dataset_graphs.append(G.copy())
    targets = np.zeros(2*dataset_size).astype(int)
    targets[0:dataset_size] = 1
    return dataset_graphs, targets

def generate_dataset_maze_graph(dataset_size=100, seed=23):
    dataset_graphs = []
    gs_dict_list = []
    np.random.seed(seed)
    for _ in range(dataset_size):
        G, gs_list = generate_maze_graph_with_cut(init_length=20, max_iter=500, max_nodes=400, p=.998, n_twist=2,
                                                           cut_size_min=100, cut_size_max=200, max_length=40)
        gs_dict = {}
        for gs in gs_list:
            gs_str = ''.join(gs.astype(int).astype(str))
            gs_dict[gs_str] = 1
        gs_dict_list.append(gs_dict)
        dataset_graphs.append(G.copy())
    for _ in range(dataset_size):
        G, gs_list = generate_maze_graph_with_cut(init_length=20, max_iter=500, max_nodes=400, p=.998, n_twist=2,
                                                           cut_size_min=10, cut_size_max=15, max_length=40)
        gs_dict = {}
        for gs in gs_list:
            gs_str = ''.join(gs.astype(int).astype(str))
            gs_dict[gs_str] = 1
        gs_dict_list.append(gs_dict)
        dataset_graphs.append(G.copy())
    targets = np.zeros(2*dataset_size).astype(int)
    targets[0:dataset_size] = 1
    return dataset_graphs, targets, gs_dict_list

def split_dataset(data_list, train_size=.8, val_size=.1, test_size=.1, seed=65):
    assert np.abs(train_size + val_size + test_size - 1) <= 1e-12
    np.random.seed(seed)
    permutation = np.random.choice(len(data_list), size=len(data_list), replace=False)
    num_train = int(train_size*len(permutation))
    num_val = int(val_size*len(permutation))
    mask_train=permutation[0:num_train]
    mask_val=permutation[num_train:num_train+num_val]
    mask_test=permutation[num_train+num_val::]

    data_train, data_val, data_test = [], [], []
    for i, data in enumerate(data_list):
        if i in mask_train:
            data_train.append(data)
        elif i in mask_val:
            data_val.append(data)
        elif i in mask_test:
            data_test.append(data)
        else:
            pass
    return data_train, data_val, data_test


def corr_obj(
        x: str,
        prob: float
        ) -> np.ndarray:
    """
    Returns a corr matrix objective for each
    sampled bitstring.

    Args:
        x: str (pl_s.Simulation obj)
           bitstring,
        prob: float
              counts/N_samples.
    Returns:
            np.array(M): np.ndarray
                         corr matrix for one sample.
    """
    x_array = 2*np.array(list(x[::-1]), dtype=int).reshape((-1, 1)) - 1
    return np.dot(x_array, x_array.T) * prob

        

def get_bitstrs2corr(
        N: int,
        sampled: dict
        ) -> np.ndarray:    
    """
    This function returns the corresponding correlation matrix
    given a classical input.

    Arg:
        N: int
           number of qubits, 
        sampled: dict
                 all sampled bitstrings and their counts.
    Returns:
        corr: np.ndarray
              correlation matrix.
    """    
    corr = np.zeros((N, N))
    N_samples = sum(sampled.values())
    for bitstring, count in sampled.items():
        corr += corr_obj(bitstring, count / N_samples)
    return corr

def generate_ladder_lock_crossing(length=10, n_crossings=2):
    crossing_indices = np.random.choice(np.arange(1, length//2), replace=False, size=n_crossings)
    G = generate_maze_graph(init_length=length, max_iter=0)
    for index in crossing_indices:
        G.add_edge(4*index, 4*index+3)
        # print(4*index, 4*index+3)
    config=True
    
    for n0, n1 in G.edges():
        if G.nodes[n0]['c'] + G.nodes[n1]['c'] >= 1.1:
            config=False
    if not config:
        for node in G.nodes():
            G.nodes[node]['c'] = 1 - G.nodes[node]['c']
    
    config=True
    for n0, n1 in G.edges():
        if G.nodes[n0]['c'] + G.nodes[n1]['c'] >= 1.1:
            config=False
    assert config
    return G

def generate_graph_frustrated(length=3):
    G = generate_maze_graph(length, max_iter=0)
    G.add_edges_from([(0, 3), (2*length-4, 2*length-1)])
    ground_states_frustration_array_list = []

    gs = np.zeros(2*length).astype(int)
    gs[2*np.arange(1, length-1, 2)] = 1
    gs[2*np.arange(1, length-1, 2)-1] = 1
    ground_states_frustration_array_list.append(gs.copy())

    gs = np.zeros(2*length).astype(int)
    gs[2*np.arange(1, length-2, 2)] = 1
    gs[2*np.arange(1, length-1, 2)-1] = 1
    gs[2*length - 1] = 1
    ground_states_frustration_array_list.append(gs.copy())
    gs[2*length - 1] = 0
    gs[2*length - 2] = 1
    ground_states_frustration_array_list.append(gs.copy())

    gs = np.zeros(2*length).astype(int)
    gs[2*np.arange(1, length-1, 2)+2] = 1
    gs[2*np.arange(1, length-1, 2)+1] = 1
    ground_states_frustration_array_list.append(gs.copy())

    gs = np.zeros(2*length).astype(int)
    gs[2*np.arange(1, length-1, 2)+2] = 1
    gs[2*np.arange(1, length-1, 2)[1::]+1] = 1

    gs[0] = 1
    ground_states_frustration_array_list.append(gs.copy())
    gs[0] = 0
    gs[1] = 1
    ground_states_frustration_array_list.append(gs.copy())

    for i in range(length-4):
        gs = np.zeros(2*length).astype(int)
        gs[2*np.arange(0, i+2, 2)+1] = 1
        gs[2*np.arange(0, i+1, 2)+2] = 1
        gs[2*np.arange(i+3+i%2, length-1, 2)+1] = 1
        gs[2*np.arange(i+3-i%2, length-1, 2)+2] = 1
        ground_states_frustration_array_list.append(gs.copy())


    return G, np.array(ground_states_frustration_array_list)

def generate_graph_from_sequence(sequence=[1, 0, 1], lengths=[6, 6, 6], crossings=None, gs_list=None):
    graph_type_sequence = sequence
    ABC = string.ascii_uppercase 
    # for i, e in enumerate(sequence):
    #     if i > 0:
    #         graph_type_sequence.append(2)
    #     graph_type_sequence.append(e)
    graph_sequence = []
    n_ground_states = []
    ground_states_try = []
    nodes_list = [0]
    if crossings is None:
        crossings = np.zeros_like(sequence).astype(int)
    for t, l, c in zip(graph_type_sequence, lengths, crossings):
        if t == 0:
            graph = generate_maze_graph(l, 0)
            gs = np.array([graph.nodes[node]['c'] for node in graph.nodes()])
            ground_states_try.append([gs.copy(), 1-gs.copy()])
            n_ground_states.append(2)
        elif t ==1:
            if c==0:
                c = int(l/10)*2
            graph = generate_ladder_lock_crossing(l, int(l/10)*2)
            gs = np.array([graph.nodes[node]['c'] for node in graph.nodes()])
            ground_states_try.append([gs.copy()])
            n_ground_states.append(1)
        else:
            graph, gs = generate_graph_frustrated(l)
            n_ground_states.append(len(gs))
            ground_states_try.append(gs.copy())
        
        for node in graph.nodes():
            graph.nodes[node]['coord'] = graph.nodes[node]['coord'] + np.array([1, 0]) * (np.sum(nodes_list)//2+1)
        nodes_list.append(nx.number_of_nodes(graph))
        graph_sequence.append(graph.copy())
        


    edges_links = []
    for i, graph in enumerate(graph_sequence[0:-1]):
        N = nx.number_of_nodes(graph)
        edges_links.append(((N-2, N-1), (0, 1)))

    iterators = [np.arange(l) for l in n_ground_states]
    if gs_list is None:
        gs_list = []
        for seq in product(*iterators):
            keep = True
            for i, (e1, e2) in enumerate(edges_links):
                # print(e1, e2)
                # print(len(ground_states_try[i][seq[i]]))
                if ground_states_try[i][seq[i]][e1[0]] + ground_states_try[i+1][seq[i+1]][e2[0]] >= 1.1:
                    keep = False
                    break
                if ground_states_try[i][seq[i]][e1[1]] + ground_states_try[i+1][seq[i+1]][e2[1]] >= 1.1:
                    keep = False
                    break
            if keep:
                gs_list.append(seq)

    for i, graph in enumerate(graph_sequence):
        for node in graph.nodes():
            graph.nodes[node]['gs'] = []
            for gs in gs_list:
                graph.nodes[node]['gs'].append(ground_states_try[i][gs[i]][node])

    G_union = nx.union_all(graph_sequence, rename=tuple(list(ABC[0:len(graph_sequence)])))
    for i, (e1, e2) in enumerate(edges_links):
        G_union.add_edge(f'{ABC[i]}{e1[0]}', f'{ABC[i+1]}{e2[0]}')
        G_union.add_edge(f'{ABC[i]}{e1[1]}', f'{ABC[i+1]}{e2[1]}')

    G_union = nx.convert_node_labels_to_integers(G_union)

    return G_union, gs_list


def generate_dataset_from_sequence(dataset_size=100, seed=23, compute_gs=False):
    np.random.seed(seed)
    graph_list = []
    for i in range(dataset_size):
        lengths = np.random.randint(5, 26, size=7)
        n_twists = np.random.randint(2, 9, size=7)
        n_twists = np.minimum(lengths, n_twists)
        lengths = 2 * lengths
        lengths[[1, 3, 6]] += 1
        if i == 0:
            G, gs_list = generate_graph_from_sequence([1, 2, 1, 2, 0, 2, 1], lengths, n_twists, compute_gs=compute_gs)
        else:
            G, gs_list = generate_graph_from_sequence([1, 2, 1, 2, 0, 2, 1], lengths, n_twists, compute_gs=compute_gs)
        graph_list.append(G.copy())

    for i in range(dataset_size):
        lengths = np.random.randint(5, 26, size=7)
        n_twists = np.random.randint(2, 9, size=7)
        n_twists = np.minimum(lengths, n_twists)
        lengths = 2 * lengths
        lengths[[0, 1, 2, 3, 6]] += 1
        if i == 0:
            G, gs_list = generate_graph_from_sequence([1, 2, 1, 2, 0, 2, 1], lengths, n_twists, compute_gs=compute_gs)
        else:
            G, gs_list = generate_graph_from_sequence([1, 2, 1, 2, 0, 2, 1], lengths, n_twists, compute_gs=compute_gs)        
        graph_list.append(G.copy())

    targets = np.zeros(2*dataset_size).astype(int)
    targets[dataset_size: 2*dataset_size] = 1
    return graph_list, targets 
