import torch 
import torch.nn as nn 
import numpy as np 
from copy import deepcopy 

def huber_loss(diff, delta=9.):
    diff_sq = (diff * diff)
    larger_than_delta = (diff_sq > delta ** 2).to(dtype=diff.dtype)
    return (
        delta * (diff_sq - .5 * delta) * larger_than_delta + \
        .5 * diff_sq * (1 - larger_than_delta)
    ).mean()

class GFlowNet(nn.Module): 

    def __init__(self, pf, pb, gamma_func=None, lamb_reg=0., lamb_subtb=.9, criterion='tb', device='cpu'): 
        super(GFlowNet, self).__init__() 
        self.pf = pf
        self.pb = pb 
        self.log_z = nn.Parameter(torch.randn((1,), dtype=torch.get_default_dtype(), device=device).squeeze(), requires_grad=True) 
        
        # For the trajectory-decomposable discriminatory objective 
        self.gamma_func = gamma_func 
        self.lamb_reg = lamb_reg # regularization 

        self.criterion = criterion 
        self.device = device 
        self.lamb_subtb = lamb_subtb 

    def mask_from_traj(self, batch_state, last_idx):
        return last_idx.view(-1, 1) > torch.arange(
            batch_state.max_trajectory_length, device=self.device
        ).view(1, -1) 
            
    def forward(self, batch_state, return_target=False, return_transition_loss=False): 
        if self.criterion == 'dbc': 
            return self._detailed_balance_comp(batch_state), None 

        loss_gamma = torch.nan  

        if self.criterion == 'fm': 
            return self._flow_matching(batch_state), loss_gamma   
        
        assert not (return_transition_loss and self.criterion != 'db')  

        if self.criterion in ['td', 'regdb']: 
            traj_stats, F_traj, last_idx, gamma = self._sample_traj(batch_state, return_gamma=True) 
        else: 
            traj_stats, F_traj, last_idx = self._sample_traj(batch_state, return_gamma=False) 
        
        match self.criterion: 
            case 'tb': 
                loss = self._trajectory_balance(batch_state, traj_stats, F_traj, last_idx) 
            case 'cb': 
                loss = self._contrastive_balance(batch_state, traj_stats, F_traj, last_idx) 
            case 'cbf': 
                loss = self._contrastive_balance_full(batch_state, traj_stats, F_traj, last_idx) 
            case 'db':
                loss = self._detailed_balance(batch_state, traj_stats, F_traj, last_idx, return_transition_loss)
                if isinstance(loss, tuple): 
                    loss, loss_traj = loss  
            case 'subtb': 
                loss = self._subtrajectory_balance(batch_state, traj_stats, F_traj, last_idx) 
            case 'td': 
                loss = self._td3(batch_state, traj_stats, F_traj, last_idx, gamma) 
            case 'regdb': 
                loss = self._detailed_balance(batch_state, traj_stats, F_traj, last_idx) 
                reg = self._refine(batch_state, traj_stats, F_traj, last_idx, gamma) 
                loss = loss + self.lamb_reg * reg 
            case _: 
                raise ValueError(f'{self.criterion} should be either tb, cb, db, or dbc') 

        if self.criterion in ['td', 'regdb']: 
            loss_gamma = (gamma.sum(dim=1) - F_traj[batch_state.batch_ids, last_idx].view(-1, 1)).pow(2) 
            loss_gamma = loss_gamma.mean() # (loss_gamma.sum(dim=1)).mean()    

        if return_target: 
            return loss, F_traj[batch_state.batch_ids, last_idx] 
        
        if return_transition_loss: 
            return loss, loss_traj, last_idx 
        return loss, loss_gamma  
    
    def _sample_traj(self, batch_state, return_gamma=False):
        # dim 0: pf, dim 1: pb, dim 2: pf_exp 
        traj_stats = torch.zeros((3, batch_state.batch_size, batch_state.max_trajectory_length), device=self.device) 
        F_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length + 1), device=self.device) 

        # gamma 
        if return_gamma: 
            gamma = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length), device=self.device) 

        i = 0 
        last_idx = torch.zeros((batch_state.batch_size,), dtype=torch.long, device=self.device) 

        is_stopped = torch.zeros((batch_state.batch_size,), dtype=bool, device=self.device) 

        while (batch_state.stopped < 1).any(): 
            # Sample the actions 
            out = self.pf(batch_state) 
            actions, pf, F, sp = out[0], out[1], out[2], out[3]  

            # Apply the actions  
            batch_state_t = deepcopy(batch_state) 
            batch_state.apply(actions) 
            batch_state_tp1 = batch_state 

            # Corresponding backward actions  
            out = self.pb(batch_state, actions) 
            pb = out[1] 

            # Save values 
            traj_stats[0, ~is_stopped, i] = pf[~is_stopped] 
            traj_stats[1, ~is_stopped, i] = pb[~is_stopped]
            traj_stats[2, ~is_stopped, i] = sp[~is_stopped] # sampling policy 
            F_traj[~is_stopped, i] = F[~is_stopped] 
            if return_gamma: 
                gamma[~is_stopped, i] = self.gamma_func(batch_state_t, batch_state_tp1)[~is_stopped]
            # Check whether it already stopped 
            is_stopped = batch_state.stopped.bool()  
            i += 1 
            last_idx += (1 - batch_state.stopped).long()   
        
        F_traj[batch_state.batch_ids, last_idx + 1] = batch_state.log_reward() 
        if return_gamma: 
            return traj_stats, F_traj, last_idx + 1, gamma 
        return traj_stats, F_traj, last_idx + 1

    def _trajectory_balance(self, batch_state, traj_stats, F_traj, last_idx): 
        loss = (traj_stats[0] - traj_stats[1]).sum(dim=1) - F_traj[batch_state.batch_ids, last_idx] + self.log_z 
        return huber_loss(loss, delta=1.) # (loss*loss).mean()  

    def _detailed_balance(self, batch_state, traj_stats, F_traj, last_idx, return_transition_loss=False): 
        loss = (traj_stats[0] + F_traj[:, :-1] - traj_stats[1] - F_traj[:, 1:]).pow(2) 
        loss_avg = (
            loss.sum(dim=1) / last_idx
        ).mean() 
        if not return_transition_loss: 
            return loss_avg 
        else: 
            return loss_avg, loss 
        # return huber_loss(loss, delta=1.) # (loss*loss).mean() 
    
    def _subtrajectory_balance(self, batch_state, traj_stats, F_traj, last_idx): 
        max_traj_length = batch_state.max_trajectory_length 
        i, j = torch.triu_indices(max_traj_length + 1, max_traj_length + 1, offset=1, device=self.device) 
        
        traj_stats = torch.cat([torch.zeros((*traj_stats.shape[:-1], 1), device=self.device), traj_stats], dim=-1) 
        traj_stats = torch.cumsum(traj_stats, dim=-1) 
        pf = traj_stats[0, :, j] - traj_stats[0, :, i]  
        pb = traj_stats[1, :, j] - traj_stats[1, :, i] 
        loss = pf - pb + F_traj[:, i] - F_traj[:, j] 
        loss = loss * (last_idx.unsqueeze(1) >= j.unsqueeze(0)) 
        lamb = self.lamb_subtb ** (j - i).view(1, -1) * (last_idx.unsqueeze(1) >= j.unsqueeze(0)) 
        loss = ((loss * loss) * lamb).sum(dim=1) / lamb.sum(dim=1)   
        return loss.mean()  

    def _contrastive_balance_full(self, batch_state, traj_stats, F_traj, last_idx): 
        loss = (traj_stats[0] - traj_stats[1]).sum(dim=1) - F_traj[batch_state.batch_ids, last_idx]
        loss = (loss[:, None] - loss[None, :])  
        return huber_loss(loss, delta=1.) 
    
    def _contrastive_balance(self, batch_state, traj_stats, F_traj, last_idx): 
        assert (batch_state.batch_size % 2) == 0, 'batch size must be even' 
        half_batch = batch_state.batch_size // 2 
        loss = (traj_stats[0] - traj_stats[1]).sum(dim=1) - F_traj[batch_state.batch_ids, last_idx] 
        loss = loss[:half_batch] - loss[half_batch:] 
        return huber_loss(loss, delta=1.) # (loss*loss).mean()   

    def _td3(self, batch_state, traj_stats, F_traj, last_idx, gamma):
        loss = (F_traj[:, :-1] + traj_stats[0] - F_traj[:, 1:] - traj_stats[1]).pow(2)  

        mask = self.mask_from_traj(batch_state, last_idx).type(traj_stats.dtype) 
        gamma = (mask * gamma - 1e5 * (1 - mask)).softmax(dim=1)  
        loss = (gamma * loss).sum(dim=1) 
        return loss.mean()  
    
    def _refine(self, batch_state, traj_stats, F_traj, last_idx, gamma): 
        loss = (
            torch.detach(F_traj[:, :-1] - F_traj[:, 1:]) \
            + traj_stats[0] - traj_stats[1] 
        ).pow(2) 
        gamma = torch.softmax(gamma, dim=1) 
        loss = (gamma * loss).sum(dim=1)  
        return loss.mean() 

    @torch.no_grad() 
    def sample(self, batch_state, seed=None): 
        while (batch_state.stopped < 1).any(): 
            if seed is not None: self.pf.set_seed(seed) 
            out = self.pf(batch_state) 
            actions = out[0] 
            batch_state.apply(actions)
            if seed is not None: self.pf.unset_seed()  
        return batch_state  
    
    @torch.no_grad() 
    def marginal_prob(self, batch_state, copy_env=False): 
        # Use importance sampling to estimate the marginal probabilities
        if copy_env: 
            batch_state = deepcopy(batch_state) 
        forward_log_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length), device=self.device) 
        backward_log_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length), device=self.device) 

        idx = 0 

        is_initial = torch.zeros(batch_state.batch_size, device=self.device, dtype=bool) 
        while not is_initial.all():
            # Estimate the backward probabilities  
            back_out = self.pb(batch_state) 
            actions, backward_log_prob = back_out[0], back_out[1] 
            
            forward_actions = batch_state.backward(actions) 

            # Estimate the forward probabilities
            forward_out = self.pf(batch_state, actions=forward_actions) 
            forward_log_prob = forward_out[1] 

            forward_log_traj[~is_initial, idx] = forward_log_prob[~is_initial]  
            backward_log_traj[~is_initial, idx] = backward_log_prob[~is_initial] 

            is_initial = batch_state.is_initial.bool()        
            idx += 1

        marginal_log = (forward_log_traj - backward_log_traj).sum(dim=1) 
        return marginal_log 

    def _flow_matching(self, batch_state): 
        '''
        This method assumes that the trajectory length is continuous for each state. 
        '''
        assert hasattr(batch_state, 'get_children') and hasattr(batch_state, 'get_parents')  
        loss = torch.zeros((batch_state.batch_size,), device=self.device) 
        while (batch_state.stopped < 1).any(): 
            # Sample the actions 
            actions = self.pf(batch_state)[0]  
            
            batch_state.apply(actions) 
            
            # Evaluate \sum_{s \in Pa(s')} F(s')p_{F}(s|s') 
            f_parent_p_f = torch.zeros((batch_state.batch_size, batch_state.max_num_parents), device=self.device) 
            idx = 0 
            for parent, forward_actions in batch_state.get_parents(): 
                flows = self.pf.mlp_flows(parent.unique_input).squeeze(dim=1)  
                logits = self.pf(parent, actions=forward_actions)[1]  
                f_parent_p_f[:, idx] = flows + logits  
                idx += 1 

            # Evaluate \sum_{s \in Ch(s')} F(s') p_{B}(s | s') 
            idx = 0 
            f_child_p_b = torch.zeros((batch_state.batch_size, batch_state.max_num_parents), device=self.device) 
            for child, backward_actions in batch_state.get_children(return_actions=True): 
                flows = self.pf.mlp_flows(child.unique_input).squeeze(dim=1) 
                logits = self.pb(child, actions=backward_actions)[1] 
                f_child_p_b[:, idx] = (flows + logits).nan_to_num(neginf=-1e8)  
                idx += 1 

            loss = loss + (
                torch.logsumexp(f_parent_p_f, dim=1) - torch.logsumexp(f_child_p_b, dim=1) 
            ).pow(2) 

        loss = loss / batch_state.max_trajectory_length 
        return loss.mean() 
       
    def sample_many_backward(self, batch_states, num_trajectories): 
        marginal_log = torch.zeros((batch_states.batch_size, num_trajectories), device=self.device) 
        for idx in range(num_trajectories): 
            marginal_log[:, idx] = self.marginal_prob(batch_states, copy_env=True) 
        return marginal_log  

    class OffPolicyCtx: 

        def __init__(self, gflownet): 
            self.gflownet = gflownet 
        
        def __enter__(self): 
            self.curr_eps = self.gflownet.pf.eps 
            self.gflownet.pf.eps = 1. 
        
        def __exit__(self, *unused_args): 
            self.gflownet.pf.eps = self.curr_eps 
        
    class OnPolicyCtx: 

        def __init__(self, gflownet): 
            self.gflownet = gflownet 

        def __enter__(self): 
            self.curr_eps = self.gflownet.pf.eps 
            self.gflownet.pf.eps = 0. 

        def __exit__(self, *unused_args): 
            self.gflownet.pf.eps = self.curr_eps 

    def off_policy(self): 
        return self.OffPolicyCtx(self) 
    
    def on_policy(self): 
        return self.OnPolicyCtx(self) 

