import torch
import torch.nn as nn

class InputGenMicro(nn.Module):
    def __init__(self, num_subcircuit_qubits:int,
                 num_subcircuits:int, num_layers:int, num_hidden:int=0):
        super().__init__()
        self.thetas = torch.nn.Parameter(torch.empty([num_subcircuits,
                                                      num_layers,
                                                      num_subcircuit_qubits],
                                                      dtype=torch.float32))
        nn.init.uniform_(self.thetas,-torch.pi,torch.pi)
        
        assert num_hidden >= 0
        self.use_hidden = num_hidden > 0
        if not self.use_hidden:
            self.logits = nn.Parameter(torch.empty([num_layers,
                                                    num_subcircuit_qubits,
                                                    num_subcircuit_qubits+3],
                                                    dtype=torch.float32))
            nn.init.uniform_(self.logits,-1.0,1.0)
            self.logit_recipe = [self.logits]
        else:
            self.hidden_matrices = nn.Parameter(torch.empty([num_layers,
                                                             num_subcircuit_qubits,
                                                             num_subcircuit_qubits+3, 
                                                             num_hidden],
                                                             dtype=torch.float32))
            self.hidden_vectors = nn.Parameter(torch.empty(
                [num_layers, num_subcircuit_qubits, num_hidden, 1],
                dtype=torch.float32))
            nn.init.xavier_uniform_(self.hidden_matrices)
            nn.init.xavier_uniform_(self.hidden_vectors)
            self.logit_recipe = [self.hidden_matrices, self.hidden_vectors]
    
    def forward(self):
        if not self.use_hidden:
            return self.thetas, self.logits
        else:
            logits = torch.matmul(self.hidden_matrices, 
                                  self.hidden_vectors).squeeze(-1)
            return self.thetas, logits
