import torch 
import torch.nn as nn 
import numpy as np 
from copy import deepcopy 

from sal.utils import marginal_dist, Environment  
from sal.gym import Hypergrid, Sequences, Set  

from typing import List, Tuple 

def huber_loss(diff, delta=9.):
    diff_sq = (diff * diff)
    larger_than_delta = (diff_sq > delta ** 2).to(dtype=diff.dtype)
    return (
        delta * (diff_sq - .5 * delta) * larger_than_delta + \
        .5 * diff_sq * (1 - larger_than_delta)
    ).mean()

@torch.jit.script 
def _subtraj_loss(i: torch.Tensor, j: torch.Tensor, traj_stats: torch.Tensor, F_traj: torch.Tensor, last_idx: torch.Tensor): 
    traj_stats = torch.cumsum(traj_stats, dim=2) 
    pf = traj_stats[0, :, j] - traj_stats[0, :, i]  
    pb = traj_stats[1, :, j] - traj_stats[1, :, i] 
    loss = pf - pb + F_traj[:, i] - F_traj[:, j] 
    flag = (last_idx.unsqueeze(1) >= j.unsqueeze(0)).type(loss.dtype) 
    lamb = (.99 ** (j - i).view(1, -1)) * flag 
    loss = (loss.pow(2) * lamb).sum(dim=1) / lamb.sum(dim=1)   
    return loss.mean() 

