import torch 
import numpy as np 

from .envs import Sequence, LogRewardLinear 
from .flows import ForwardFlow, BackwardFlow, StateFlow 

@torch.no_grad() 
def create_gfn(client=None, max_size=None, vocab_size=None, emb_dim=None, hidden_dim=None, device=None, **kwargs): 
    if client is not None: 
        log_reward = LogRewardLinear(max_size, vocab_size, seed=(client + 1) * 42, device=device) 
    else: 
        log_reward = None 

    forward_flow = ForwardFlow(emb_dim=emb_dim, vocab_size=vocab_size, hidden_dim=hidden_dim) 
    backward_flow = BackwardFlow() 
    return forward_flow, backward_flow, log_reward 

@torch.no_grad() 
def create_env(log_reward=None, batch_size=None, max_size=None, vocab_size=None, **kwargs): 
    return Sequence(max_size=max_size, vocab_size=vocab_size, batch_size=batch_size, log_reward=log_reward) 

@torch.no_grad() 
def unique_smp(samples): 
    _, indices, counts = np.unique(samples.state.cpu(), axis=0, return_index=True, return_counts=True) 
    return indices, counts 

@torch.no_grad() 
def create_state_flow(hidden_dim=None, emb_dim=None, vocab_size=None, **kwargs): 
    state_flow = StateFlow(hidden_dim=hidden_dim, emb_dim=emb_dim, vocab_size=vocab_size) 
    return state_flow 