import torch
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.gnn import GNNPreMP
from torch_geometric.graphgym.models.layer import (new_layer_config,
                                                   BatchNorm1dNode)
from torch_geometric.graphgym.register import register_network

from torch.nn import Tanh
from  torch_geometric.nn.conv import GATConv, SAGEConv, GCNConv
from torch_geometric.nn.pool import global_mean_pool

def extract_common_tensor(data_batch,batch_num):
    # batch num va de 0 BS -1 
    idx_min,idx_max = data_batch.ptr[batch_num].item(), data_batch.ptr[batch_num+1].item()
    Com_ij = torch.zeros((idx_max-idx_min,idx_max-idx_min))
    idx_vals = torch.where((data_batch.common_index<idx_max) & (data_batch.common_index>=idx_min))[1]
    
    pairs = data_batch.common_index[:,idx_vals]
    Com_ij[pairs[0]-idx_min,pairs[1]-idx_min]+= data_batch.common_val[idx_vals]
    return Com_ij + Com_ij.transpose(0,1)


def extract_adj_tensor(data_batch,batch_num):
    # batch num va de 0 BS -1 
    
    idx_min,idx_max = data_batch.ptr[batch_num].item(), data_batch.ptr[batch_num+1].item()
    idx_vals = torch.where((data_batch.edge_index<idx_max) & (data_batch.edge_index>=idx_min))[1]
    
    return 0.5*to_dense_adj(data_batch.edge_index[:,idx_vals]-idx_min)[0]


class CorrelationMatrix:
    def __init__(self,
                 Gnn_encoder: torch.Tensor,
                 k : int,
                 device,
                 ) -> torch.Tensor:
        
        super().__init__()
        # P is a [1,3*k] dimensional tensor which contains the values of  
        # theta, t and h, that we later reshape for more efficiency 
        self.device = device
        self.Gnn_encoder = Gnn_encoder
        self.k = k 

        
    def w_ij(self, Adj, theta, t):
        # for this method, the Adj is an individual (not batched) adjacency matrix 

        return (torch.cos(theta[:,:,None])**2 + (torch.sin(theta[:,:,None])**2)*torch.exp(Adj*t[:,:,None]*1j)).to(self.device)


    def w_plus(self, Adj, com_ij, theta, t):
        # for this method, the Adj is an individual (not batched) adjacency matrix 

        B = (torch.cos(theta)**2 + torch.sin(theta)**2 * torch.exp(1j * t)) \
        * (torch.cos(theta)**2 + torch.sin(theta)**2 * torch.exp(-1j * t))
        B = B[:,:,None]

        return torch.pow((torch.cos(theta[:,:,None])**2 + (torch.sin(theta[:,:,None])**2) *
                          torch.exp(2 * t[:,:,None] * 1j)), com_ij).to(self.device) * (B**(1-Adj))
    
    def w_minus(self, Adj, theta, t):
        # for this method, the Adj is an individual (not batched) adjacency matrix 

        return (torch.cos(theta[:,:,None])**2 + torch.sin(theta[:,:,None])**2 * torch.exp(1j * t[:,:,None])) \
        * (torch.cos(theta[:,:,None])**2 + torch.sin(theta[:,:,None])**2 * torch.exp(-1j * t[:,:,None])) ** (1 - Adj)


    def compute_correlation_matrix_batched(self, data_batch):
        P = self.Gnn_encoder(data_batch) 
        X_corrs_list = []
        E_ij_corrs_list = []
        indexes = []
        for bn in range(data_batch.ptr.shape[0]-1):
            
            theta = (P[bn][:self.k].reshape(self.k,1)).to(self.device)
            t = (P[bn][self.k:2*self.k].reshape(self.k,1)).to(self.device)
            h = (P[bn][2*self.k:].reshape(self.k,1)).to(self.device)

            Adj =  extract_adj_tensor(data_batch,bn).to(self.device)
            N = Adj.shape[0]
            com_ij = extract_common_tensor(data_batch,bn).to(self.device)

            F = ((4*(torch.sin(theta)**4) * (torch.cos(theta)**4))).to(self.device)
            W = self.w_ij(Adj,theta,t).to(self.device)
            
            rho_vect = torch.exp(h * t* 1j)*torch.prod(W, 2)

            rho_col = ((rho_vect.reshape(k,N,1).repeat(1,1,N)).reshape(-1,N))
            rho_row = (torch.repeat_interleave(rho_vect,N,dim = 0))
            rho_ij = ((rho_col + rho_row).reshape(k,N,N)).to(self.device)

            f1 = (rho_ij * (1 - 1/W)).to(self.device) #####
            a = (.5 * (1 - (torch.exp(Adj * t[:,:,None] * 1j) / self.w_plus(Adj, com_ij,theta,t)))).to(self.device)
            
            b = (rho_row.reshape(k,N,N) * rho_col.reshape(k,N,N)).to(self.device)
            f2 = (a * b).to(self.device)
            b_conj = (rho_row.reshape(k,N,N) * torch.conj(rho_col.reshape(k,N,N))).to(self.device)

            f3 = (.5 * (1 - (1 / self.w_minus(Adj,theta,t))) * b_conj).to(self.device)

            corr = F[:,:,None] * torch.real(f1 + f2 + f3)

            self_cors = torch.stack([corr[:,i,i] for i in range(N)])
            cross_cors = torch.stack([corr[:,i,j] for i in range(N) for j in range(N)])
            bn_indexes = [(data_batch.ptr[bn]+i, data_batch.ptr[bn]+j) for i in range(N) for j in range(N)]
                        
            indexes.append(
                torch.stack((torch.tensor([i for i,j in bn_indexes]),torch.tensor([j for i,j in bn_indexes])),0)
                        )

            X_corrs_list.append(self_cors)
            E_ij_corrs_list.append(cross_cors)

        #######
        data_batch.qcorr = torch.cat(X_corrs_list,0).to(self.device)
        data_batch.qcorr_val = torch.cat(E_ij_corrs_list,0).to(self.device)
        data_batch.qcorr_index = torch.cat(indexes,1).to(self.device)

        return data_batch