class GFlowNet(nn.Module): 

    def __init__(self, pf, pb, lamb_subtb=.9, criterion='tb', device='cpu'): 
        super(GFlowNet, self).__init__() 
        self.pf = pf
        self.pb = pb 
        self.log_z = nn.Parameter(torch.randn((1,), dtype=torch.get_default_dtype(), device=device).squeeze(), requires_grad=True) 
        
        self.criterion = criterion 
        self.device = device 
        self.lamb_subtb = lamb_subtb 

    def mask_from_traj(self, batch_state, last_idx):
        return last_idx.view(-1, 1) > torch.arange(
            batch_state.max_trajectory_length, device=self.device
        ).view(1, -1) 
            
    def forward(self, batch_state, return_target=False, return_trajectories=False, **kwargs): 

        if self.criterion == 'fm': 
            return self._flow_matching(batch_state), None   

        out = self._sample_traj(
            batch_state, return_gamma=False, return_trajectories=return_trajectories, **kwargs
        )

        trajectories = None  
        if return_trajectories:
            traj_stats, F_traj, last_idx, trajectories = out 
            loss = self.evaluate_loss_on_transitions(batch_state, traj_stats, F_traj, last_idx) 
        else:
            traj_stats, F_traj, last_idx = out 
            loss = self.evaluate_loss_on_transitions(batch_state, traj_stats, F_traj, last_idx) 
        
        if return_target: 
            return loss, F_traj[batch_state.batch_ids, last_idx] 
        if return_trajectories: 
            return loss, trajectories 
        return loss, None 
    
    def evaluate_loss_on_transitions(
        self, batch_state, traj_stats, F_traj, last_idx
    ): 
        match self.criterion: 
            case 'tb': 
                loss = self._trajectory_balance(batch_state, traj_stats, F_traj, last_idx) 
            case 'cb': 
                loss = self._contrastive_balance(batch_state, traj_stats, F_traj, last_idx) 
            case 'cbf': 
                loss = self._contrastive_balance_full(batch_state, traj_stats, F_traj, last_idx) 
            case 'db':
                loss = self._detailed_balance(batch_state, traj_stats, F_traj, last_idx)
            case 'subtb': 
                loss = self._subtrajectory_balance(batch_state, traj_stats, F_traj, last_idx) 
            case 'atb': 
                loss = self._amortized_trajectory_balance(batch_state, traj_stats, F_traj, last_idx) 
            case _: 
                raise ValueError(f'{self.criterion} should be either tb, cb, db, or dbc') 
        return loss 

    def _sample_traj(self, batch_state, return_gamma=False, return_trajectories=False, **kwargs):
        # dim 0: pf, dim 1: pb, dim 2: pf_exp 
        traj_stats = torch.zeros((3, batch_state.batch_size, batch_state.max_trajectory_length), device=self.device) 
        F_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length + 1), device=self.device) 
        trajectories = list() 
    
        # gamma 
        if return_gamma: 
            gamma = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length), device=self.device) 

        i = 0 
        last_idx = torch.zeros((batch_state.batch_size,), dtype=torch.long, device=self.device) 

        is_stopped = torch.zeros((batch_state.batch_size,), dtype=bool, device=self.device) 

        for _ in range(batch_state.max_trajectory_length): 
            if (batch_state.stopped == 1).all(): break 
            # Sample the actions 
            out = self.pf(batch_state, **kwargs) 
            actions, pf, F, sp = out[0], out[1], out[2], out[3]  

            # Apply the actions  
            if not return_trajectories:
                batch_state_t = deepcopy(batch_state) 
                batch_state.apply(actions) 
                batch_state_tp1 = batch_state 
            else: 
                states_t = deepcopy(batch_state) 
                batch_state.apply(actions) 
                states_tp1 = deepcopy(batch_state)
                trajectories.append(
                    (states_t, actions, states_tp1) 
                ) 
            # Corresponding backward actions  
            out = self.pb(batch_state, actions) 
            pb = out[1] 

            # Save values 
            traj_stats[0, ~is_stopped, i] = pf[~is_stopped] 
            traj_stats[1, ~is_stopped, i] = pb[~is_stopped]
            traj_stats[2, ~is_stopped, i] = sp[~is_stopped] # sampling policy 
            F_traj[~is_stopped, i] = F[~is_stopped] 
            if return_gamma: 
                gamma[~is_stopped, i] = self.gamma_func(batch_state_t, batch_state_tp1)[~is_stopped]
            # Check whether it already stopped 
            is_stopped = batch_state.stopped.bool()  
            i += 1 
            last_idx += (1 - batch_state.stopped.long())   
        
        F_traj[batch_state.batch_ids, last_idx + 1] = batch_state.log_reward(**kwargs) 
        if return_gamma: 
            return traj_stats, F_traj, last_idx + 1, gamma 
        if return_trajectories:  
            return traj_stats, F_traj, last_idx + 1, trajectories 
        return traj_stats, F_traj, last_idx + 1

    def _trajectory_balance(self, batch_state, traj_stats, F_traj, last_idx): 
        loss = (traj_stats[0] - traj_stats[1]).sum(dim=1) - F_traj[batch_state.batch_ids, last_idx] + self.log_z 
        return huber_loss(loss, delta=1.) # (loss*loss).mean()  

    def _amortized_trajectory_balance(self, batch_state, traj_stats, F_traj, last_idx): 
        loss = (
            (traj_stats[0] - traj_stats[1]).sum(dim=1) - \
                F_traj[batch_state.batch_ids, last_idx] + F_traj[batch_state.batch_ids, 0] 
        )
        return huber_loss(loss, delta=1.) 
    
    def _detailed_balance(self, batch_state, traj_stats, F_traj, last_idx, return_transition_loss=False): 
        loss = (traj_stats[0] + F_traj[:, :-1] - traj_stats[1] - F_traj[:, 1:]).pow(2) 
        loss_avg = (
            loss.sum(dim=1) / last_idx
        ).mean() 
        if not return_transition_loss: 
            return loss_avg 
        else: 
            return loss_avg, loss 
        # return huber_loss(loss, delta=1.) # (loss*loss).mean() 
    
    # pass 

    def _subtrajectory_balance(self, batch_state, traj_stats, F_traj, last_idx): 
        max_traj_length = batch_state.max_trajectory_length 
        traj_stats = torch.cat([torch.zeros((*traj_stats.shape[:-1], 1), device=traj_stats.device), traj_stats], dim=2) 
        i, j = torch.triu_indices(max_traj_length + 1, max_traj_length + 1, offset=1, device=traj_stats.device) 
        return _subtraj_loss(i, j, traj_stats, F_traj, last_idx) 
        
    def _contrastive_balance_full(self, batch_state, traj_stats, F_traj, last_idx): 
        loss = (traj_stats[0] - traj_stats[1]).sum(dim=1) - F_traj[batch_state.batch_ids, last_idx]
        loss = (loss[:, None] - loss[None, :])  
        return huber_loss(loss, delta=1.) 
    
    def _contrastive_balance(self, batch_state, traj_stats, F_traj, last_idx): 
        assert (batch_state.batch_size % 2) == 0, 'batch size must be even' 
        half_batch = batch_state.batch_size // 2 
        loss = (traj_stats[0] - traj_stats[1]).sum(dim=1) - F_traj[batch_state.batch_ids, last_idx] 
        loss = loss[:half_batch] - loss[half_batch:] 
        return huber_loss(loss, delta=1.) # (loss*loss).mean()   

    @torch.no_grad() 
    def sample(self, batch_state, seed=None, **kwargs): 
        while (batch_state.stopped < 1).any(): 
            if seed is not None: self.pf.set_seed(seed) 
            out = self.pf(batch_state, **kwargs) 
            actions = out[0] 
            batch_state.apply(actions)
            if seed is not None: self.pf.unset_seed()  
        return batch_state  
    
    @torch.no_grad() 
    def marginal_prob(self, batch_state, copy_env=False, perturb_params=False): 
        # Use importance sampling to estimate the marginal probabilities
        if copy_env: 
            batch_state = deepcopy(batch_state) 
        forward_log_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length), device=self.device) 
        backward_log_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length), device=self.device) 

        idx = 0 

        is_initial = torch.zeros(batch_state.batch_size, device=self.device, dtype=bool) 
        while not is_initial.all():
            # Estimate the backward probabilities  
            back_out = self.pb(batch_state) 
            actions, backward_log_prob = back_out[0], back_out[1] 
            
            forward_actions = batch_state.backward(actions) 

            # Estimate the forward probabilities
            forward_out = self.pf(batch_state, actions=forward_actions, perturb_params=perturb_params) 
            forward_log_prob = forward_out[1] 

            forward_log_traj[~is_initial, idx] = forward_log_prob[~is_initial]  
            backward_log_traj[~is_initial, idx] = backward_log_prob[~is_initial] 

            is_initial = batch_state.is_initial.bool()        
            idx += 1

        marginal_log = (forward_log_traj - backward_log_traj).sum(dim=1) 
        return marginal_log 

    def _flow_matching(self, batch_state): 
        '''
        This method assumes that the trajectory length is continuous for each state. 
        '''
        assert hasattr(batch_state, 'get_children') and hasattr(batch_state, 'get_parents')  
        loss = torch.zeros((batch_state.batch_size,), device=self.device) 
        while (batch_state.stopped < 1).any(): 
            # Sample the actions 
            actions = self.pf(batch_state)[0]  
            
            batch_state.apply(actions) 
            
            # Evaluate \sum_{s \in Pa(s')} F(s')p_{F}(s|s') 
            f_parent_p_f = torch.zeros((batch_state.batch_size, batch_state.max_num_parents), device=self.device) 
            idx = 0 
            for parent, forward_actions in batch_state.get_parents(): 
                flows = self.pf.mlp_flows(parent.unique_input).squeeze(dim=1)  
                logits = self.pf(parent, actions=forward_actions)[1]  
                f_parent_p_f[:, idx] = flows + logits  
                idx += 1 

            # Evaluate \sum_{s \in Ch(s')} F(s') p_{B}(s | s') 
            idx = 0 
            f_child_p_b = torch.zeros((batch_state.batch_size, batch_state.max_num_parents), device=self.device) 
            for child, backward_actions in batch_state.get_children(return_actions=True): 
                flows = self.pf.mlp_flows(child.unique_input).squeeze(dim=1) 
                logits = self.pb(child, actions=backward_actions)[1] 
                f_child_p_b[:, idx] = (flows + logits).nan_to_num(neginf=-1e8)  
                idx += 1 

            loss = loss + (
                torch.logsumexp(f_parent_p_f, dim=1) - torch.logsumexp(f_child_p_b, dim=1) 
            ).pow(2) 

        loss = loss / batch_state.max_trajectory_length 
        return loss.mean() 
    
    def sample_many_backward(self, batch_states, num_trajectories, perturb_params=False): 
        marginal_log = torch.zeros((batch_states.batch_size, num_trajectories), device=self.device) 
        for idx in range(num_trajectories): 
            marginal_log[:, idx] = self.marginal_prob(batch_states, copy_env=True, perturb_params=perturb_params) 
        return marginal_log  
    
    def evaluate_loss_on_trajectories(
        self, trajectories: List[Tuple[Environment, torch.Tensor, Environment]]
    ): 
        # Evaluate the neural network at each transition
        (sample, _, _) = trajectories[-1] 
        batch_size = sample.batch_size 
        max_trajectory_length = sample.max_trajectory_length 
        traj_stats = torch.zeros((2, batch_size, max_trajectory_length), device=self.device) 
        F_traj = torch.zeros((batch_size, max_trajectory_length + 1), device=self.device) 
        last_idx = torch.zeros((batch_size,), dtype=torch.long, device=self.device) 

        state_idx = 0 
        for (state_t, actions, state_tp1) in trajectories: # trajectories is a list of batches of states
            out = self.pf(state_t, actions=actions)  
            log_pf, log_F = out[1], out[2]  
            out_bkc = self.pb(state_tp1, actions=actions)
            log_pb = out_bkc[1] 

            traj_stats[0, state_tp1.stopped != 1, state_idx] = log_pf[state_tp1.stopped != 1] 
            traj_stats[1, state_tp1.stopped != 1, state_idx] = log_pb[state_tp1.stopped != 1] 
            F_traj[state_tp1.stopped != 1, state_idx] = log_F[state_tp1.stopped != 1] 
            
            state_idx += 1 

            last_idx += (state_tp1.stopped != 1)  

        F_traj[state_tp1.batch_ids, last_idx] = state_tp1.log_reward() 
        # hacky; rewards should be stored and not redundantly computed 
        loss = self.evaluate_loss_on_transitions(
            state_tp1, traj_stats, F_traj, last_idx
        )
        return loss 

    @torch.no_grad() 
    def evaluate_fcs_on_trajectories(
        self, trajectories, bucket_size, num_samples=32, num_trajectories=16, perturb_params=False
    ): 
        (_, _, terminal_states) = trajectories[-1] 
        running_results = 0. 
        for _ in range(num_samples): 
            samples = terminal_states.get(
                np.random.choice(terminal_states.batch_size, size=bucket_size, replace=False)
            ) 
            # Compute the learned marginal probability 
            marginal_log = self.sample_many_backward(samples, 
                                                     num_trajectories=num_trajectories, perturb_params=perturb_params) 
            # Compute the estimated marginal probability 
            log_reward = samples.log_reward() 

            marginal_prob, target_prob = marginal_dist(
                samples, marginal_log, log_reward, dim=0
            )
            running_results += (marginal_prob - target_prob).abs().sum() / 2  

        return running_results / num_samples 
    
    class OffPolicyCtx: 

        def __init__(self, gflownet): 
            self.gflownet = gflownet 
        
        def __enter__(self): 
            self.curr_eps = self.gflownet.pf.eps 
            self.gflownet.pf.eps = 1. 
        
        def __exit__(self, *unused_args): 
            self.gflownet.pf.eps = self.curr_eps 
        
    class OnPolicyCtx: 

        def __init__(self, gflownet): 
            self.gflownet = gflownet 

        def __enter__(self): 
            self.curr_eps = self.gflownet.pf.eps 
            self.gflownet.pf.eps = 0. 

        def __exit__(self, *unused_args): 
            self.gflownet.pf.eps = self.curr_eps 

    def off_policy(self): 
        return self.OffPolicyCtx(self) 
    
    def on_policy(self): 
        return self.OnPolicyCtx(self) 

