import torch
import torch.nn as nn

import warnings

from qtorch import RTYPE

class InputGen(nn.Module):
    '''
    Generates the inputs (logits and angles) for the QAS algorithms
    
    Attributes
    ----------
    use_hidden: bool
        Whether the logits are generated through matrix multiplication [True] or
        directly [False]
    thetas: torch.Tensor
        Shape [L, Q] - The angles of rotation for each gate in the architecture,
        initialized from the uniform distribution [-pi, pi]
    
    ** Only an attribute if use_hidden = False **
    logits: torch.Tensor
        Shape [L, Q, Q+3] - The unnormalized log probabilities for each gate in
        the architecture. 
    
    ** Only an attribute if use_hidden = True **
        logits[i,j,:] = hidden_matrices[i,j,:,:] @ hidden_vectors[i,j,:,:]
    hidden_matrices: torch.Tensor
        Shape [L, Q, Q+3, K] - Hidden matrices to generate the logits
    hidden_vectors: torch.Tensor
        Shape [L, Q, K, 1] - Hidden vectors to generate the logits
    
    Methods
    -------
    forward()
        Returns the angles of rotation and the logits
    '''
    def __init__(self, num_qubits:int, num_layers:int, num_hidden:int=0):
        '''
        Arguments
        ---------
        num_qubits: int
            The number of qubits in the architecture
        num_layers: int
            The number of layers in the architecture
        num_hidden: int, optional
            The dimension of the hidden layer used to generate the logits. 
            If 0, the logits are generated directly
            Default=0
        '''
        super().__init__()
        self.thetas = nn.Parameter(torch.empty([num_layers,num_qubits],dtype=RTYPE))
        self.use_hidden = num_hidden > 0
        if not self.use_hidden:
            if num_hidden < 0:
                warnings.warn('`num_hidden` < 0, assuming no hidden layer')
            self.logits = nn.Parameter(torch.empty([num_layers,num_qubits,num_qubits+3],dtype=RTYPE))
            nn.init.uniform_(self.logits,-1.0,1.0)
            self.logit_recipe = [self.logits]
        else:
            self.hidden_matrices = nn.Parameter(torch.empty([num_layers,num_qubits,num_qubits+3, num_hidden],dtype=RTYPE))
            self.hidden_vectors = nn.Parameter(torch.empty([num_layers,num_qubits,num_hidden,1],dtype=RTYPE))
            nn.init.xavier_uniform_(self.hidden_matrices)
            nn.init.xavier_uniform_(self.hidden_vectors)
            self.logit_recipe = [self.hidden_matrices, self.hidden_vectors]

        nn.init.uniform_(self.thetas, -torch.pi,torch.pi)
    
    def forward(self):
        if not self.use_hidden:
            return self.thetas, self.logits
        else:
            logits = torch.matmul(self.hidden_matrices,self.hidden_vectors).squeeze(-1)
            return self.thetas, logits
