import torch

from torch_geometric.graphgym.register import (
    register_node_encoder,register_edge_encoder
)

from torch.nn import Linear
from torch_geometric.utils import to_dense_adj

def extract_common_tensor(data_batch,batch_num):
    # batch num va de 0 BS -1 
    idx_min,idx_max = data_batch.ptr[batch_num].item(), data_batch.ptr[batch_num+1].item()
    Com_ij = torch.zeros((idx_max-idx_min,idx_max-idx_min))
    idx_vals = torch.where((data_batch.common_index<idx_max) & (data_batch.common_index>=idx_min))[1]
    
    pairs = data_batch.common_index[:,idx_vals]
    Com_ij[pairs[0]-idx_min,pairs[1]-idx_min]+= data_batch.common_val[idx_vals]
    return Com_ij + Com_ij.transpose(0,1)


def extract_adj_tensor(data_batch,batch_num):
    # batch num va de 0 BS -1 
    
    idx_min,idx_max = data_batch.ptr[batch_num].item(), data_batch.ptr[batch_num+1].item()
    idx_vals = torch.where((data_batch.edge_index<idx_max) & (data_batch.edge_index>=idx_min))[1]
    
    return 0.5*to_dense_adj(data_batch.edge_index[:,idx_vals]-idx_min)[0]


class CorrelationMatrix:
    def __init__(self,
                 Gnn_encoder: torch.Tensor,
                 k : int,
                 device,
                 ) -> torch.Tensor:
        
        super().__init__()
        # P is a [1,3*k] dimensional tensor which contains the values of  
        # theta, t and h, that we later reshape for more efficiency 
        self.device = device
        self.Gnn_encoder = Gnn_encoder
        self.k = k 

        
    def w_ij(self, Adj, theta, t):
        # for this method, the Adj is an individual (not batched) adjacency matrix 

        return (torch.cos(theta[:,:,None])**2 + (torch.sin(theta[:,:,None])**2)*torch.exp(Adj*t[:,:,None]*1j)).to(self.device)


    def w_plus(self, Adj, com_ij, theta, t):
        # for this method, the Adj is an individual (not batched) adjacency matrix 

        B = (torch.cos(theta)**2 + torch.sin(theta)**2 * torch.exp(1j * t)) \
        * (torch.cos(theta)**2 + torch.sin(theta)**2 * torch.exp(-1j * t))
        B = B[:,:,None]

        return torch.pow((torch.cos(theta[:,:,None])**2 + (torch.sin(theta[:,:,None])**2) *
                          torch.exp(2 * t[:,:,None] * 1j)), com_ij).to(self.device) * (B**(1-Adj))
    
    def w_minus(self, Adj, theta, t):
        # for this method, the Adj is an individual (not batched) adjacency matrix 

        return (torch.cos(theta[:,:,None])**2 + torch.sin(theta[:,:,None])**2 * torch.exp(1j * t[:,:,None])) \
        * (torch.cos(theta[:,:,None])**2 + torch.sin(theta[:,:,None])**2 * torch.exp(-1j * t[:,:,None])) ** (1 - Adj)


    def compute_correlation_matrix_batched(self, data_batch):
        P = self.Gnn_encoder(data_batch) 
        X_corrs_list = []
        E_ij_corrs_list = []
        indexes = []
        for bn in range(data_batch.ptr.shape[0]-1):
            
            theta = (P[bn][:self.k].reshape(self.k,1)).to(self.device)
            t = (P[bn][self.k:2*self.k].reshape(self.k,1)).to(self.device)
            h = (P[bn][2*self.k:].reshape(self.k,1)).to(self.device)

            Adj =  extract_adj_tensor(data_batch,bn).to(self.device)
            N = Adj.shape[0]
            com_ij = extract_common_tensor(data_batch,bn).to(self.device)

            F = ((4*(torch.sin(theta)**4) * (torch.cos(theta)**4))).to(self.device)
            W = self.w_ij(Adj,theta,t).to(self.device)
            
            rho_vect = torch.exp(h * t* 1j)*torch.prod(W, 2)

            rho_col = ((rho_vect.reshape(k,N,1).repeat(1,1,N)).reshape(-1,N))
            rho_row = (torch.repeat_interleave(rho_vect,N,dim = 0))
            rho_ij = ((rho_col + rho_row).reshape(k,N,N)).to(self.device)

            f1 = (rho_ij * (1 - 1/W)).to(self.device) #####
            a = (.5 * (1 - (torch.exp(Adj * t[:,:,None] * 1j) / self.w_plus(Adj, com_ij,theta,t)))).to(self.device)
            
            b = (rho_row.reshape(k,N,N) * rho_col.reshape(k,N,N)).to(self.device)
            f2 = (a * b).to(self.device)
            b_conj = (rho_row.reshape(k,N,N) * torch.conj(rho_col.reshape(k,N,N))).to(self.device)

            f3 = (.5 * (1 - (1 / self.w_minus(Adj,theta,t))) * b_conj).to(self.device)

            corr = F[:,:,None] * torch.real(f1 + f2 + f3)

            self_cors = torch.stack([corr[:,i,i] for i in range(N)])
            cross_cors = torch.stack([corr[:,i,j] for i in range(N) for j in range(N)])
            bn_indexes = [(data_batch.ptr[bn]+i, data_batch.ptr[bn]+j) for i in range(N) for j in range(N)]
                        
            indexes.append(
                torch.stack((torch.tensor([i for i,j in bn_indexes]),torch.tensor([j for i,j in bn_indexes])),0)
                        )

            X_corrs_list.append(self_cors)
            E_ij_corrs_list.append(cross_cors)

        #######
        data_batch.qcorr = torch.cat(X_corrs_list,0).to(self.device)
        data_batch.qcorr_val = torch.cat(E_ij_corrs_list,0).to(self.device)
        data_batch.qcorr_index = torch.cat(indexes,1).to(self.device)

        return data_batch