class SALGFlowNet(GFlowNet): 

    def loss(self, batch_state, traj_stats, F_traj, last_idx): 
        if self.criterion == 'subtb': 
            return self._subtrajectory_balance(batch_state, traj_stats, F_traj, last_idx) 
        # By default, use ATB 
        return self._amortized_trajectory_balance(batch_state, traj_stats, F_traj, last_idx)  
    
    def forward(self, batch_state, return_trajectories=False,  **kwargs): 
        out = self.sample_traj(batch_state, return_trajectories, **kwargs) 
        if return_trajectories: 
            traj_stats, F_traj, last_idx, trajectories = out 
            loss = self.loss(batch_state, traj_stats, F_traj, last_idx) 
            return loss, trajectories 
        else: 
            traj_stats, F_traj, last_idx, _ = out 
            loss = self.loss(batch_state, traj_stats, F_traj, last_idx) 
            return loss 

    def sample_traj(self, batch_state, return_trajectories=False, **kwargs):
        # dim 0: pf, dim 1: pb, dim 2: pf_exp 
        traj_stats = torch.zeros((3, batch_state.batch_size, batch_state.max_trajectory_length), device=self.device) 
        F_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length + 1), device=self.device) 
        trajectories = list() 

        i = 0 
        last_idx = torch.zeros((batch_state.batch_size,), dtype=torch.long, device=self.device) 

        is_stopped = torch.zeros((batch_state.batch_size,), dtype=bool, device=self.device) 

        while (batch_state.stopped < 1.).any(): 
            # For the replay buffer, all trajectories should have the same length... 
            # for _ in range(batch_state.max_trajectory_length): 
            # Sample the actions 
            out = self.pf(batch_state, **kwargs) 
            actions, pf, F, sp = out[0], out[1], out[2], out[3]  

            # Apply the actions  
            if return_trajectories:
                states_t = deepcopy(batch_state) 
                batch_state.apply(actions) 
                states_tp1 = deepcopy(batch_state) 

                trajectories.append(
                    (states_t, actions, states_tp1) 
                )
            else: 
                batch_state.apply(actions)  

            # Corresponding backward actions  
            out = self.pb(batch_state, actions) 
            pb = out[1] 

            # Save values 
            traj_stats[0, ~is_stopped, i] = pf[~is_stopped] 
            traj_stats[1, ~is_stopped, i] = pb[~is_stopped]
            traj_stats[2, ~is_stopped, i] = sp[~is_stopped] # sampling policy 
            F_traj[~is_stopped, i] = F[~is_stopped] 

            # Check whether it already stopped 
            is_stopped = batch_state.stopped.bool()  
            i += 1 
            last_idx += (1 - batch_state.stopped.long())   
        
        F_traj[batch_state.batch_ids, last_idx + 1] = batch_state.log_reward(**kwargs) 
        if return_trajectories: 
            return traj_stats, F_traj, last_idx + 1, trajectories 
        else: 
            return traj_stats, F_traj, last_idx + 1 

    def _amortized_trajectory_balance(self, batch_state, traj_stats, F_traj, last_idx): 
        loss = (
            (traj_stats[0] - traj_stats[1]).sum(dim=1) - \
                F_traj[batch_state.batch_ids, last_idx] + F_traj[batch_state.batch_ids, 0] 
        )
        return huber_loss(loss, delta=1.) 
    
    @torch.no_grad() 
    def sample(self, batch_state, seed=None, **kwargs): 
        log_pf = list() 
        while (batch_state.stopped < 1).any(): 
            out = self.pf(batch_state, **kwargs) 
            actions = out[0]  
            batch_state.apply(actions)
            log_pf.append(out[1].view(-1, 1)) 
        return batch_state, torch.cat(log_pf, dim=1).sum(dim=1) 
       
    def sample_many_backward(self, batch_states, num_trajectories, **kwargs): 
        marginal_log = torch.zeros((batch_states.batch_size, num_trajectories), device=self.device) 
        for idx in range(num_trajectories): 
            marginal_log[:, idx] = self.marginal_prob(batch_states, copy_env=True, **kwargs) 
        return marginal_log  

    def sample_backward_traj(self, batch_state): 
        traj = list() 
        traj_pb = list() 
        traj_mask = list() 
        is_initial = torch.zeros(batch_state.batch_size, device=self.device, dtype=bool) 
        while not is_initial.all():
            # Estimate the backward probabilities  
            back_out = self.pb(batch_state) 
            backward_actions = back_out[0]  
            forward_actions = batch_state.backward(backward_actions) 
            state_t = deepcopy(batch_state) 
            traj.append(
                (state_t, forward_actions) 
            )
            traj_pb.append(back_out[1]) 
            traj_mask.append(is_initial) 
            is_initial = batch_state.is_initial.bool() 
        return traj, traj_pb, traj_mask 

    @torch.no_grad() 
    def evaluate_marginal_on_backward_traj(self, batch_state, num_trajectories=32, normalize=True, **kwargs):
        log_pf = torch.zeros((batch_state.batch_size, num_trajectories), device=self.device) 
        log_pb = torch.zeros((batch_state.batch_size, num_trajectories), device=self.device) 
        
        for traj_idx in range(num_trajectories): 
            terminal_states = deepcopy(batch_state) 
            traj, traj_pb, traj_mask = self.sample_backward_traj(terminal_states) 
            traj, traj_pb, traj_mask = traj[::-1], traj_pb[::-1], traj_mask[::-1]  
            state_running = deepcopy(traj[0][0])  
            assert state_running.is_initial.all() 
            for (idx, (state_t, forward_actions)) in enumerate(traj): 
                if (gflownets := kwargs.get('gflownets')): 
                    if isinstance(state_running, Hypergrid): 
                        from sal.gym.hypergrids import state_to_node 
                    elif isinstance(state_running, Sequences): 
                        from sal.gym.sequences import state_to_node 
                    elif isinstance(state_running, Set): 
                        from sal.gym.sets import state_to_node 
                    else: 
                        raise ValueError 
                    state_t.node_indices = state_to_node(state_running, len(gflownets)) 
                    # print(state_t.node_indices, state_t.max_depth) 
                log_pf_tran = self.pf(state_t, actions=forward_actions, **kwargs)[1] 
                state_running.apply(forward_actions)
                log_pf[~traj_mask[idx], traj_idx] += log_pf_tran[~traj_mask[idx]]
                log_pb[~traj_mask[idx], traj_idx] += traj_pb[idx][~traj_mask[idx]]

        log_pt = torch.logsumexp(log_pf - log_pb, dim=1) - np.log(num_trajectories) 
        if normalize: 
            pt = (log_pt - torch.logsumexp(log_pt, dim=0)).exp() 
            return pt 
        return log_pt

    
    def evaluate_loss_on_trajectories(
        self, trajectories: List[Tuple[Environment, torch.Tensor, Environment]], **kwargs 
    ): 
        # Evaluate the neural network at each transition
        (sample, _, _) = trajectories[-1] 
        batch_size = sample.batch_size 
        max_trajectory_length = sample.max_trajectory_length 
        traj_stats = torch.zeros((2, batch_size, max_trajectory_length), device=self.device) 
        F_traj = torch.zeros((batch_size, max_trajectory_length + 1), device=self.device) 
        last_idx = torch.zeros((batch_size,), dtype=torch.long, device=self.device) 

        state_idx = 0 
        for (state_t, actions, state_tp1) in trajectories: # trajectories is a list of batches of states
            out = self.pf(state_t, actions=actions, **kwargs)  
            log_pf, log_F = out[1], out[2]  
            out_bkc = self.pb(state_tp1, actions=actions)
            log_pb = out_bkc[1] 

            traj_stats[0, state_tp1.stopped != 1, state_idx] = log_pf[state_tp1.stopped != 1] 
            traj_stats[1, state_tp1.stopped != 1, state_idx] = log_pb[state_tp1.stopped != 1] 
            F_traj[state_tp1.stopped != 1, state_idx] = log_F[state_tp1.stopped != 1] 
            
            state_idx += 1 

            last_idx += (state_tp1.stopped != 1)  

        F_traj[state_tp1.batch_ids, last_idx + 1] = state_tp1.log_reward(**kwargs) 
        # hacky; rewards should be stored and not redundantly computed 
        loss = self._amortized_trajectory_balance(
            state_tp1, traj_stats, F_traj, last_idx + 1
        )
        return loss 

