import torch
import dgl
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import networkx as nx
import numpy as np

from pyqtorch.core.operation import RX

def obs_ZZ(N, i=0, j=0, device='cpu', type_n=False):
    
    if i==j:
        return torch.ones(2**N).to(device)
    
    I = torch.ones(2).to(device)
    if type_n:
        Z = torch.tensor([0, 1.]).to(device)
    else:
        Z = torch.tensor([1., -1]).to(device)

    sz_list = []
    sn_list = []

    op_list = [Z if k in [i, j] else I for k in range(N)]
    operator = op_list[0]
    for op in op_list[1::]:
        operator = torch.kron(operator, op)

    return operator

def generate_ising_Ham_from_graph_torch_nx(graph, precomputed_zz, process_edge=None):
    N = graph.number_of_nodes()
    # construct the hamiltonian
    H = torch.zeros(2**N)

    for edge in graph.edges.data():
        # What are in edges data -> edge[2] == ?
        edge_weight = 1
        if len(edge[2]) > 0:
            if process_edge is not None:
                edge_weight = process_edge(edge[2]['attr'])
        if precomputed_zz is not None:
            if (edge[0], edge[1]) in precomputed_zz[N]:
                key = (edge[0], edge[1])
            else: 
                key = (edge[1], edge[0])
            H += edge_weight * precomputed_zz[N][key]
        
    return H

def generate_ising_Ham_from_graph_torch(graph, precomputed_zz, process_edge=None):
    N = graph.number_of_nodes()
    # construct the hamiltonian
    H = torch.zeros(2**N)
    used_edges = []

    # all_edges enumerate all the edges twice
    for edge in zip(*graph.all_edges()):
        edge_weight = 1
        edge = (int(edge[0]), int(edge[1]))
        r_edge = (int(edge[1]), int(edge[0]))
        if edge not in used_edges and r_edge not in used_edges:
            if precomputed_zz is not None:
                if edge in precomputed_zz[N]:
                    key = edge
                else: 
                    key = (int(edge[1]), int(edge[0]))
                H += edge_weight * precomputed_zz[N][key]
            used_edges.append(r_edge)
    return H

def generate_ising_matrices_torch(G_list, process_edge=None, precomputed_zz=None, device='cpu'):
    matrices = []
    for graph in G_list:
        ising_matrix = generate_ising_Ham_from_graph_torch(graph, process_edge=process_edge, precomputed_zz=precomputed_zz)
        matrices.append(ising_matrix.to(device))
        del ising_matrix
        torch.cuda.empty_cache()
    return matrices

def compute_all_ising_matrices(graphs_list, NN_matrices):
    ising_matrices_list = []
    for graphs in graphs_list:
        ising_matrices = generate_ising_matrices_torch(
                                    graphs,
                                    precomputed_zz=NN_matrices)
        ising_matrices = torch.cat([ising.reshape((-1, 1)) for ising in ising_matrices], dim=1)
        ising_matrices_list.append(ising_matrices)
    return ising_matrices_list

def return_observables_torch(max_N=10, device='cpu', precomputed_zz=None):
    obs_dict = dict()
    for n in range(2, max_N+1): 
        obs_matrix = torch.zeros((n, n, 2**n)).to(device)
        for i in range(n):
            for j in range(i, n):
                if precomputed_zz is not None:
                    obs = precomputed_zz[n][(i,j)]
                else:
                    obs = obs_ZZ(n, i, j)
                obs_matrix[i,j] = obs.to(device)
                obs_matrix[j,i] = obs_matrix[i,j]
        obs_dict[n] = obs_matrix
        del obs_matrix
        torch.cuda.empty_cache()
    return obs_dict

def update_x(state, X):
    N_qubits = len(state.shape) - 1
    for i in range(N_qubits):
        state = torch.tensordot(X, state, dims=([1],[i]))
        inv_perm = torch.argsort(torch.tensor([i] + [j for j in range(N_qubits+1) if j != i]))
        state = torch.permute(state, tuple(inv_perm))
    return state


def update_ising_matrix(psi, ising_matrix, tf=.3):
    return torch.exp(-1j * tf * ising_matrix) * psi


def return_exp_x(t, dev='cpu'):
    return torch.eye(2).to(dev) * torch.cos(t) - 1j\
        * torch.tensor([[0, 1], [1, 0]]).to(dev) * torch.sin(t)


def QAOA_state(N, ising_matrix, times, pulses, batch_size=1, omega=2):

    dev = ising_matrix.device
    state = torch.ones((2**N, batch_size)).to(dev)/np.sqrt(2**N)
    state = state.reshape([2] * N + [batch_size])
    ising_matrix = ising_matrix.reshape([2] * N + [batch_size])

    for t, p in zip(times, pulses):
        state = update_ising_matrix(state, ising_matrix, t)
        for i in range(N):
            state = RX(p * np.pi * omega, state, qubits=[i], N_qubits=N)
        #X = return_exp_x(p * np.pi * omega, dev)
        #state = update_x(state, X)

    #state = state.reshape((-1, batch_size))
    del ising_matrix
    torch.cuda.empty_cache()

    return state


def return_obs_on_state(state, observable, N, batch_size=1):
    if observable.device != state.device:
        observable = observable.to(state.device)
    state = torch.abs(state)**2
    state = torch.unsqueeze(state, 0)
    state = torch.unsqueeze(state, 0)
    state = state.repeat((N, N, 1, 1))
    if len(observable.shape) == 3:
        observable = torch.unsqueeze(observable, 3)
        observable = observable.repeat((1, 1, 1, batch_size))    
    matrix = torch.sum(state * observable, axis=2)
    del state
    del observable
    torch.cuda.empty_cache()
    if batch_size == 1:
        matrix = matrix.squeeze(2)

    return matrix


def masked_softmax(a, mask, dim=1):
    a = torch.exp(a) * mask
    return a /\
        (torch.sum(a, dim=dim).unsqueeze(1).repeat(1, a.shape[0]) + 1e-15)

def masked_softmax_batch(a, mask, dim=1):
    a = torch.exp(a) * mask
    return a /\
        (torch.sum(a, dim=2).unsqueeze(2).repeat(1, 1, a.shape[1]) + 1e-15)