class Param_encoder(torch.nn.Module):
    def __init__(self,conv_layer, max_dim, 
                 hidden_channels, k, num_layers_encoder):
        super().__init__()
        self.num_layers = num_layers_encoder
        self.convs = torch.nn.ModuleList()
        self.lins = torch.nn.ModuleDict()
        self.conv_layer = conv_layer

        
        conv = conv_layer(max_dim, hidden_channels)
        self.convs.append(conv)
        for i in range(1,num_layers_encoder-1):
            conv = conv_layer(hidden_channels, hidden_channels)
            self.convs.append(conv)
        conv = conv_layer(hidden_channels, hidden_channels)
        self.convs.append(conv)        
        self.lin = Linear(hidden_channels, 3*k)
        self.final_act = Tanh()


    def forward(self, data_batch):
        x = data_batch.dist.float()
        edge_index = data_batch.edge_index
        edge_attr = data_batch.edge_attr
        
        for i in range(self.num_layers):

            x = self.convs[i](x , edge_index, edge_attr).float()

        x = self.lin(x.float())
        return  (1+ self.final_act(global_mean_pool(x,data_batch.batch)))/2


class FeatureEncoder(torch.nn.Module):
    """
    Encoding node and edge features

    Args:
        dim_in (int): Input feature dimension
    """
    def __init__(self, dim_in):
        super(FeatureEncoder, self).__init__()
        self.dim_in = dim_in
        if cfg.dataset.node_encoder:
            # Encode integer node features via nn.Embeddings
            NodeEncoder = register.node_encoder_dict[
                cfg.dataset.node_encoder_name]
            self.node_encoder = NodeEncoder(cfg.gnn.dim_inner)
            if cfg.dataset.node_encoder_bn:
                self.node_encoder_bn = BatchNorm1dNode(
                    new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False,
                                     has_bias=False, cfg=cfg))
            # Update dim_in to reflect the new dimension fo the node features
            self.dim_in = cfg.gnn.dim_inner
        if cfg.dataset.edge_encoder:
            # Hard-limit max edge dim for PNA.
            if 'PNA' in cfg.gt.layer_type:
                cfg.gnn.dim_edge = min(128, cfg.gnn.dim_inner)
            else:
                cfg.gnn.dim_edge = cfg.gnn.dim_inner
            # Encode integer edge features via nn.Embeddings
            EdgeEncoder = register.edge_encoder_dict[
                cfg.dataset.edge_encoder_name]
            self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge)
            if cfg.dataset.edge_encoder_bn:
                self.edge_encoder_bn = BatchNorm1dNode(
                    new_layer_config(cfg.gnn.dim_edge, -1, -1, has_act=False,
                                     has_bias=False, cfg=cfg))

    def forward(self, batch):
        for module in self.children():
            batch = module(batch)
        return batch


