import torch
import dgl
import torch.nn as nn
import torch.nn.init as init
from utils import masked_softmax, return_obs_on_state, QAOA_state, masked_softmax_batch
from pyqtorch.core.operation import X, Y, Z, H, Sd, S
import time

class QuantumAttention(nn.Module):

    def __init__(self, n_layers, observables=None):
        super(QuantumAttention, self).__init__()
        self.times = nn.Parameter(torch.empty((n_layers, )))
        self.pulses = nn.Parameter(torch.empty((n_layers, )))
        self.observables = observables
        self.reset_parameters()

    def reset_parameters(self):
        init.uniform_(self.times, 0, 2.)
        init.uniform_(self.pulses, 0, .3)

    def update_observable_device(self, device):
        for key, value in self.observables.items():
            self.observables[key] = value.to(device)

    def forward(self, N, ising_matrix, batch_size=1, observables=None):
        """Forward computatioin

        Parameters 
        ----------
        N : int
            Number of qubits
        ising_matrix : torch.Tensor
            Ising matrix of size (2^N, batch_size)
        batch_size : int
            Batch size

        Returns
        -------
        torch.Tensor of shape (N, N, batch_size)
            Attention matrices
        """
        t0 = time.perf_counter()
        state = QAOA_state(N, ising_matrix,
                           self.times, self.pulses, batch_size=batch_size)
        t1 = time.perf_counter()
        
        if observables is None:
            observables = self.observables[N]
        matrix = return_obs_on_state(state, observables, N, batch_size)
        t2 = time.perf_counter()

        del state
        torch.cuda.empty_cache()
        return matrix