class LEDGFlowNet(GFlowNet):

    def __init__(self, pf, pb, phi, gamma=.1, criterion='tb', device='cpu'): 
        super(LEDGFlowNet, self).__init__(pf, pb, criterion=criterion, device=device) 
        self.phi = phi 
        self.gamma = torch.tensor(gamma, device=device) 
        self.bern = torch.distributions.Bernoulli(1 - self.gamma) 

        # Change off-policy exploration rate 
        self.pf.eps = 1e-2 
        
    # Override _sample_traj 
    def _sample_traj(self, batch_state):
        pf_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length), device=self.device) 
        pb_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length), device=self.device) 
        pt_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length), device=self.device) 

        F_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length + 1), device=self.device) 
        i = 0 
        last_idx = torch.zeros((batch_state.batch_size,), dtype=torch.long, device=self.device) 
        is_stopped = torch.zeros((batch_state.batch_size,), dtype=bool, device=self.device) 
        
        while (batch_state.stopped < 1).any(): 
            # Sample the actions 
            out = self.pf(batch_state) 
            actions, pf, F = out[0], out[1], out[2]  

            # Apply the actions  
            batch_state_t = deepcopy(batch_state) 
            batch_state.apply(actions) 
            batch_state_tp1 = batch_state 

            # Corresponding backward actions  
            out = self.pb(batch_state, actions) 
            pb = out[1] 

            # Save values 
            pf_traj[~is_stopped, i] = pf[~is_stopped] 
            pb_traj[~is_stopped, i] = pb[~is_stopped]

            pt_traj[~is_stopped, i] = self.phi(batch_state_t, batch_state_tp1)[~is_stopped]  

            F_traj[~is_stopped, i] = F[~is_stopped] 
            
            # Check whether it already stopped 
            is_stopped = batch_state_tp1.stopped.bool()
            i += 1 
            last_idx += (1 - batch_state.stopped).long()   
        
        F_traj[batch_state.batch_ids, last_idx + 1] = batch_state.log_reward() 

        return torch.cat([
            pf_traj.unsqueeze(0), 
            pb_traj.unsqueeze(0),
            pt_traj.unsqueeze(0)], dim=0), F_traj, last_idx + 1 
    
    # Override forward 
    def forward(self, batch_state, return_target=False): 
        traj_stats, F_traj, last_idx = self._sample_traj(batch_state) 

        # Dropout for learning the potentials 
        mask = self.bern.sample(sample_shape=traj_stats[2].shape) 
        # Loss for learning energy decompositions 
        loss_ls = (
            - F_traj[batch_state.batch_ids, last_idx] / last_idx - \
                (mask * traj_stats[2]).sum(dim=1) / (mask.sum(dim=1) + 1e-8) 
        ).pow(2).mean() 

        # Loss for training the policies 
        loss_led = traj_stats[0] + F_traj[:, :-1] - F_traj[:, 1:] - traj_stats[1] + traj_stats[2].detach() 
        loss_led = loss_led.pow(2).sum(dim=1).mean() 
        if return_target: 
            return loss_led, loss_ls, F_traj[batch_state.batch_ids, last_idx] 
        return loss_led, loss_ls  