@register_node_encoder('Qcorr')
class LinearQuantumCorrelationNodeEncoder(torch.nn.Module):

    def __init__(self, emb_dim, out_dim, use_bias=False, batchnorm=False, layernorm=False, pe_name="qcorr"):
        #emb_dim = k, is the number of correlation matrices we provide in our input

        super().__init__()
        self.batchnorm = batchnorm
        self.layernorm = layernorm
        self.name = pe_name

        self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias)
        torch.nn.init.xavier_uniform_(self.fc.weight)

        if self.batchnorm:
            self.bn = nn.BatchNorm1d(out_dim)
        if self.layernorm:
            self.ln = nn.LayerNorm(out_dim)

    def forward(self, batch):
        # Encode just the first dimension if more exist
        rrwp = batch[f"{self.name}"]
        rrwp = self.fc(rrwp)

        if self.batchnorm:
            rrwp = self.bn(rrwp)

        if self.layernorm:
            rrwp = self.ln(rrwp)

        if "x" in batch:
            batch.x = batch.x + rrwp
        else:
            batch.x = rrwp

        return batch

@register_edge_encoder('Qcorr')
class LinearQuantumCorrelationEdgeEncoder(torch.nn.Module):
 
    def __init__(self, emb_dim, out_dim, batchnorm=False, layernorm=False, use_bias=False,
                 pad_to_full_graph=True, fill_value=0.,
                 add_node_attr_as_self_loop=False,
                 overwrite_old_attr=False):
        super().__init__()
        # note: batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info
        self.emb_dim = emb_dim
        self.out_dim = out_dim
        self.add_node_attr_as_self_loop = add_node_attr_as_self_loop
        self.overwrite_old_attr = overwrite_old_attr # remove the old edge-attr

        self.batchnorm = batchnorm
        self.layernorm = layernorm

        self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias)
        torch.nn.init.xavier_uniform_(self.fc.weight)
        self.pad_to_full_graph = pad_to_full_graph
        self.fill_value = 0.

        padding = torch.ones(1, out_dim, dtype=torch.float) * fill_value
        self.register_buffer("padding", padding)

        if self.batchnorm:
            self.bn = nn.BatchNorm1d(out_dim)

        if self.layernorm:
            self.ln = nn.LayerNorm(out_dim)

    def forward(self, batch):
        qcorr_index = batch.qcorr_index
        qcorr_val = batch.qcorr_val
        edge_index = batch.edge_index
        edge_attr = batch.edge_attr

        qcorr_val = self.fc(qcorr_val)


        if edge_attr is None:
            edge_attr = torch.zeros(edge_index.size(1), qcorr_val.size(1))
            # zero padding for non-existing edges

        if self.overwrite_old_attr:
            out_idx, out_val = qcorr_index, qcorr_val
        else:
            edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.)
            out_idx, out_val = torch_sparse.coalesce(
                torch.cat([edge_index, qcorr_index], dim=1),
                torch.cat([edge_attr, qcorr_val], dim=0),
                batch.num_nodes, batch.num_nodes,
                op="add"
            )


        if self.pad_to_full_graph:
            edge_index_full = full_edge_index(out_idx, batch=batch.batch)
            edge_attr_pad = self.padding.repeat(edge_index_full.size(1), 1)
            # zero padding to fully-connected graphs
            out_idx = torch.cat([out_idx, edge_index_full], dim=1)
            out_val = torch.cat([out_val, edge_attr_pad], dim=0)
            out_idx, out_val = torch_sparse.coalesce(
               out_idx, out_val, batch.num_nodes, batch.num_nodes,
               op="add"
            )

        if self.batchnorm:
            out_val = self.bn(out_val)

        if self.layernorm:
            out_val = self.ln(out_val)


        batch.edge_index, batch.edge_attr = out_idx, out_val
        return batch

    def __repr__(self):
        return f"{self.__class__.__name__}" \
               f"(pad_to_full_graph={self.pad_to_full_graph}," \
               f"fill_value={self.fill_value}," \
               f"{self.fc.__repr__()})"



