import torch 
import torch.nn as nn 

from copy import deepcopy 

from sal.utils import ForwardPolicyMeta, BaseNN
from sal.gym.sets import state_to_node 
from sal.pac_utils import BayesianMLP, MLP, BayesianPolicyMeta 


class ForwardPolicy(ForwardPolicyMeta): 

    def __init__(self, src_size, hidden_dim=128, num_layers=3, eps=.3, force_mask_idx=None, 
                 compute_chi_squared_div=False, device='cpu'): 
        super(ForwardPolicy, self).__init__(eps=eps, device=device)  
        self.device = device  
        self.mlp_logit = BaseNN(src_size, hidden_dim, num_layers, src_size).to(self.device) 
        # Output layers 
        self.mlp_flows = BaseNN(src_size, hidden_dim, num_layers, 1).to(self.device) 
        self.force_mask_idx = force_mask_idx 
        self.compute_chi_squared_div = compute_chi_squared_div 

    def get_latent_emb(self, batch_state, gflownets=None): 
        # Convert everything to tensors if they're not already
        states = batch_state.unique_input 
        curr_indices = torch.ones((batch_state.batch_size,), device=self.device) * batch_state.cur_depth 
        max_indices = torch.ones((batch_state.batch_size,), device=self.device) * batch_state.max_depth 
        if gflownets is not None: 
            node_indices = state_to_node(batch_state, len(gflownets)) 

        # Mask for selecting the model
        if gflownets is not None: 
            mask = curr_indices < max_indices 
        else: 
            mask = torch.ones_like(curr_indices).to(bool)  
        
        # Apply `self.mlp_logit` and `self.mlp_flows` to all states
        logit_all = self.mlp_logit(states)
        flows_all = self.mlp_flows(states)

        # Apply `models[node_idx].mlp_logit` and `models[node_idx].mlp_flows` where mask is False
        logit_model = torch.empty_like(logit_all)
        flows_model = torch.empty_like(flows_all)

        if gflownets is not None: 
            for idx in range(len(gflownets)):
                node_mask = (node_indices == idx) & ~mask
                if node_mask.any():
                    logit_model[node_mask] = gflownets[idx].pf.mlp_logit(states[node_mask])
                    flows_model[node_mask] = gflownets[idx].pf.mlp_flows(states[node_mask])

        # Combine the results
        logit_lst = torch.where(mask.unsqueeze(-1), logit_all, logit_model)
        flows_lst = torch.where(mask.unsqueeze(-1), flows_all, flows_model)

        return (
            logit_lst, flows_lst.squeeze(dim=-1)   
        )    
    
    def get_pol(self, logit_flows, mask):
        logit, flows = logit_flows  
        pol = (mask * logit + self.masked_value * (1 - mask)).softmax(dim=-1)  
        if self.force_mask_idx is not None and self.compute_chi_squared_div: 
            pol[:, self.force_mask_idx] = 1e-5 
            pol /= pol.sum(dim=1, keepdims=True) 
        return pol, flows 

# Python's MRO employs a linearized, left-to-right, method resolution for multiple inheritance 
class BayesianPolicy(BayesianPolicyMeta): 

    def __init__(self, src_size, hidden_dim=128, 
                 num_layers=2, eps=.3, device='cpu', 
                 force_mask_idx=None, **bayesian_kwargs): 
        super(BayesianPolicy, self).__init__(eps=eps, device=device, **bayesian_kwargs) 
        self.src_size = src_size 
        self.hidden_dim = hidden_dim 
        self.num_layers = num_layers 
        self.force_mask_idx = force_mask_idx  

        # this should be properly experimented and assessed 
        # there are two stages for training: first, learn a prior by minimizing 
        # a given objective; then, learn a posterior by minimizing the bound   
        self.mlp_logit_prior = MLP(src_size, src_size, [hidden_dim] * (num_layers - 1)).to(device) 
        self.mlp_logit_posterior = deepcopy(self.mlp_logit_prior)

        self.mlp_logit = BayesianMLP(
            self.mlp_logit_prior, self.mlp_logit_posterior, device=device, **bayesian_kwargs 
        ).to(self.device) 
        self.mlp_flows = MLP(src_size, 1, [hidden_dim] * (num_layers - 1)).to(device) 

    def get_latent_emb(self, batch_state): 
        return (
            self.mlp_logit(batch_state.unique_input), 
            self.mlp_flows(batch_state.unique_input).squeeze(dim=1)  
        )  
    
    def get_pol(self, logit_flows, mask):
        logit, flows = logit_flows  
        if self.force_mask_idx and self.sample_dataset_mode: 
            # forcefully mask actions 
            mask[:, self.force_mask_idx] = 0. 
        pol = (mask * logit + self.masked_value * (1 - mask)).softmax(dim=-1)  
        return pol, flows 

class ForwardPolicyLA(ForwardPolicy): 

    def __init__(self, src_size, hidden_dim=128, num_layers=3, eps=.3, device='cpu'): 
        super(ForwardPolicyLA, self).__init__(src_size, eps=eps, device=device) 
        self.mlp_emb = BaseNN(src_size, hidden_dim, num_layers).to(self.device)
        self.mlp_logit = BaseNN(2*hidden_dim, 2*hidden_dim, 1, 1).to(self.device) 
        self.mlp_flows = BaseNN(src_size, hidden_dim, num_layers, 1).to(self.device) 
        
    def get_latent_emb(self, batch_state): 
        nn_input = list() 
        p_emb = self.mlp_emb(batch_state.unique_input) 
        for child in batch_state.get_children(): 
            nn_input.append(torch.hstack([
                p_emb, self.mlp_emb(child.unique_input) 
            ]).unsqueeze(dim=1))  
        nn_input = torch.cat(nn_input, dim=1) 
        return (
            self.mlp_logit(nn_input).squeeze(dim=-1), 
            self.mlp_flows(batch_state.unique_input).squeeze(dim=-1) 
        )

class BackwardPolicy(nn.Module): 
    masked_value = -1e5 

    def __init__(self, device='cpu'): 
        super(BackwardPolicy, self).__init__() 
        self.device = device 
        
    def forward(self, batch_state, actions=None): 
        mask = batch_state.backward_mask 
        uniform_pol = (mask * batch_state.unique_input + self.masked_value * (1 - mask)).softmax(dim=-1) 
        if actions is None: 
            actions = torch.multinomial(uniform_pol, num_samples=1, replacement=True) 
            actions = actions.squeeze() 
        return actions, torch.log(uniform_pol[batch_state.batch_ids, actions])     
