import torch 
import torch.nn as nn 
from copy import deepcopy 

def huber_loss(diff, alpha=3.):
    loss = .5 * (diff < alpha) * (diff ** 2) + alpha * (diff >= alpha) * (diff.abs() - .5 * alpha)  
    return loss

class GFlowNet(nn.Module): 

    def __init__(self, forward_flow, backward_flow, state_flow=None, criterion='tb', off_policy_rate=.5): 
        super(GFlowNet, self).__init__() 
        self.forward_flow = forward_flow 
        self.backward_flow = backward_flow

        self.criterion = criterion 
    
        self.state_flow = state_flow 

        if criterion == 'tb': 
            self.log_partition_function = nn.Parameter(torch.randn((1,)).squeeze(), requires_grad=True) 
        
        self.off_policy_rate = off_policy_rate 
    
    def forward(self, batch_state): 
        match self.criterion: 
            case 'tb': 
                loss = self._trajectory_balance(batch_state) 
            case 'cb': 
                loss = self._contrastive_balance_even(batch_state) 
            case 'db':
                loss = self._detailed_balance(batch_state) 
            case 'dbc': # DB when every state is complete  
                loss = self._detailed_balance_complete(batch_state) 
            case 'fl': 
                loss = self._forward_looking(batch_state) 
            case _: 
                raise ValueError(f'{self.criterion} should be either tb, cb, db, fl or dbc') 
        return loss 

    def _trajectory_balance(self, batch_state): 
        loss = torch.zeros((batch_state.batch_size,), requires_grad=True) 

        while (batch_state.stopped < 1).any(): 
            unif = torch.rand((1,)).item() 
            # Sample an action for each batch 
            out = self.forward_flow(batch_state, (unif <= self.off_policy_rate)) 
            actions, forward_log_prob = out[0], out[1] 
        
            # Apply the actions and validate them 
            mask = batch_state.apply(actions) 
            
            # Compute the backward and forward transition probabilities 
            backward_log_prob = self.backward_flow(batch_state, actions) 

            # Should check this in different environments 
            loss = loss + torch.where(mask, (forward_log_prob.squeeze() - backward_log_prob.squeeze()), 0.) 
            # print(mask, backward_log_prob) 
        loss = loss + (self.log_partition_function - batch_state.log_reward()) 
        return (loss * loss).mean() 

    def _contrastive_balance(self, batch_state): 
        loss = torch.zeros((batch_state.batch_size,), requires_grad=True) 

        for sgn in [-1, 1]: 
            batch_state_sgn = deepcopy(batch_state) 
            while (batch_state_sgn.stopped < 1).any(): 
                unif = torch.rand((1,)).item() 
                # Sample an action for each batch 
                out = self.forward_flow(batch_state_sgn, (unif <= self.off_policy_rate)) 
                actions, forward_log_prob = out[0], out[1] 
                # Apply the actions and validate them 
                batch_state_sgn.apply(actions) 
                
                # Compute the backward and forward transition probabilities 
                backward_log_prob = self.backward_flow(batch_state_sgn, actions) 
                # Should check this in different environments 
                loss = loss + sgn * (forward_log_prob.squeeze() - backward_log_prob.squeeze())
            loss = loss - sgn * (batch_state_sgn.log_reward()) 
        return (loss * loss).mean() 


    def _contrastive_balance_even(self, batch_state): 
        assert (batch_state.batch_size % 2) == 0, 'the batch size must be even for CB' 
        half_batch = batch_state.batch_size // 2 
        loss = torch.zeros((half_batch,), requires_grad=True) 
    
        batch_state_sgn = deepcopy(batch_state) 
        while (batch_state_sgn.stopped < 1).any(): 
            unif = torch.rand((1,)).item() 
            # Sample an action for each batch 
            out = self.forward_flow(batch_state_sgn, (unif <= self.off_policy_rate)) 
            actions, forward_log_prob = out[0], out[1] 
            # Apply the actions and validate them 
            batch_state_sgn.apply(actions) 
            
            # Compute the backward and forward transition probabilities 
            backward_log_prob = self.backward_flow(batch_state_sgn, actions) 
            # Should check this in different environments 
            loss = loss + (forward_log_prob[:half_batch] - backward_log_prob[:half_batch]) 
            loss = loss - (forward_log_prob[half_batch:] - backward_log_prob[half_batch:]) 
        
        rewards = batch_state_sgn.log_reward() 
        loss = loss - (rewards[:half_batch] - rewards[half_batch:]) 
        return (loss * loss).mean() 

    def _detailed_balance(self, batch_state): 
        loss = torch.tensor(0., requires_grad=True) 

        while (batch_state.stopped < 1).any(): 
            unif = torch.rand((1,)).item() 
            # Forward actions  
            out = self.forward_flow(batch_state, off_policy=(unif <= self.off_policy_rate))
            actions, forward_log_prob = out[0], out[1] 

            current_state_flow = self.state_flow(batch_state) 

            mask = batch_state.apply(actions) 
            if mask.sum() == 0: break 

            # Backward actions
            next_state_flow = self.state_flow(batch_state) 
            backward_log_prob = self.backward_flow(batch_state, actions) 
            
            indices = (mask == 1) 
            # Update the loss 
            loss = loss + huber_loss(forward_log_prob.squeeze() \
                            + current_state_flow.squeeze() \
                            - backward_log_prob.squeeze() \
                            - next_state_flow.squeeze())[indices].mean() 

        # When states are complete, the detailed balance condition becomes F(s) P_{F}(s_{f} | s) = R(s)
        loss = loss + huber_loss(next_state_flow.squeeze() - batch_state.log_reward()).mean() 
        return loss 
        
    def _detailed_balance_complete(self, batch_state): 
        loss = torch.tensor(0., requires_grad=True) 

        while (batch_state.stopped < 1).any(): 
            unif = torch.rand(size=(1,)).item()  
            # Sample next action 
            out = self.forward_flow(batch_state, (unif < self.off_policy_rate)) 
            actions, forward_log_prob, forward_stop_log_prob = out[0], out[1], out[2] 
            forward_reward = batch_state.log_reward()

            # Update the state 
            mask = batch_state.apply(actions) 

            if mask.sum() == 0: break 

            # Compute the backward and stop probabilities 
            backward_log_prob = self.backward_flow(batch_state, actions) 
            _, _, backward_stop_log_prob = self.forward_flow(batch_state) 
            backward_reward = batch_state.log_reward() 

            # Compute the loss 
            balance_lhs = (forward_log_prob + forward_reward + backward_stop_log_prob).squeeze() 
            balance_rhs = (backward_log_prob + backward_reward + forward_stop_log_prob).squeeze() 

            # As the loss is trajectory-decomposable, this may not be necessary 
            balance_lhs = balance_lhs[mask == 1] 
            balance_rhs = balance_rhs[mask == 1] 
        
            loss = loss + huber_loss(balance_lhs - balance_rhs).mean() 
    
        return loss
   

    def _forward_looking(self, batch_state): 
        loss = torch.tensor(0., requires_grad=True) 

        while (batch_state.stopped < 1).any(): 
            unif = torch.rand((1,)).item() 
            # Forward actions  
            out = self.forward_flow(batch_state, off_policy=(unif <= self.off_policy_rate))
            actions, forward_log_prob = out[0], out[1] 

            current_state_flow = self.state_flow(batch_state) + batch_state.log_reward()  

            mask = batch_state.apply(actions) 
            if mask.sum() == 0: break 

            # Backward actions
            next_state_flow = self.state_flow(batch_state) + batch_state.log_reward() 
            backward_log_prob = self.backward_flow(batch_state, actions) 
            
            indices = (mask == 1) 
            # Update the loss 
            loss = loss + huber_loss(forward_log_prob.squeeze() \
                            + current_state_flow.squeeze() \
                            - backward_log_prob.squeeze() \
                            - next_state_flow.squeeze())[indices].mean() 

        # When states are complete, the detailed balance condition becomes F(s) P_{F}(s_{f} | s) = R(s)
        loss = loss + huber_loss(next_state_flow.squeeze() - batch_state.log_reward()).mean() 
        return loss 
        
    @torch.no_grad() 
    def sample(self, batch_state): 
        while (batch_state.stopped < 1).any(): 
            out = self.forward_flow(batch_state, off_policy=False) 
            actions = out[0] 
            batch_state.apply(actions) 
        return batch_state  

