import torch 
import torch.nn as nn 
import numpy as np 

from copy import copy, deepcopy 

def huber_loss(diff, delta=9.):
    diff_sq = (diff * diff)
    larger_than_delta = (diff_sq > delta ** 2).to(dtype=diff.dtype)
    return (
        delta * (diff_sq - .5 * delta) * larger_than_delta + \
        .5 * diff_sq * (1 - larger_than_delta)
    ).mean()

class OnPolicyCtx: 

    def __init__(self, gfn): 
        self.gfn = gfn 

    def __enter__(self): 
        self.curr_eps = copy(self.gfn.pf.eps)  
        self.gfn.pf.eps = 0.

    def __exit__(self, *args): 
        self.gfn.pf.eps = self.curr_eps 

class OffPolicyCtx: 

    def __init__(self, gfn): 
        self.gfn = gfn 

    def __enter__(self): 
        self.curr_eps = copy(self.gfn.pf.eps) 
        self.gfn.pf.eps = 1. 
    
    def __exit__(self, *args): 
        self.gfn.pf.eps = self.curr_eps 

class GFlowNet(nn.Module): 

    def __init__(self, pf, pb, n_traj, criterion='tb', device='cpu'): 
        super(GFlowNet, self).__init__() 
        self.pf = pf 
        self.pb = pb 
        self.n_traj = n_traj 
        self.criterion = criterion 
        self.device = device 

        self.log_z = nn.Parameter(torch.randn((1,), device=self.device).squeeze(), requires_grad=True) 

    def forward(self, batch_state):
        match self.criterion: 
            case 'tb': 
                return self._trajectory_balance(batch_state)  
            case 'kl': 
                return self._kl(batch_state) 

    def sample_traj(self, batch_state): 
        batch_size, traj_length = batch_state.batch_size, batch_state.max_trajectory_length 
        log_pf = torch.zeros((batch_size, traj_length), device=self.device)
        log_pb = torch.zeros((batch_size, traj_length), device=self.device) 

        stopped = batch_state.stopped == 1. 
        idx = 0 
        while not stopped.all(): 
            actions, log_pf_act = self.pf(batch_state) 
            batch_state.apply(actions) 
            log_pb_act = self.pb(batch_state, actions)[1] 
            log_pf[~stopped, idx] = log_pf_act[~stopped] 
            log_pb[~stopped, idx] = log_pb_act[~stopped] 
            
            idx += 1 
            stopped = (batch_state.stopped == 1.)  
        
        return log_pf.sum(dim=1), log_pb.sum(dim=1), batch_state.log_reward() 

    def sample_traj_iwae(self, batch_state): 
        batch_size, traj_length = batch_state.batch_size, batch_state.max_trajectory_length  
  
        log_pf_sample = torch.zeros((batch_size, traj_length), device=self.device) 
        stopped = (batch_state.stopped == 1)  
        idx = 0 
        while not stopped.all(): 
            actions, log_pf_act = self.pf(batch_state)
            batch_state.apply(actions) 
            log_pf_sample[~stopped, idx] = log_pf_act[~stopped]  
            stopped = batch_state.stopped.bool() 
            idx += 1 
        log_reward = batch_state.log_reward() 

        marginal_log = self.sample_many_backward(batch_state, self.n_traj) 

        return log_pf_sample.sum(dim=1), marginal_log, log_reward 

    def _trajectory_balance(self, batch_state): 
        log_pf, log_pb, log_reward = self.sample_traj(batch_state) 
        loss = (log_pf + self.log_z) - (log_pb + log_reward)
        return (loss * loss).mean() 
    
    def _kl(self, batch_state): 
        log_pf, log_pb, log_reward = self.sample_traj(batch_state) 
        return (log_pf - log_pb - log_reward), log_pf 

    @torch.no_grad() 
    def sample(self, batch_state, seed=None): 
        while (batch_state.stopped < 1).any(): 
            if seed is not None: self.pf.set_seed(seed) 
            out = self.pf(batch_state) 
            actions = out[0] 
            batch_state.apply(actions)
            if seed is not None: self.pf.unset_seed()  
        return batch_state  

    def marginal_prob(self, batch_state, copy_env=False): 
        # Use importance sampling to estimate the marginal probabilities
        if copy_env: 
            batch_state = deepcopy(batch_state) 
        forward_log_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length), device=self.device) 
        backward_log_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length), device=self.device) 

        idx = 0 

        is_initial = torch.zeros(batch_state.batch_size, device=self.device, dtype=bool) 
        while not is_initial.all():
            # Estimate the backward probabilities  
            back_out = self.pb(batch_state) 
            actions, backward_log_prob = back_out[0], back_out[1] 
            
            forward_actions = batch_state.backward(actions) 

            # Estimate the forward probabilities
            forward_out = self.pf(batch_state, actions=forward_actions) 
            forward_log_prob = forward_out[1] 

            forward_log_traj[~is_initial, idx] = forward_log_prob[~is_initial]  
            backward_log_traj[~is_initial, idx] = backward_log_prob[~is_initial] 

            is_initial = batch_state.is_initial.bool()        
            idx += 1

        marginal_log = (forward_log_traj - backward_log_traj).sum(dim=1) 
        return marginal_log 

    def sample_many_backward(self, batch_states, num_trajectories): 
        marginal_log = torch.zeros((batch_states.batch_size, num_trajectories), device=self.device) 
        for idx in range(num_trajectories): 
            marginal_log[:, idx] = self.marginal_prob(batch_states, copy_env=True) 
        return marginal_log  
    
    def on_policy(self): 
        return OnPolicyCtx(self) 

    def off_policy(self): 
        return OffPolicyCtx(self) 

