import torch 
import torch.nn as nn 
import tqdm 
import os 
import numpy as np 

from gflownet import GFlowNet, GFlowNetEnsemble 
from functools import partial 

try: 
    DISABLE_TQDM = os.environ['DISABLE_TQDM'] 
except KeyError: 
    DISABLE_TQDM = False 

class LogRewardPool: 

    def __init__(self, log_rewards): 
        self.log_rewards = log_rewards 

    @torch.no_grad() 
    def __call__(self, batch_state): 
        log_rewards = torch.zeros((batch_state.batch_size,)) 
        for log_reward in self.log_rewards: 
            log_rewards += log_reward(batch_state) 
        return log_rewards 

def train(gflownet, epochs, create_env, lr, disable_tqdm=DISABLE_TQDM, optimizer=None, use_lr_scheduler=False):
    if optimizer is None:   
        optimizer = torch.optim.AdamW(gflownet.parameters(), lr=lr)
    if use_lr_scheduler:
        scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=epochs, power=1.) 
    
    pbar = tqdm.tqdm(range(epochs), disable=disable_tqdm) 
    for _ in pbar:    
        optimizer.zero_grad() 
        env = create_env() 
        loss = gflownet(env) 
        if (torch.isnan(loss) or torch.isinf(loss)): 
            print(loss) 
            continue 
        loss.backward() 
        optimizer.step()
        if use_lr_scheduler: 
            scheduler.step()  
        pbar.set_postfix(loss=loss) 

    return gflownet  

@torch.no_grad() 
def sample_massive_batch(gflownet, create_env, num_batches, log_reward=None, disable_tqdm=DISABLE_TQDM): 
    env = create_env() 
    env = gflownet.sample(env) 
    if log_reward is None: 
        rewards = env.log_reward()
    else: 
        rewards = log_reward(env) 
     
    assert hasattr(env, 'merge'), f'{env} must have the method `merge`'

    for _ in tqdm.tqdm(range(num_batches), disable=disable_tqdm): 
        env_i = create_env() 
        env_i = gflownet.sample(env_i) 
        if log_reward is None: 
            rewards = torch.hstack([rewards, env_i.log_reward()]) 
        else: 
            rewards = torch.hstack([rewards, log_reward(env_i)])
        env.merge(env_i) 
    
    return env, rewards 

def sample_massive_batch_tree(gflownet, create_env, num_batches, disable_tqdm=DISABLE_TQDM): 
    env = create_env() 
    env = gflownet.sample(env) 
    rewards = env.log_reward() 
    newick = env.to_newick() 

    for _ in tqdm.tqdm(range(num_batches - 1), disable=disable_tqdm): 
        env_i = create_env() 
        env_i = gflownet.sample(env_i)  
        env.merge(env_i) 
        rewards = torch.hstack([rewards, env_i.log_reward()]) 
        newick = np.concatenate([newick, env_i.to_newick()]) 
    
    return env, rewards, newick 

def federated_gflownets(
    create_gfn, 
    create_env,
    unique_smp, 
    num_clients, 
    epochs=int(2e3), 
    num_batches=int(1e2), 
    batch_size_train=512, 
    batch_size_eval=512, 
    lr=1e-3, 
    create_env_args=dict(), 
    flow_args=dict(), 
    is_phylogeny=False 
): 
    gflownets = list() 
    rewards = list() 

    for client in range(num_clients): 
        forward_flow, backward_flow, log_reward = create_gfn(client=client, **create_env_args, **flow_args) 
        gflownet = GFlowNet(forward_flow, backward_flow, off_policy_rate=.5, criterion='cb') 
        create_env_p = partial(create_env, batch_size=batch_size_train, log_reward=log_reward, **create_env_args) 
        train(gflownet, epochs=epochs, create_env=create_env_p, lr=lr)
        
        gflownets.append(gflownet) 
        rewards.append(log_reward) 
        
    # Train an ensemble of these models 
    log_reward_ensemble = LogRewardPool(rewards) 
    forward_flow, backward_flow, _ = create_gfn(client=client, **create_env_args, **flow_args)  
    gfn = GFlowNetEnsemble(forward_flow, backward_flow, gflownets=gflownets, off_policy_rate=.5, criterion='cb')
    create_env_p = partial(create_env, batch_size=batch_size_train, **create_env_args)  
    train(gfn, epochs=epochs, create_env=create_env_p, lr=lr)  

    # Inferences over the local and global models 
    empirical_dist_per_client = list() 
    for gflownet, log_reward in zip(gflownets, rewards): 
        gflownet.eval() 
        create_env_p = partial(create_env, batch_size=batch_size_eval, log_reward=log_reward, **create_env_args) 
        if not is_phylogeny: 
            samples, rewards = sample_massive_batch(gflownet, create_env=create_env_p, num_batches=num_batches) 
            indices, counts = unique_smp(samples) 
        else: 
            _, rewards, newick = sample_massive_batch_tree(gflownet, create_env=create_env_p, num_batches=num_batches) 
            _, indices, counts = np.unique(newick, return_index=True, return_counts=True) 

        empirical_dist_per_client.append({'counts': counts.tolist(), 'rewards': rewards[indices].cpu().tolist()})

    gfn.eval() 
    create_env_p = partial(create_env, batch_size=batch_size_eval, log_reward=log_reward_ensemble, **create_env_args) 
    if not is_phylogeny: 
        samples, rewards = sample_massive_batch(gfn, create_env=create_env_p, num_batches=num_batches) 
        indices, counts = unique_smp(samples) 
    else: 
        _, rewards, newick = sample_massive_batch_tree(gfn, create_env=create_env_p, num_batches=num_batches) 
        _, indices, counts = np.unique(newick, return_index=True, return_counts=True)  
    
    return empirical_dist_per_client, {'counts': counts.tolist(), 'rewards': rewards[indices].cpu().tolist()} 