class GFlowNetEnsemble(nn.Module): 
    
    def __init__(self, forward_flow, backward_flow, gflownets, off_policy_rate=.5, state_flow=None, criterion='tb', log_pool_p=None):
        super(GFlowNetEnsemble, self).__init__() 
        self.forward_flow = forward_flow 
        self.backward_flow = backward_flow 
        self.state_flow = state_flow 

        self.gflownets = gflownets 

        self.criterion = criterion 
        if self.criterion == 'db': 
            raise NotImplementedError('use a criterion different from the detailed balance condition') 

        self.off_policy_rate = off_policy_rate
        self.num_clients = len(self.gflownets) 

        # Disable the gradients for the clients' networks 
        for gflownet in self.gflownets: 
            gflownet.forward_flow.requires_grad_(False) 
            gflownet.backward_flow.requires_grad_(False) 
            gflownet.requires_grad_(False) 

        # Instantiate the parameters for logarithmically pooling the distributions 
        if log_pool_p is None: 
            self.log_pool_p = torch.ones((self.num_clients,)) 
        else: 
            self.log_pool_p = log_pool_p 

        if criterion == 'tb': 
            self.log_partition_function = nn.Parameter(torch.randn((1,)).squeeze(), requires_grad=True) 
            self.log_partition_function_pooled = torch.dot(
                self.log_pool_p, 
                torch.tensor([gflownet.log_partition_function for gflownet in self.gflownets])
            ) 
        
    def forward(self, batch_state): 
        match self.criterion: 
            case 'tb': 
                loss = self._trajectory_balance(batch_state) 
            case 'db': 
                loss = self._detailed_balance(batch_state) 
            case 'cb': 
                loss = self._contrastive_balance_even(batch_state) 
            case 'dbc': # Detailed balance when all states are complete 
                loss = self._detailed_balance_complete(batch_state) 
            case _: 
                raise ValueError(f'{self.criterion} should be either tb, db, cb or dbc') 
        return loss 
    
    def _trajectory_balance(self, batch_state): 
        loss = torch.zeros((batch_state.batch_size,)) 
        while (batch_state.stopped < 1.).any(): 
            unif = torch.rand((1,)).item() 
            
            # Forward flows 
            out = self.forward_flow(batch_state, (unif <= self.off_policy_rate)) 
            actions, forward_log_prob = out[0], out[1] 
            fpc = torch.zeros((batch_state.batch_size, len(self.gflownets))) 
            for i, gflownet in enumerate(gflownets): 
                out = gflownet.forward_flow(batch_state, actions=actions) 
                forward_log_prob_clients = out[1] 
                fpc[:, i] = forward_log_prob_clients 

            loss = loss - forward_log_prob.squeeze() 

            batch_state.apply(actions) 

            # Backward flows 
            backward_log_prob = self.backward_flow(batch_state, actions) 

            bpc = torch.zeros((batch_state.batch_size, len(self.gflownets))) 
            # Asynchronously compute the forward and backward flows for the clients 
            for i, gflownet in enumerate(self.gflownets): 
                backward_log_prob_clients = gflownet.backward_flow(batch_state, actions) 
                bpc[:, i] = backward_log_prob_clients 
            
            loss = loss + backward_log_prob.squeeze() 
            
            loss = loss + (fpc.sum(dim=1) - bpc.sum(dim=1))

        loss = loss + self.log_partition_function_pooled  
        loss = loss - self.log_partition_function 

        return (loss * loss).mean() 


    def _detailed_balance(self, batch_state): 
        raise NotImplementedError 

    def _contrastive_balance(self, batch_state): 
        loss = torch.zeros((batch_state.batch_size,)) 
        for sgn in [-1, 1]: 
            batch_state_sgn = deepcopy(batch_state)
            while (batch_state_sgn.stopped < 1.).any():  
                unif = torch.rand((1,)).item()

                # Forward flows  
                out = self.forward_flow(batch_state_sgn, (unif <= self.off_policy_rate)) 
                actions, forward_log_prob = out[0], out[1] 

                # Forward flows for the clients 
                fpc = torch.zeros((batch_state_sgn.batch_size, len(gflownets))) 
                for i, gflownet in enumerate(self.gflownets): 
                    out = gflownet.forward_flow(batch_state_sgn, actions=actions) 
                    forward_log_prob_clients = out[1] 
                    fpc[:, i] = sgn * forward_log_prob_clients 
                fpc = fpc.sum(dim=1) 
                
                batch_state_sgn.apply(actions) 

                backward_log_prob = self.backward_flow(batch_state_sgn, actions) 

                bpc = torch.zeros((batch_state_sgn.batch_size, len(gflownets))) 
                for i, gflownet in enumerate(self.gflownets): 
                    backward_log_prob_clients = gflownet.backward_flow(batch_state_sgn, actions) 
                    bpc[:, i] = sgn * backward_log_prob_clients.squeeze() 
                bpc = bpc.sum(dim=1) 
                
                # Backward flows 
                loss = loss - sgn * (forward_log_prob.squeeze() - backward_log_prob.squeeze()) 
                loss = loss + (fpc - bpc)
            
        return (loss * loss).mean() 

    def _contrastive_balance_even(self, batch_state): 
        assert (batch_state.batch_size % 2) == 0, 'batch size must be an even number' 
        half_batch = batch_state.batch_size // 2   

        loss = torch.zeros((half_batch,)) 
        batch_state_sgn = deepcopy(batch_state)
        while (batch_state_sgn.stopped < 1.).any():  
            unif = torch.rand((1,)).item()

            # Forward flows  
            out = self.forward_flow(batch_state_sgn, (unif <= self.off_policy_rate)) 
            actions, forward_log_prob = out[0], out[1] 

            # Forward flows for the clients 
            fpc = torch.zeros((half_batch, )) 
            for i, gflownet in enumerate(self.gflownets): 
                out = gflownet.forward_flow(batch_state_sgn, actions=actions) 
                forward_log_prob_clients = out[1] 
                fpc += forward_log_prob_clients[:half_batch] - forward_log_prob_clients[half_batch:]  
            
            batch_state_sgn.apply(actions) 

            backward_log_prob = self.backward_flow(batch_state_sgn, actions) 

            bpc = torch.zeros((half_batch, )) 
            for i, gflownet in enumerate(self.gflownets): 
                backward_log_prob_clients = gflownet.backward_flow(batch_state_sgn, actions) 
                bpc += backward_log_prob_clients[:half_batch] - backward_log_prob_clients[half_batch:]  

            # Backward flows 
            loss = loss - (forward_log_prob[:half_batch] - backward_log_prob[:half_batch]) 
            loss = loss + (forward_log_prob[half_batch:] - backward_log_prob[half_batch:])   
            loss = loss + (fpc - bpc)
        
        return (loss * loss).mean() 

    def _detailed_balance_complete(self, batch_state): 
        loss = torch.tensor(0., requires_grad=True) 

        while (batch_state.stopped < 1).any(): 
            unif = torch.rand(size=(1,)).item() 
            # Sample the next action 
            out = self.forward_flow(batch_state, (unif < self.off_policy_rate)) 
            action, forward_log_prob, forward_stop_log_prob = out[0], out[1], out[2] 
            # forward_reward = batch_state.log_reward() 

            # Compute the forward probabilities for the other GFlowNets 
            forward_log_prob_clients = torch.zeros((batch_state.batch_size,)) 
            forward_log_stop_prob_clients = torch.zeros((batch_state.batch_size,)) 
            for idx, gflownet in enumerate(self.gflownets): 
                out = gflownet.forward_flow(batch_state, off_policy=False, actions=action)  
                forward_log_prob_gfn, forward_stop_log_prob_gfn = out[1], out[2] 
                forward_log_prob_clients += self.log_pool_p[idx] * forward_log_prob_gfn 
                forward_log_stop_prob_clients += self.log_pool_p[idx] * forward_stop_log_prob_gfn 

            # Update the state 
            mask = batch_state.apply(action) 

            # Compute the backward and stop probabilities 
            backward_log_prob = self.backward_flow(batch_state, action) 
            _, _, backward_stop_log_prob = self.forward_flow(batch_state) 

            backward_log_prob_clients = torch.zeros((batch_state.batch_size,)) 
            backward_stop_log_prob_clients = torch.zeros((batch_state.batch_size,)) 
            for idx, gflownet in enumerate(self.gflownets): 
                backward_log_prob_gfn = gflownet.backward_flow(batch_state, action) 
                out = gflownet.forward_flow(batch_state, off_policy=False, actions=action) 
                backward_stop_log_prob_gfn = out[2] 
                backward_log_prob_clients += self.log_pool_p[idx] * backward_log_prob_gfn 
                backward_stop_log_prob_clients += self.log_pool_p[idx] * backward_stop_log_prob_gfn 
            
            # Compute the loss 
            balance_lhs = forward_log_prob_clients + backward_stop_log_prob_clients - \
                        forward_log_stop_prob_clients - backward_log_prob_clients
            balance_rhs = forward_log_prob + backward_stop_log_prob - \
                        forward_stop_log_prob - backward_log_prob 

            balance_lhs = balance_lhs[mask == 1] 
            balance_rhs = balance_rhs[mask == 1] 

            loss = loss + ((balance_lhs - balance_rhs) ** 2).mean()  
        
        return loss 

    @torch.no_grad() 
    def sample(self, batch_state): 
        while (batch_state.stopped < 1.).any(): 
            out = self.forward_flow(batch_state, off_policy=False) 
            actions = out[0] 
            batch_state.apply(actions) 
        return batch_state 