import torch 
import torch.nn as nn 

from .envs import Trees, LikelihoodReward
from .flows import ForwardFlow, BackwardFlow, StateFlow, WarmupReward 

def create_gfn(client=None, num_leaves=None, vocab_size=None, hidden_dim=None, device='cpu', **kwargs):
    forward_flow = ForwardFlow(hidden_dim=hidden_dim, num_leaves=num_leaves).to(device) 
    backward_flow = BackwardFlow() 

    if client is not None:
        g = torch.Generator(device=device) 
        g.manual_seed((client + 1) * 42) 
        sites = torch.randint(vocab_size, size=(64, num_leaves), generator=g) 
        pi = torch.ones((vocab_size,)) / vocab_size 
        log_reward = LikelihoodReward(pi, sites, vocab_size) 

        create_env = lambda: Trees(num_leaves=num_leaves, batch_size=128, log_reward=log_reward) 
        warmup_reward = WarmupReward(forward_flow, warmup_epochs=int(2e2)) 
        warmup_reward(create_env) 
        # Update the parameters of the rewards 
        log_reward.mu = warmup_reward.mu 
        log_reward.std = warmup_reward.std         
    else: 
        log_reward = None 

    return forward_flow, backward_flow, log_reward 
 
def create_state_flow(num_leaves=None, hidden_dim=None, device='cpu', **kwargs): 
    state_flow = StateFlow(num_leaves=num_leaves, hidden_dim=hidden_dim).to(device)  
    return state_flow 
    # pass 

def create_env(num_leaves=None, batch_size=None, log_reward=None, **kwargs):
    return Trees(num_leaves=num_leaves, batch_size=batch_size, log_reward=log_reward) 

# def unique_smp(samples, newick, **kwargs): 
#     pass the sample_massive_batch_tree procedure already returns a collection of uniquely defined trees 