import torch
from torch import nn 

from torch_geometric.graphgym.register import (
    register_node_encoder,register_edge_encoder
)

from torch.nn import Linear
from torch_geometric.utils import to_dense_adj
import numpy as np 


class CorrelationMatrix:
    def __init__(self,
                 P: torch.Tensor,
                 k : int,
                 device,
                 ) -> torch.Tensor:
        
        super().__init__()
        # P is a [1,3*k] dimensional tensor which contains the values of  
        # theta, t and h, that we later reshape for more efficiency 
        self.device = device
        self.P = P
        self.k = k 


    def extract_common_tensor(self,data_batch,batch_num):
        # batch num va de 0 BS -1 
        idx_min,idx_max = data_batch.ptr[batch_num].item(), data_batch.ptr[batch_num+1].item()
        Com_ij = torch.zeros((idx_max-idx_min,idx_max-idx_min)).to(self.device)
        idx_vals = torch.where((data_batch.common_index<idx_max) & (data_batch.common_index>=idx_min))[1]
        
        pairs = data_batch.common_index[:,idx_vals]
        Com_ij[pairs[0]-idx_min,pairs[1]-idx_min]+= data_batch.common_val[idx_vals]
        return Com_ij + Com_ij.transpose(0,1)
    
    def extract_adj_tensor(self,data_batch,batch_num):
        # batch num ranges in (0, BS -1) 
        
        idx_min,idx_max = data_batch.ptr[batch_num].item(), data_batch.ptr[batch_num+1].item()
        idx_vals = torch.where((data_batch.edge_index<idx_max) & (data_batch.edge_index>=idx_min))[1]
        
        return (0.5*to_dense_adj(data_batch.edge_index[:,idx_vals]-idx_min)[0]).to(self.device)

        
    def w_ij(self, Adj, theta, t):
        # for this method, the Adj is an individual (not batched) adjacency matrix 

        return (torch.cos(theta[:,:,None])**2 + (torch.sin(theta[:,:,None])**2)*torch.exp(Adj*t[:,:,None]*1j)).to(self.device)


    def w_plus(self, Adj, com_ij, theta, t):
        # for this method, the Adj is an individual (not batched) adjacency matrix 

        B = (torch.cos(theta)**2 + torch.sin(theta)**2 * torch.exp(1j * t)) \
        * (torch.cos(theta)**2 + torch.sin(theta)**2 * torch.exp(-1j * t))
        B = B[:,:,None]

        return torch.pow((torch.cos(theta[:,:,None])**2 + (torch.sin(theta[:,:,None])**2) *
                          torch.exp(2 * t[:,:,None] * 1j)), com_ij).to(self.device) * (B**(1-Adj))
    
    def w_minus(self, Adj, theta, t):
        # for this method, the Adj is an individual (not batched) adjacency matrix 

        return (torch.cos(theta[:,:,None])**2 + torch.sin(theta[:,:,None])**2 * torch.exp(1j * t[:,:,None])) \
        * (torch.cos(theta[:,:,None])**2 + torch.sin(theta[:,:,None])**2 * torch.exp(-1j * t[:,:,None])) ** (1 - Adj)


    def compute_correlation_matrix_batched(self, data_batch):

        P = self.P
        X_corrs_list = []
        E_ij_corrs_list = []
        indexes = []
        for bn in range(data_batch.ptr.shape[0]-1):
            #check the validity of the range of t and h 
            theta = 2*np.pi*(P[bn][:self.k].reshape(self.k,1)).to(self.device)
            t = 10*(P[bn][self.k:2*self.k].reshape(self.k,1)).to(self.device)
            h = 100*(P[bn][2*self.k:].reshape(self.k,1)).to(self.device)

            Adj =  self.extract_adj_tensor(data_batch,bn)
            N = Adj.shape[0]
            com_ij = self.extract_common_tensor(data_batch,bn)

            F = ((4*(torch.sin(theta)**4) * (torch.cos(theta)**4))).to(self.device)
            W = self.w_ij(Adj,theta,t).to(self.device)
            
            rho_vect = torch.exp(h * t* 1j)*torch.prod(W, 2)

            rho_col = ((rho_vect.reshape(self.k,N,1).repeat(1,1,N)).reshape(-1,N))
            rho_row = (torch.repeat_interleave(rho_vect,N,dim = 0))
            rho_ij = ((rho_col + rho_row).reshape(self.k,N,N)).to(self.device)

            f1 = (rho_ij * (1 - 1/W)).to(self.device) #####
            a = (.5 * (1 - (torch.exp(Adj * t[:,:,None] * 1j) / self.w_plus(Adj, com_ij,theta,t)))).to(self.device)
            
            b = (rho_row.reshape(self.k,N,N) * rho_col.reshape(self.k,N,N)).to(self.device)
            f2 = (a * b).to(self.device)
            b_conj = (rho_row.reshape(self.k,N,N) * torch.conj(rho_col.reshape(self.k,N,N))).to(self.device)

            f3 = (.5 * (1 - (1 / self.w_minus(Adj,theta,t))) * b_conj).to(self.device)

            corr = F[:,:,None] * torch.real(f1 + f2 + f3)

            self_cors = torch.stack([corr[:,i,i] for i in range(N)])
            cross_cors = torch.stack([corr[:,i,j] for i in range(N) for j in range(N)])
            bn_indexes = [(data_batch.ptr[bn]+i, data_batch.ptr[bn]+j) for i in range(N) for j in range(N)]
                        
            indexes.append(
                torch.stack((torch.tensor([i for i,j in bn_indexes]),torch.tensor([j for i,j in bn_indexes])),0)
                        )

            X_corrs_list.append(self_cors)
            E_ij_corrs_list.append(cross_cors)

        #######
        data_batch.qcorr = torch.cat(X_corrs_list,0).to(self.device)
        data_batch.qcorr_val = torch.cat(E_ij_corrs_list,0).to(self.device)
        data_batch.qcorr_index = torch.cat(indexes,1).to(self.device)

        return data_batch