@register_network('GritTransformerQuantum')
class GritTransformerQuantum(torch.nn.Module):
    '''
        The Quantum-positional encoding version of GritTransformer 
    '''

    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.encoder = FeatureEncoder(dim_in)
        dim_in = self.encoder.dim_in

        self.ablation = True
        self.ablation = False

        if cfg.posenc_Quantum.enable:
            # here no concatenation is done 
            self.quantum_abs_encoder = register.node_encoder_dict["Qcorr"]\
                (cfg.posenc_Quantum.ksteps, cfg.gnn.dim_inner)
            rel_pe_dim = cfg.posenc_Quantum.ksteps
            self.rrwp_rel_encoder = register.edge_encoder_dict["Qcorr"] \
                (rel_pe_dim, cfg.gnn.dim_edge,
                 pad_to_full_graph=cfg.gt.attn.full_attn,
                 add_node_attr_as_self_loop=False,
                 fill_value=0.
                 )


        if cfg.gnn.layers_pre_mp > 0:
            self.pre_mp = GNNPreMP(
                dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp)
            dim_in = cfg.gnn.dim_inner

        assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \
            "The inner and hidden dims must match."

        #global_model_type = cfg.gt.get('layer_type', "GritTransformer")
        global_model_type = "GritTransformerQuantum"

        TransformerLayer = register.layer_dict.get(global_model_type)

        layers = []
        for l in range(cfg.gt.layers):
            layers.append(TransformerLayer(
                in_dim=cfg.gt.dim_hidden,
                out_dim=cfg.gt.dim_hidden,
                num_heads=cfg.gt.n_heads,
                dropout=cfg.gt.dropout,
                act=cfg.gnn.act,
                attn_dropout=cfg.gt.attn_dropout,
                layer_norm=cfg.gt.layer_norm,
                batch_norm=cfg.gt.batch_norm,
                residual=True,
                norm_e=cfg.gt.attn.norm_e,
                O_e=cfg.gt.attn.O_e,
                cfg=cfg.gt,
            ))

        # layers = []

        self.layers = torch.nn.Sequential(*layers)
        
        if cfg.gnn.param_encoder_layer == 'GAT':
            conv_layer = GATConv

        param_encoder = Param_encoder(conv_layer, cfg.dataset.max_dim, cfg.posenc_Quantum.param_hidden,
                                      cfg.posenc_Quantum.ksteps, cfg.posenc_Quantum.param_num_layer)
        
        self.correlation = CorrelationMatrix(param_encoder, cfg.posenc_Quantum.ksteps, "cpu")

        GNNHead = register.head_dict[cfg.gnn.head]
        self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)

    def forward(self, batch):
        batch = self.correlation.compute_correlation_matrix_batched(batch)
        for module in self.children():
            batch = module(batch)

        return batch