class SBGFlowNet(GFlowNet): 

    def forward(self, batch_state, previous_model): 
        match self.criterion: 
            case 'tb': 
                return self._streaming_balance(batch_state, previous_model) 
            case 'kl': 
                return self._streaming_div(batch_state, previous_model) 
            case _: 
                raise Exception(f'criterion: {self.criterion}') 

    def sample_traj(self, batch_state, previous_model): 
        batch_size, traj_length = batch_state.batch_size, batch_state.max_trajectory_length 
        log_pf_t = torch.zeros((batch_size, traj_length), device=self.device)
        log_pb_t = torch.zeros((batch_size, traj_length), device=self.device) 
        log_pf_tp1 = log_pf_t.clone() 
        log_pb_tp1 = log_pb_t.clone() 

        stopped = batch_state.stopped == 1. 
        idx = 0 
        while not stopped.all(): 
            actions, log_pf_act_tp1 = self.pf(batch_state) 
            with torch.no_grad(): 
                log_pf_act_t = previous_model.pf(batch_state, actions=actions)[1]  
            batch_state.apply(actions) 
            log_pb_act_tp1 = self.pb(batch_state, actions)[1] 
            with torch.no_grad(): 
                log_pb_act_t = previous_model.pb(batch_state, actions)[1]  

            log_pf_tp1[~stopped, idx] = log_pf_act_tp1[~stopped] 
            log_pb_tp1[~stopped, idx] = log_pb_act_tp1[~stopped] 
            log_pf_t[~stopped, idx] = log_pf_act_t[~stopped] 
            log_pb_t[~stopped, idx] = log_pb_act_t[~stopped] 

            idx += 1 
            stopped = (batch_state.stopped == 1.)  
        
        return (
            log_pf_t.sum(dim=1), 
            log_pb_t.sum(dim=1), 
            log_pf_tp1.sum(dim=1), 
            log_pb_tp1.sum(dim=1),  
            batch_state.log_reward() 
        ) 

    def _streaming_balance(self, batch_state, previous_model): 
        log_pf_t, log_pb_t, log_pf_tp1, log_pb_tp1, log_rewards = self.sample_traj(batch_state, previous_model) 
        loss = (
            self.log_z + log_pf_tp1 + log_pb_t - \
            (previous_model.log_z + log_pf_t + log_pb_tp1 + log_rewards) 
        )
        return loss.pow(2).mean() 

    def _streaming_div(self, batch_state, previous_model): 
        log_pf_t, log_pb_t, log_pf_tp1, log_pb_tp1, log_rewards = self.sample_traj(batch_state, previous_model) 
        # Assumption 
        assert torch.isclose(log_pb_tp1, log_pb_t).all()  
        loss = log_pf_tp1 - (log_pf_t + log_rewards) 
        return loss, log_pf_tp1 
