from itertools import combinations
import networkx as nx
import numpy as np
from numpy import linalg as LA
import torch
from torch_geometric.utils import to_networkx


def qrw_on_graph_from_pyg_graph3(A, duration, num_walkers):
    """
    Takes as input a pytorch geometric graph data object and returns the transition matrix for the quantum random walk
    """

    def list_to_string(l):
        s = ""
        for i in l:
            s += f"{i:d}"
        return s

    # Creating adjacency matrix from edge list
    def create_adj(n, elist):
        Amat = np.zeros((n, n))
        for e in elist:
            u, v = e
            Amat[int(u), int(v)] = 1
            Amat[int(v), int(u)] = 1
        return Amat

    def two_qw_probs(Amat, theta):
        G1 = nx.from_numpy_matrix(Amat)
        G_2 = occupation_graph(G1, 2)
        Hmat = nx.to_numpy_array(G_2)
        m, n = np.shape(Hmat)
        inp = np.zeros(m)
        all_probs = []
        w, v = LA.eigh(Hmat)
        D = np.diag(w)
        mat_op = np.matmul(v, np.diag(np.exp(1j * theta * w)))
        mat_op = np.matmul(mat_op, v.T)
        for i in range(0, m):
            psi_out = np.zeros(m)
            inp[i] = 1
            psi_out = inp
            c_fac = np.sqrt(np.sum((np.abs(psi_out)) ** 2))
            psi_out = (1 / c_fac) * psi_out
            psi_out = np.matmul(mat_op, psi_out)
            c_fac = np.sqrt(np.sum((np.abs(psi_out)) ** 2))
            psi_out = (1 / c_fac) * psi_out
            all_probs.append(np.abs(psi_out) ** 2)
            inp[i] = 0
        Pij = np.reshape(np.array(all_probs), (m, m))
        return Pij

    # Returns occupation networkx graph. Takes as input base networkx graph G and number of walkers p
    def occupation_graph(G, p):
        """
        returns the occupation graph with p particles of graph G
        """
        anc_list = []
        Nv = G.order()
        A = nx.adjacency_matrix(G)
        baseline = range(Nv)
        nodelist = []
        anc_cntr = -1
        for perm in combinations(baseline, p):
            state = [0] * Nv
            anc_cntr += 1
            if Nv - 1 in perm:
                anc_list.append(anc_cntr)
            for p in perm:
                state[p] = 1
                nodelist.append(state)
        nodes = [list_to_string(node) for node in nodelist]
        G_occ = nx.Graph()
        G_occ.add_nodes_from(nodes)

        for node1, node2 in combinations(nodelist, 2):
            diffs = []
            for i, (n1, n2) in enumerate(zip(node1, node2)):
                if n1 != n2:
                    diffs.append((i, n1 - n2))
            if len(diffs) == 2:
                (i, di), (j, dj) = diffs
                if di + dj == 0 and A[i, j] > 0:
                    G_occ.add_edge(list_to_string(node1), list_to_string(node2))

        return G_occ

    # performs evolution for graph given by adjacency matrix Hmat for time theta and returns list of probabilities to observe
    # a particle at a particular location given by array all_probs
    def perform_evolution_diag_sup(Hmat, theta):
        m, n = np.shape(Hmat)
        inp = np.ones(m)
        all_probs = []
        w, v = LA.eigh(Hmat)
        D = np.diag(w)
        mat_op = np.matmul(v, np.diag(np.exp(1j * theta * w)))
        mat_op = np.matmul(mat_op, v.T)
        psi_out = np.zeros(m)
        psi_out = inp
        c_fac = np.sqrt(np.sum((np.abs(psi_out)) ** 2))
        psi_out = (1 / c_fac) * psi_out
        psi_in = psi_out
        psi_out = np.matmul(mat_op, psi_out)
        c_fac = np.sqrt(np.sum((np.abs(psi_out)) ** 2))
        psi_out = (1 / c_fac) * psi_out
        all_probs.append(np.abs(psi_out) ** 2)

        # overlap=np.matmul(psi_in,np.conjugate(psi_out.T))

        return all_probs

    # reshapes array of probabilities all_probs for 2-particle QRW into a square matrix
    def qw_reshape(all_probs, nw, m):
        # print(m)
        p_mat = np.zeros((m, m))
        ix = -1
        # print(all_probs)
        for perm in combinations(range(m), int(nw)):
            ix += 1
            i, j = perm
            # print(all_probs[ix])
            p_mat[i, j] = all_probs[ix]
            p_mat[j, i] = all_probs[ix]
        return p_mat

    # accepts nw:number of walkers(only 1 or 2), Amat: base adjacency matrix, theta: time for evolution,
    # two_flag: indicates if number walkers is 2.
    def perform_evolution_diag(nw, Amat, theta, two_flag):
        G1 = nx.from_numpy_array(Amat)
        G_nw = occupation_graph(G1, int(nw))
        # nx.draw(G_nw)

        Hmat = nx.to_numpy_array(G_nw)
        m, n = np.shape(Hmat)
        inp = np.zeros(m)

        if two_flag:
            if type(theta) != list:
                all_probs = perform_evolution_diag_sup(Hmat, theta)
                all_probs = np.ravel(all_probs)
                # print(all_probs)
                Pij = qw_reshape(all_probs, nw, np.shape(Amat)[0])
                Pij = torch.tensor(Pij, dtype=torch.float)  # n_nodes x n_nodes
                return Pij

            else:
                list_Pij = []

                w, v = LA.eigh(Hmat)
                D = np.diag(w)

                for th in theta:
                    # m,n = np.shape(Hmat)
                    inp = np.ones(m)
                    all_probs = []
                    mat_op = np.matmul(v, np.diag(np.exp(1j * th * w)))
                    mat_op = np.matmul(mat_op, v.T)
                    psi_out = np.zeros(m)
                    psi_out = inp
                    c_fac = np.sqrt(np.sum((np.abs(psi_out)) ** 2))
                    psi_out = (1 / c_fac) * psi_out
                    psi_in = psi_out
                    psi_out = np.matmul(mat_op, psi_out)
                    c_fac = np.sqrt(np.sum((np.abs(psi_out)) ** 2))
                    psi_out = (1 / c_fac) * psi_out
                    all_probs.append(np.abs(psi_out) ** 2)

                    all_probs = np.ravel(all_probs)
                    Pij = qw_reshape(all_probs, nw, np.shape(Amat)[0])
                    Pij = torch.tensor(
                        Pij, dtype=torch.float
                    )  # n_nodes x n_nodes
                    list_Pij.append(Pij)

                return list_Pij
        else:
            w, v = LA.eigh(Hmat)
            D = np.diag(w)
            if type(theta) != list:
                all_probs = []
                mat_op = np.matmul(v, np.diag(np.exp(1j * theta * w)))
                mat_op = np.matmul(mat_op, v.T)
                for i in range(0, m):
                    psi_out = np.zeros(m)
                    inp[i] = 1
                    psi_out = inp
                    c_fac = np.sqrt(np.sum((np.abs(psi_out)) ** 2))
                    psi_out = (1 / c_fac) * psi_out
                    psi_out = np.matmul(mat_op, psi_out)
                    c_fac = np.sqrt(np.sum((np.abs(psi_out)) ** 2))
                    psi_out = (1 / c_fac) * psi_out
                    all_probs.append(np.abs(psi_out) ** 2)
                    inp[i] = 0
                Pij = np.reshape(np.array(all_probs), (m, m))
                Pij = torch.tensor(Pij, dtype=torch.float)  # n_nodes x n_nodes

                return Pij  # array of Pij
            else:
                list_Pij = []
                for th in theta:
                    all_probs = []
                    mat_op = np.matmul(v, np.diag(np.exp(1j * th * w)))
                    mat_op = np.matmul(mat_op, v.T)
                    for i in range(0, m):
                        psi_out = np.zeros(m)
                        inp[i] = 1
                        psi_out = inp
                        c_fac = np.sqrt(np.sum((np.abs(psi_out)) ** 2))
                        psi_out = (1 / c_fac) * psi_out
                        psi_out = np.matmul(mat_op, psi_out)
                        c_fac = np.sqrt(np.sum((np.abs(psi_out)) ** 2))
                        psi_out = (1 / c_fac) * psi_out
                        all_probs.append(np.abs(psi_out) ** 2)
                        inp[i] = 0
                    Pij = np.reshape(np.array(all_probs), (m, m))
                    Pij = torch.tensor(
                        Pij, dtype=torch.float
                    )  # n_nodes x n_nodes
                    list_Pij.append(Pij)

                return list_Pij  # list of Pijs

    # Get the probability matrix
    if num_walkers == 2:
        flag = 1
    elif num_walkers == 1:
        flag = 0
    else:
        raise ValueError("Number of walkers must be 1 or 2")

    Pij = perform_evolution_diag(
        num_walkers, A, duration, flag
    )  # number of walkers 2, two_flag is set to 1

    return Pij
