import torch
import torch.nn as nn

from typing import Type

from examples.qas.input_gen import InputGen
from examples.qas.qdarts import QDARTS
from examples.qas.rhodarts import RhoDARTS

class QAS(nn.Module):
    def __init__(self, n:int,m:int,K:int, 
        model:Type[QDARTS|RhoDARTS], *args, **kwargs):
        super().__init__()
        self.input_gen = InputGen(n,m,K)
        self.search = model(n,m,*args,**kwargs)

    def forward(self, softmax_temperature:float=1.0, 
                skipValidation:bool=True):
        angles, logits = self.input_gen()
        qs = self.search(logits, angles, softmax_temperature,skipValidation)
        probs = torch.softmax(logits/softmax_temperature,dim=-1)        
                                
        return qs, angles, probs
        