class QuantumAttentionCorrelators(nn.Module):
    '''
    A_{ij} = [X_iX_j, Z_iZ_j, Y_iY_j, X_iZ_i]^T gamma'''

    def __init__(self, n_layers, observables=None):
        super(QuantumAttentionCorrelators, self).__init__()
        self.times = nn.Parameter(torch.empty((n_layers, )))
        self.pulses = nn.Parameter(torch.empty((n_layers, )))
        self.gamma = nn.Parameter(torch.empty((1, 1, 9, 1)))

        self.observables = observables
        self.reset_parameters()

    def reset_parameters(self):
        init.uniform_(self.times, 0, 2.)
        init.uniform_(self.pulses, 0, .3)
        init.uniform_(self.gamma, -5, 5)


    def update_observable_device(self, device):
        for key, value in self.observables.items():
            self.observables[key] = value.to(device)

    def forward(self, N, ising_matrix, batch_size=1, observables=None, return_raw_matrices=False):
        """Forward computation

        Parameters
        ----------
        N : int
            Number of qubits
        ising_matrix : torch.Tensor
            Ising matrix of size ([2] * N, batch_size)
        batch_size : int
            Batch size
            default = 1

        Returns
        -------
        torch.Tensor
        Attention matrix of size (N, N, batch_size)
        """
        
        state = QAOA_state(N, ising_matrix,
                           self.times, self.pulses, batch_size=batch_size)
        del ising_matrix
        torch.cuda.empty_cache()
        matrices = torch.zeros((N, N, 9, batch_size), device=state.device)

        new_state = torch.abs(state)**2
        for i in range(N):
            for j in range(N):
                if i != j:
                    dims = [N - k - 1 for k in range(N) if k not in [i, j]]
                    reduction = torch.sum(new_state, dims)
                    matrices[i, j, 0] +=  reduction[0, 0] + reduction[1, 1] - reduction[0, 1] - reduction[1, 0]
                else:
                    matrices[i, j, 0] += 1

        for i in range(N):
            state = H(state, qubits=[i], N_qubits=N)
        new_state = torch.abs(state)**2
        for i in range(N):
            for j in range(N):
                if i != j:
                    dims = [N - k - 1 for k in range(N) if k not in [i, j]]
                    reduction = torch.sum(new_state, dims)
                    matrices[i, j, 1] +=  reduction[0, 0] + reduction[1, 1] - reduction[0, 1] - reduction[1, 0]
                

        for i in range(N):
            #recover original state
            state = H(state, qubits=[i], N_qubits=N)
            #change in Y basis
            state = Sd(state, qubits=[i], N_qubits=N)
            state = H(state, qubits=[i], N_qubits=N)
        new_state = torch.abs(state)**2
        for i in range(N):
            for j in range(N):
                if i != j:
                    dims = [N - k - 1 for k in range(N) if k not in [i, j]]
                    reduction = torch.sum(new_state, dims)
                    matrices[i, j, 2] +=  reduction[0, 0] + reduction[1, 1] - reduction[0, 1] - reduction[1, 0]
                

        for i in range(N):
            #recover original state
            state = H(state, qubits=[i], N_qubits=N)
            state = S(state, qubits=[i], N_qubits=N)
            if i%2 == 0:
                state = H(state, qubits=[i], N_qubits=N)
        new_state = torch.abs(state)**2
        for i in range(N):
            for j in range(N):
                if i%2 == 0 and j%2 == 1:
                    dims = [N - k - 1 for k in range(N) if k not in [i, j]]
                    reduction = torch.sum(new_state, dims)
                    matrices[i, j, 3] +=  reduction[0, 0] + reduction[1, 1] - reduction[0, 1] - reduction[1, 0]

        for i in range(N):
            state = H(state, qubits=[i], N_qubits=N)
        new_state = torch.abs(state)**2
        for i in range(N):
            for j in range(N):
                if i%2 == 1 and j%2 == 0:
                    dims = [N - k - 1 for k in range(N) if k not in [i, j]]
                    reduction = torch.sum(new_state, dims)
                    matrices[i, j, 3] +=  reduction[0, 0] + reduction[1, 1] - reduction[0, 1] - reduction[1, 0]

        for i in range(N):
            if i%2 == 1:
                state = H(state, qubits=[i], N_qubits=N)
                state = Sd(state, qubits=[i], N_qubits=N)
            state = H(state, qubits=[i], N_qubits=N)
        new_state = torch.abs(state)**2
        for i in range(N):
            for j in range(N):
                if i%2 == 0 and j%2 == 1:
                    dims = [N - k - 1 for k in range(N) if k not in [i, j]]
                    reduction = torch.sum(new_state, dims)
                    matrices[i, j, 4] +=  reduction[0, 0] + reduction[1, 1] - reduction[0, 1] - reduction[1, 0]

        for i in range(N):
            state = H(state, qubits=[i], N_qubits=N)
            if i%2 == 1:
                state = S(state, qubits=[i], N_qubits=N)
            if i%2 == 0:
                state = Sd(state, qubits=[i], N_qubits=N)
            state = H(state, qubits=[i], N_qubits=N)
        new_state = torch.abs(state)**2
        for i in range(N):
            for j in range(N):
                if i%2 == 1 and j%2 == 0:
                    dims = [N - k - 1 for k in range(N) if k not in [i, j]]
                    reduction = torch.sum(new_state, dims)
                    matrices[i, j, 4] +=  reduction[0, 0] + reduction[1, 1] - reduction[0, 1] - reduction[1, 0]

        for i in range(N):
            if i%2 == 1:
                state = H(state, qubits=[i], N_qubits=N)
        new_state = torch.abs(state)**2
        for i in range(N):
            for j in range(N):
                if i%2 == 0 and j%2 == 1:
                    dims = [N - k - 1 for k in range(N) if k not in [i, j]]
                    reduction = torch.sum(new_state, dims)
                    matrices[i, j, 5] +=  reduction[0, 0] + reduction[1, 1] - reduction[0, 1] - reduction[1, 0]


        for i in range(N):
            if i%2 == 0:
                state = H(state, qubits=[i], N_qubits=N)
                state = S(state, qubits=[i], N_qubits=N)
            else:
                state = Sd(state, qubits=[i], N_qubits=N)
                state = H(state, qubits=[i], N_qubits=N)
        new_state = torch.abs(state)**2
        for i in range(N):
            for j in range(N):
                if i%2 == 1 and j%2 == 0:
                    dims = [N - k - 1 for k in range(N) if k not in [i, j]]
                    reduction = torch.sum(new_state, dims)
                    matrices[i, j, 5] +=  reduction[0, 0] + reduction[1, 1] - reduction[0, 1] - reduction[1, 0]

        matrices[:, :, 6:9] = torch.transpose(matrices[:, :, 3:6], 0, 1)

        if return_raw_matrices:
            return matrices
        
        matrix = torch.sum(matrices * self.gamma.expand(N, N, -1, batch_size), axis=2)
        if batch_size == 1:
            matrix = matrix.squeeze(2)

        del state
        del new_state
        del matrices
        torch.cuda.empty_cache()
        #print(f'{t1-t0} | {t2-t1}')
        return matrix