def multiclient(
    create_gfn, 
    create_var, 
    create_env, 
    create_var_prod, 
    samples_from_state, 
    unique_smp, 
    num_clients, 
    epochs=int(2e3), 
    num_batches=int(1e2), 
    batch_size_train=512,
    batch_size_sampling=int(1e4),   
    lr=1e-3, 
    create_env_args=dict(), 
    flow_args=dict() 
): 
    gflownets = list() 
    var_apprx = list() 
    rewards = list() 

    for client in range(num_clients): 
        fflow, bflow, log_reward = create_gfn(client, **create_env_args, **flow_args) 
        gfn = GFlowNet(fflow, bflow, off_policy_rate=.5, criterion='cb')  
        create_env_p = partial(create_env, log_reward=log_reward, batch_size=batch_size_train, **create_env_args) 
        train(gfn, epochs=epochs, create_env=create_env_p, lr=lr) 
        gflownets.append(gfn) 

        # Sample from the local model, and instantiate a variational approximation 
        gfn.eval() 
        create_env_p = partial(create_env, log_reward=log_reward, batch_size=batch_size_sampling, **create_env_args) 
        samples, _ = sample_massive_batch(gfn, create_env=create_env_p, num_batches=num_batches) 

        # Variational approximation 
        var = create_var(**create_env_args) 
        var.fit(samples) 
        var_apprx.append(var) 
        
        rewards.append(log_reward) 
 
    # Pooled GFlowNet
    fflow, bflow, _ = create_gfn(**create_env_args, **flow_args)  
    gfn = GFlowNetEnsemble(fflow, bflow, gflownets=gflownets, off_policy_rate=.5, criterion='cb') 

    # Train a pooled GFlowNet 
    log_reward = LogRewardPool(rewards) 
    create_env_p = partial(create_env, log_reward=log_reward, batch_size=batch_size_train, **create_env_args) 
    train(gfn, epochs=epochs, create_env=create_env_p, lr=lr) 

    # Sample from the global distribution 
    gfn.eval() 
    create_env_p = partial(create_env, log_reward=log_reward, batch_size=batch_size_sampling, **create_env_args) 
    samples_gfn, log_rewards_gfn = sample_massive_batch(gfn, create_env=create_env_p, num_batches=num_batches) 
    indices_gfn, counts_gfn = unique_smp(samples_gfn) 
    
    # Pooled variational approximations 
    var_apprx_prod = create_var_prod(var_apprx) 
    samples_var = var_apprx_prod.sample(num_samples=num_batches * batch_size_sampling) 
    samples_var = samples_from_state(samples_var, **create_env_args) 
    log_rewards_var = log_reward(samples_var) 
    indices_var, counts_var = unique_smp(samples_var)  

    # Centralized GFlowNet 
    fflow, bflow, _ = create_gfn(**create_env_args, **flow_args)  
    log_reward = LogRewardPool(rewards) 
    gfn_central = GFlowNet(fflow, bflow, off_policy_rate=.5, criterion='cb') 

    # Train the centralized GFlowNet 
    create_env_p = partial(create_env, log_reward=log_reward, batch_size=batch_size_train, **create_env_args) 
    train(gfn_central, epochs=epochs, create_env=create_env_p, lr=lr) 

    # Sample from the global distribution 
    gfn_central.eval() 
    create_env_p = partial(create_env, log_reward=log_reward, batch_size=batch_size_sampling, **create_env_args) 
    samples_gfn_central, log_rewards_gfn_central = sample_massive_batch(gfn_central, create_env=create_env_p, num_batches=num_batches) 
    indices_gfn_central, counts_gfn_central = unique_smp(samples_gfn_central) 
    
    # Compute the L1 distance between the distributions 
    l1_gfn = counts_gfn / counts_gfn.sum() - (log_rewards_gfn[indices_gfn] - torch.logsumexp(log_rewards_gfn[indices_gfn], dim=0)).exp().cpu().numpy()  
    l1_var = counts_var / counts_var.sum() - (log_rewards_var[indices_var] - torch.logsumexp(log_rewards_var[indices_var], dim=0)).exp().cpu().numpy() 
    l1_gfn_central = counts_gfn_central / counts_gfn_central.sum() \
                    - (log_rewards_gfn_central[indices_gfn_central] - torch.logsumexp(log_rewards_gfn_central[indices_gfn_central], dim=0)).exp().cpu().numpy()  

    return {
        'gfn': gfn, 
        'var': var_apprx_prod, 
        'gfn_central': gfn_central, 
        'gfn_l1': l1_gfn, 
        'var_l1': l1_var, 
        'gfn_central_l1': l1_gfn_central, 
        'gfn_l1agg': np.abs(l1_gfn).sum(), 
        'var_l1agg': np.abs(l1_var).sum(), 
        'gfn_central_l1agg': np.abs(l1_gfn_central).sum(), 
        'gfn_log_rewards': log_rewards_gfn, 
        'var_log_rewards': log_rewards_var, 
        'gfn_central_log_rewards': log_rewards_gfn_central  
    }   

