import torch
import torch.nn as nn

from typing import Type

from examples.qas.input_gen_micro import InputGenMicro
from examples.qas.qdarts_micro import QDARTSMicro
from examples.qas.rhodarts_micro import RhoDARTSMicro

class QASMicro(nn.Module):
    def __init__(self, total_qubits:int, num_subcircuit_qubits:int, 
                 super_circuit_structure:torch.Tensor, num_layers:int, 
                 num_hidden:int, model:Type[QDARTSMicro|RhoDARTSMicro],
                 psi0:torch.Tensor|None=None,
                **kwargs):
        super().__init__()
        num_subcircuits = super_circuit_structure.shape[0]
        self.input_gen = InputGenMicro(num_subcircuit_qubits, num_subcircuits,
                                       num_layers, num_hidden)
        self.search = model(total_qubits, num_subcircuit_qubits, num_layers,
                            super_circuit_structure, psi0=psi0, **kwargs)
    
    def forward(self, softmax_temperature:float=1.0, skipValidation:bool=True
                )->tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        angles, logits = self.input_gen()
        qs = self.search(logits, angles, softmax_temperature, skipValidation)
        probs = torch.softmax(logits/softmax_temperature, dim=-1)

        return qs, angles, probs
