import torch 
import numpy as np 

from .envs import Set, LogRewardLinear 
from .flows import ForwardFlow, BackwardFlow, StateFlow 

@torch.no_grad() 
def create_gfn(client=None, warehouse_size=None, hidden_dim=None, emb_dim=None, device=None, **kwargs): 
    forward_flow = ForwardFlow(emb_dim=emb_dim, hidden_dim=hidden_dim, warehouse_size=warehouse_size) 
    backward_flow = BackwardFlow() 
    if client is not None: 
        log_reward = LogRewardLinear((client + 1) * 42, warehouse_size, device=device) 
    else: 
        log_reward = None 
    return forward_flow, backward_flow, log_reward 

@torch.no_grad() 
def unique_smp(samples): 
    _, indices, counts = np.unique(samples.sorted_state.values.cpu(), axis=0, return_index=True, return_counts=True) 
    return indices, counts 

@torch.no_grad() 
def create_env(log_reward=None, batch_size=None, set_size=None, warehouse_size=None, **kwargs): 
    return Set(set_size, warehouse_size, batch_size, log_reward=log_reward) 

@torch.no_grad() 
def create_state_flow(emb_dim=None, hidden_dim=None, warehouse_size=None, **kwargs): 
    state_flow = StateFlow(emb_dim=emb_dim, hidden_dim=hidden_dim, warehouse_size=warehouse_size) 
    return state_flow 