class EPGFlowNet(GFlowNet): 
    
    def set_num_sal_models(self, num_models): 
        self._num_models = num_models  

    @property 
    def num_models(self): 
        return self._num_models 

    def sample_traj(self, initial_states: Environment, **kwargs):
        trajectories = list() 
        while (initial_states.stopped < 1).any(): 
            if kwargs.get('gflownets') is not None:                 
                if isinstance(initial_states, Hypergrid): 
                    from sal.gym.hypergrids import state_to_node 
                elif isinstance(initial_states, Sequences): 
                    from sal.gym.sequences import state_to_node 
                elif isinstance(initial_states, Set): 
                    from sal.gym.sets import state_to_node 
                initial_states.node_indices = state_to_node(initial_states, self.num_models)

            out = self.pf(initial_states, **kwargs)
            state_t = deepcopy(initial_states) 
            initial_states.apply(out[0]) 
            state_tp1 = deepcopy(initial_states) 
            trajectories.append(
                (state_t, out[0], state_tp1) 
            )
        return trajectories 
     
    def forward(
            self, 
            trajectories: List[Tuple[Environment, torch.Tensor, Environment]], 
            client_sal_global: List[GFlowNet], 
            client_sal_local: List[List[GFlowNet]] 
        ):     
        return self.evaluate_loss_on_trajectories(trajectories, client_sal_global, client_sal_local)

    def evaluate_loss_on_trajectories(self, 
                                      trajectories: List[Tuple[Environment, torch.Tensor, Environment]], 
                                      client_sal_global: List[GFlowNet], 
                                      client_sal_local: List[List[GFlowNet]]
                        ):
        batch_size = trajectories[-1][-1].batch_size 
        log_pf = torch.zeros((len(client_sal_global) + 1, batch_size, len(trajectories)), device=self.device)
        log_pb = torch.zeros((len(client_sal_global) + 1, batch_size, len(trajectories)), device=self.device) 
        for traj_idx, (state_t, actions, state_tp1) in enumerate(trajectories): 
            # Compute log-probabilities for each model 
            stopped = state_tp1.stopped == 1 
            for client_idx, client_gfn in enumerate(client_sal_global): 
                local_gfn = client_sal_local[client_idx] 
                log_pf[client_idx, ~stopped, traj_idx] = client_gfn.pf(state_t, actions=actions, gflownets=local_gfn)[1][~stopped]
                log_pb[client_idx, ~stopped, traj_idx] = client_gfn.pb(state_tp1, actions=actions)[1][~stopped] 
            log_pf[-1, ~stopped, traj_idx] = self.pf(state_t, actions=actions)[1][~stopped] 
            log_pb[-1, ~stopped, traj_idx] = self.pb(state_tp1, actions=actions)[1][~stopped]

        loss = (log_pf[:-1] - log_pb[:-1]).sum(dim=(0, 2)) + (log_pb[-1] - log_pf[-1]).sum(dim=1) 
        return loss.var() 