class CorrelationMatrixBatched : 
    def __init__(self,
             P: torch.Tensor,
             k : int,
             device,
             ) -> torch.Tensor:

        super().__init__()
        self.device = device
        self.k = k 
    
        theta = P[:,:self.k].to(self.device)
        t = P[:,self.k:2*self.k].to(self.device)
        h = P[:,2*self.k:].to(self.device)

        self.h = h[:,:,None][:,:,None]
        self.theta =  theta[:,:,None][:,:,None]
        self.t =  t[:,:,None][:,:,None]


    def extract_common_tensor_batched(self,data_batch):
        # batch num va de 0 BS -1 
        max_dim = data_batch.dist.shape[1]
        L = []
        for batch_num in range(data_batch.ptr.shape[0]-1):        
            idx_min,idx_max = data_batch.ptr[batch_num].item(), data_batch.ptr[batch_num+1].item()
            Com_ij = torch.zeros((idx_max-idx_min,idx_max-idx_min)).to(self.device)
            idx_vals = torch.where((data_batch.common_index<idx_max) & (data_batch.common_index>=idx_min))[1].to(self.device)
            pairs = data_batch.common_index[:,idx_vals]


            Com_ij[pairs[0]-idx_min,pairs[1]-idx_min]+= data_batch.common_val[idx_vals]

            tri_sup = Com_ij.transpose(0,1).clone().fill_diagonal_(0)
            t = Com_ij + tri_sup
            m = nn.ZeroPad2d((0, max_dim - t.shape[0],0, max_dim - t.shape[0]))
            L.append(m(t))
        return torch.cat(L,0).view(data_batch.ptr.shape[0]-1,max_dim,max_dim)

    def extract_adj_batched(self,data_batch):
        max_dim = data_batch.dist.shape[1]
        L = []

        for batch_num in range(data_batch.ptr.shape[0]-1): 
            idx_min,idx_max = data_batch.ptr[batch_num].item(), data_batch.ptr[batch_num+1].item()
            idx_vals = torch.where((data_batch.edge_index<idx_max) & (data_batch.edge_index>=idx_min))[1]
            A = (0.5*to_dense_adj(data_batch.edge_index[:,idx_vals]-idx_min)[0])
            m = nn.ZeroPad2d((0, max_dim - A.shape[0],0, max_dim - A.shape[0]))
            L.append(m(A))

        return torch.cat(L,0).view(data_batch.ptr.shape[0]-1,max_dim,max_dim)

    def extract_deg_batched_row(self,data_batch):
        max_dim = data_batch.dist.shape[1]
        L = []

        for batch_num in range(data_batch.ptr.shape[0]-1): 
            deg_vect = data_batch.deg[data_batch.ptr[batch_num].item() : data_batch.ptr[batch_num+1].item()]

            deg_2D = deg_vect.repeat(deg_vect.shape[0],1)

            m = nn.ZeroPad2d((0, max_dim - deg_2D.shape[0],0, max_dim - deg_2D.shape[0]))
            L.append(m(deg_2D))

        return torch.stack(L,0)

    def extract_deg_batched_col(self,data_batch):
        max_dim = data_batch.dist.shape[1]
        L = []

        for batch_num in range(data_batch.ptr.shape[0]-1): 
            deg_vect = data_batch.deg[data_batch.ptr[batch_num].item() : data_batch.ptr[batch_num+1].item()]

            deg_2D = deg_vect.view(-1,1).repeat(1,deg_vect.shape[0])
            m = nn.ZeroPad2d((0, max_dim - deg_2D.shape[0],0, max_dim - deg_2D.shape[0]))
            L.append(m(deg_2D))
        return torch.stack(L,0)

    def duplicate_to_k(self,tensor_3D,data_batch):
        N = data_batch.dist.shape[1]
        BS = data_batch.ptr.shape[0]-1
        return torch.repeat_interleave(tensor_3D,self.k,0).reshape(BS,self.k,N,N)

    def compute_n1_n2(self,data_batch):
        # first put each value of the encoded parameters in its own tensor


        #i_h_t = torch.exp(1j*self.h*self.t).to(self.device)
        i_t = torch.cos(self.theta)**2 + (torch.sin(self.theta)**2)*torch.exp(1j*self.t).to(self.device)
        minus_it = torch.cos(self.theta)**2 + (torch.sin(self.theta)**2)*torch.exp(-1j*self.t).to(self.device)
        two_i_t = torch.cos(self.theta)**2 + (torch.sin(self.theta)**2)*torch.exp(2j*self.t).to(self.device)
        pre_fact = 4*(torch.sin(self.theta)**4)*(torch.cos(self.theta)**4).to(self.device).to(self.device)

        deg_row = self.duplicate_to_k(self.extract_deg_batched_row(data_batch),data_batch).to(self.device)
        deg_col = self.duplicate_to_k(self.extract_deg_batched_col(data_batch),data_batch).to(self.device)
        Adj = self.duplicate_to_k(self.extract_adj_batched(data_batch),data_batch).to(self.device)
        com = self.duplicate_to_k(self.extract_common_tensor_batched(data_batch),data_batch).to(self.device)

        #h_adj = torch.exp((Adj + 2*self.h)*self.t*1j)

        f1 = 1 - torch.exp(1j*self.h*self.t)*(i_t**(deg_row - Adj) + i_t**(deg_col - Adj))
        f2 = 0.5*torch.exp((Adj + 2*self.h)*self.t*1j)*(two_i_t**com)*(i_t**(deg_row+deg_col-2*com-2*Adj))
        f3 = 0.5*(i_t**(deg_col-com-Adj))*(minus_it**(deg_row-com-Adj))

        return pre_fact*torch.real(f1 + f2 + f3)

    def compute_n(self,data_batch):
        # first put each value of the encoded parameters in its own tensor

        pre_fact = (2*(torch.sin(self.theta)**2)*(torch.cos(self.theta)**2)).to(self.device)
        i_h_t = torch.exp(1j*self.h*self.t).to(self.device)
        i_t = (torch.cos(self.theta)**2 + (torch.sin(self.theta)**2)*torch.exp(1j*self.t)).to(self.device)

        deg_row = self.duplicate_to_k(self.extract_deg_batched_row(data_batch),data_batch).to(self.device)
        deg_col = self.duplicate_to_k(self.extract_deg_batched_col(data_batch),data_batch).to(self.device)
        
        n_row = pre_fact*torch.real(1 - i_h_t*(i_t**deg_row))
        n_col = pre_fact*torch.real(1 - i_h_t*(i_t**deg_col))
        return n_row * n_col 

    def compute_correlation(self,data_batch):
        
        BS = data_batch.ptr.shape[0]-1
        correlation =  self.compute_n1_n2(data_batch) - self.compute_n(data_batch)
        
        X_i = torch.cat([(torch.cat([
                        torch.diag(correlation[bn][ik])[:(data_batch.ptr[bn+1]-
                        data_batch.ptr[bn]).item()].view(-1,1)
                        for ik in range(self.k)],1)) for bn in range(BS)],0)

        E_ij = torch.cat([torch.stack([correlation[bn][k_i]
                         [:data_batch.ptr[bn+1]-data_batch.ptr[bn],:data_batch.ptr[bn+1]
                          -data_batch.ptr[bn]].flatten() for k_i in range(self.k)],1) for bn in range(BS)])
        
        indexes = torch.transpose(torch.tensor([[i,j] for bn in range(data_batch.ptr.shape[0]-1) for i in range(data_batch.ptr[bn],data_batch.ptr[bn+1]) for j in 
         range(data_batch.ptr[bn],data_batch.ptr[bn+1])]),0,1)
        
        data_batch.qcorr = X_i
        data_batch.qcorr_val = E_ij
        data_batch.qcorr_index = indexes
        
        return data_batch
        
        
        
        
        
        
