import torch 
import torch.nn as nn 

from sal.utils import ForwardPolicyMeta, BaseNN 
from sal.gym.hypergrids import Hypergrid, state_to_node 

class ForwardPolicy(ForwardPolicyMeta): 

    def __init__(self, dim, hidden_dim, num_layers, device='cpu', eps=.3): 
        super(ForwardPolicy, self).__init__(eps=eps, device=device) 
        self.dim = dim  
        self.hidden_dim = hidden_dim 
        self.device = device 
        self.num_layers = num_layers 

        self.mlp_logit = BaseNN(self.dim, self.hidden_dim, self.num_layers, self.dim + 1).to(self.device)  
        self.mlp_flows = BaseNN(self.dim, self.hidden_dim, self.num_layers, 1).to(self.device)  
    
    def get_latent_emb(self, batch_state: 'Hypergrid', gflownets=None): 
        # Convert everything to tensors if they're not already
        states = batch_state.state.to(torch.get_default_dtype())
        curr_indices = batch_state.cur_depth 
        max_indices = torch.ones((batch_state.batch_size,), device=self.device) * batch_state.max_depth 
        if gflownets is not None: 
            node_indices = state_to_node(batch_state, len(gflownets)) 
        else: 
            node_indices = None
 
        # Mask for selecting the model
        if gflownets is not None: 
            mask = curr_indices < max_indices 
        else: 
            mask = torch.ones_like(curr_indices).to(bool)  
        
        # Apply `self.mlp_logit` and `self.mlp_flows` to all states
        logit_all = self.mlp_logit(states)
        flows_all = self.mlp_flows(states)

        # Apply `models[node_idx].mlp_logit` and `models[node_idx].mlp_flows` where mask is False
        logit_model = torch.empty_like(logit_all)
        flows_model = torch.empty_like(flows_all)

        if gflownets is not None: 
            for idx in range(len(gflownets)):
                assert node_indices.max() < len(gflownets) 
                node_mask = (node_indices == idx) & ~mask
                if node_mask.any():
                    logit_model[node_mask] = gflownets[idx].pf.mlp_logit(states[node_mask])
                    flows_model[node_mask] = gflownets[idx].pf.mlp_flows(states[node_mask])

        # Combine the results
        logit_lst = torch.where(mask.unsqueeze(-1), logit_all, logit_model)
        flows_lst = torch.where(mask.unsqueeze(-1), flows_all, flows_model)

        return (
            torch.nan_to_num(logit_lst), flows_lst.squeeze(dim=-1)   
        )   

    def get_pol(self, logits_flows, mask):
        logits, flows = logits_flows
        logits = (logits * mask + (1 - mask) * self.masked_value) 
        pol = torch.softmax(logits, dim=-1) 
        return pol, flows

class BackwardPolicy(nn.Module): 

    masked_value = -1e5 

    def __init__(self, device='cpu'): 
        super(BackwardPolicy, self).__init__() 
        self.device = device 

    def forward(self, batch_state: 'Hypergrid', actions=None): 
        mask = batch_state.backward_mask
        pol = mask / mask.sum(dim=1, keepdims=True) 
        if actions is None: 
            actions = torch.zeros((batch_state.batch_size,), device=self.device, dtype=torch.long) 
            is_initial_state = torch.isnan(pol).all(dim=1) 
            is_stopped = batch_state.stopped == 1. 
            actions[(~is_initial_state) & (~is_stopped)] = torch.multinomial(
                pol[(~is_initial_state) & (~is_stopped)], num_samples=1, replacement=True).squeeze(dim=1)  
            actions[is_initial_state | is_stopped] = batch_state.dim
        is_stop_action = (actions == batch_state.dim) 
        log_pol = torch.zeros((batch_state.batch_size,), device=self.device) 
        log_pol[~is_stop_action] = torch.log(
            pol[batch_state.batch_ids[~is_stop_action], actions[~is_stop_action]] 
        )
        return actions, log_pol