class QuantumAttentionHead(nn.Module):
    """Graph convolution module based on the GraphSage model adding Quantum weights usage

    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
    def __init__(self, in_feat, out_feat, n_att_layers, observables=None, apply_softmax=False, only_neighbors=False):
        # EDIT(1) to use our own class
        super(QuantumAttentionHead, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.attention = QuantumAttentionCorrelators(n_att_layers, observables)
        self.a = nn.Parameter(torch.empty((1, )))
        self.linear = nn.Linear(in_feat * 2, out_feat)
        self.apply_softmax = apply_softmax
        self.only_neighbors = only_neighbors
        self.reset_parameters()
        
    
    def reset_parameters(self):
        init.uniform_(self.a, -10., 10.)
        self.attention.reset_parameters()
        self.linear.reset_parameters()
        
    def update_observable_device(self, device):
        self.attention.update_observable_device(device)


    def forward(self, g, h, ising_matrices, unbatch=True, batch_size=1, precomputed_attention=None, observables=None):
        """Forward computation

        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        precomputed_attention: list of Tensor
            The attention matrices precomputed
        """
        with g.local_scope():
            g.ndata['h'] = h
            if precomputed_attention is not None:
                trans_matrix = torch.block_diag(*precomputed_attention)
                if self.only_neighbors:
                    mask = g.adj().to_dense()
                else:
                    mask = torch.block_diag(*[1 - torch.eye(att.shape[0]) for att in precomputed_attention])
                mask = mask.to(trans_matrix.device)
                trans_matrix = trans_matrix * mask
                if self.apply_softmax:
                    trans_matrix = self.a * trans_matrix
                    trans_matrix = masked_softmax(trans_matrix, mask, dim=1)
            else:
                if unbatch:
                    att_matrices = []
                    if observables is None:
                        observables = [None] * len(ising_matrices)
                    for ising, graph, obs in zip(ising_matrices, dgl.unbatch(g), observables):
                        N = graph.num_nodes()
                        att_matrix = self.attention(N, ising, batch_size=1, observables=obs)
                        if self.only_neighbors:
                            mask = graph.adj().to_dense().to(att_matrix.device)
                        else:
                            mask = 1 - torch.eye(att_matrix.shape[0]).to(att_matrix.device)
                        att_matrix = att_matrix * mask
                        if self.apply_softmax:
                            att_matrix = self.a * att_matrix
                            att_matrix = masked_softmax(att_matrix, mask, dim=1)
                        att_matrices.append(att_matrix)
                    trans_matrix = torch.block_diag(*att_matrices) 
                else:
                    N = int(g.batch_num_nodes()[0])
                    trans_matrix = self.attention(N, ising_matrices, batch_size=batch_size, observables=observables)
                    if batch_size>1:
                        att_matrices = [trans_matrix[:, :, i] for i in range(batch_size)]
                    else:
                        att_matrices = [trans_matrix]
                    trans_matrix = torch.block_diag(*att_matrices) 
                    if self.only_neighbors:
                        mask = g.adj().to_dense()
                    else:
                        mask = torch.block_diag(*[1 - torch.eye(att.shape[0]) for att in att_matrices])
                    mask = mask.to(trans_matrix.device)
                    trans_matrix = trans_matrix * mask
                    if self.apply_softmax:
                        trans_matrix = self.a * trans_matrix
                        trans_matrix = masked_softmax(trans_matrix, mask, dim=1)
            h_N = torch.matmul(trans_matrix, h)
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)

class QuantumAttentionHeadCustom(nn.Module):
    """Graph convolution module  model adding Quantum weights usage

    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    n_att_layers : int
        Number of layers in the quantum dynamics.
    apply_softmax : bool
        Whether to apply softmax on the attention matrix.
    only_neighbors : bool
        Whether to use only neighbors in the attention matrix.
        Equivalent to the strict message passing.
    """
    def __init__(self, in_feat, out_feat, n_att_layers, observables=None, apply_softmax=False, only_neighbors=False):
        # EDIT(1) to use our own class
        super(QuantumAttentionHeadCustom, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.attention = QuantumAttentionCorrelators(n_att_layers, observables)
        self.a = nn.Parameter(torch.empty((1, )))
        self.linear = nn.Linear(in_feat * 2, out_feat)
        self.apply_softmax = apply_softmax
        self.only_neighbors = only_neighbors
        self.reset_parameters()
        
    
    def reset_parameters(self):
        init.uniform_(self.a, -10., 10.)
        self.attention.reset_parameters()
        self.linear.reset_parameters()
        
    def update_observable_device(self, device):
        self.attention.update_observable_device(device)


    def forward(self, h, ising_matrices, adjacency = None, batch_size=1, precomputed_attention=None, observables=None):
        """Forward computation

        Parameters
        ----------
        h : Tensor of shape (batch_size, N, in_feat)
            The input node feature.
        ising_matrices : Tensor of shape (2**N, batch_size)
            The input ising matrices.
            can be None if adjacency is not None
        adjacency : Tensor of shape (batch_size, N, N)
            The input adjacency matrices.
        batch_size : int
            The batch size.
            default 1
        precomputed_attention: Tensor of shape (batch_size, N, N)
            The attention matrices precomputed
            default None

        Returns
        -------
        Tensor of shape (batch_size, N, out_feat)
        """
        N = h.shape[1]
        if self.only_neighbors:
            mask = adjacency
        else:
            mask = (1 - torch.eye(N)).unsqueeze(0).expand(batch_size, -1, -1)

        if precomputed_attention is not None:
            trans_matrix = precomputed_attention
        else:
            trans_matrix = self.attention(N, ising_matrices, batch_size=batch_size, observables=observables)
            if batch_size>1:
                trans_matrix = trans_matrix.permute(2, 0, 1)
            else:
                # att_matrices = [trans_matrix]
                trans_matrix = trans_matrix.unsqueeze(0)
        
        mask = mask.to(trans_matrix.device)
        trans_matrix = trans_matrix * mask
        if self.apply_softmax:
            trans_matrix = self.a * trans_matrix
            trans_matrix = masked_softmax_batch(trans_matrix, mask, dim=1)

        h_N = torch.einsum('ijk,ikl->ijl', trans_matrix, h)
        h_total = torch.cat([h, h_N], dim=2)#.reshape((batch_size, N, -1))
        return self.linear(h_total)


class MultiHeadQuantumAttention(nn.Module):
    
    
    def __init__(self, in_feat, out_head, n_heads, n_att_layers, observables, apply_softmax=False, only_neighbors=False):
        # EDIT(1) to use our own class
        super(MultiHeadQuantumAttention, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.heads = [QuantumAttentionHead(in_feat, out_head, n_att_layers, observables, apply_softmax, only_neighbors) for _ in range(n_heads)]
        self.heads = nn.ModuleList(self.heads)
        self.reset_parameters()
                        
    def forward(self, g, h, ising_matrices, unbatch=True):
        outputs = []
        for head in self.heads:
            outputs.append(head(g, h, ising_matrices, unbatch))
        return torch.cat(outputs, dim=1)
            
    def reset_parameters(self):
        for head in self.heads:
            head.reset_parameters() 
    
    def update_observable_device(self, device):
        for head in self.heads:
            head.update_observable_device(device)

class MultiHeadQuantumAttentionCustom(nn.Module):
    
    
    def __init__(self, in_feat, out_head, n_heads, n_att_layers, observables=None, apply_softmax=False, only_neighbors=False):
        # EDIT(1) to use our own class
        super(MultiHeadQuantumAttentionCustom, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.heads = [QuantumAttentionHeadCustom(in_feat, out_head, n_att_layers, observables, apply_softmax, only_neighbors) for _ in range(n_heads)]
        self.heads = nn.ModuleList(self.heads)
        self.n_heads = n_heads
        self.reset_parameters()
                        
    def forward(self, h, ising_matrices, adjacency = None, batch_size=1, precomputed_attention=None, observables=None):
        '''
        precomputed_attention of shape (batch_size, n_heads, n_nodes, n_nodes)
        '''
        outputs = []
        for i, head in enumerate(self.heads):
            if precomputed_attention is not None:
                outputs.append(head(h, ising_matrices, adjacency, batch_size, precomputed_attention[:, i], observables))
            else:
                outputs.append(head(h, ising_matrices, adjacency, batch_size, None, observables))
        return torch.cat(outputs, dim=2)
            
    def reset_parameters(self):
        for head in self.heads:
            head.reset_parameters() 
    
        