import torch 
import torch.nn as nn 
import torch_geometric as pyg 

from gfn.utils import ForwardPolicyMeta, BaseNN 

class ForwardPolicy(ForwardPolicyMeta): 

    def __init__(self, input_dim, hidden_dim, num_layers, eps=.3, device='cpu'): 
        super(ForwardPolicy, self).__init__(eps=eps, device=device)
        self.input_dim = input_dim 
        self.hidden_dim = hidden_dim 
        self.num_layers = num_layers  

        self.model = list() 
        for layer_idx in range(self.num_layers): 
            # mlp = BaseNN(input_dim if layer_idx < 1 else hidden_dim, self.hidden_dim, num_layers=2).to(self.device) 
            self.model.append(pyg.nn.SGConv(
                in_channels=input_dim if layer_idx < 1 else hidden_dim,
                out_channels=hidden_dim, 
                K=1).to(self.device))
        self.model = nn.ModuleList(self.model) 
        self.mlp_logit = nn.Linear(self.hidden_dim, 1).to(self.device) 
        self.mlp_flows = nn.Linear(self.hidden_dim, 1).to(self.device)  
    
    def get_graph_emb(self, batch_state): 
        emb = batch_state.data  
        for layer in self.model: 
            emb = layer(emb, batch_state.edge_index) 
        emb = emb.view(batch_state.batch_size, batch_state.num_nodes, self.hidden_dim) 
        return emb 

    def get_latent_emb(self, batch_state): 
        graph_emb = self.get_graph_emb(batch_state) 
        edges_emb = graph_emb[:, batch_state.actions[:, 0]] + graph_emb[:, batch_state.actions[:, 1]] 
        return graph_emb.sum(dim=1), edges_emb 
    
    def get_pol(self, latent_emb, *args, **kwargs):
        graph_emb, edges_emb = latent_emb 
        pol = self.mlp_logit(edges_emb).squeeze(dim=-1).softmax(dim=1)  
        flows = self.mlp_flows(graph_emb).squeeze(dim=-1) 
        return pol, flows 

class ForwardPolicyLA(ForwardPolicy): 

    def __init__(self, input_dim, hidden_dim, num_layers, eps=.3, device='cpu'): 
        super(ForwardPolicyLA, self).__init__(input_dim, hidden_dim, num_layers, eps=.3, device=device) 
        self.mlp_logit = nn.Linear(2*self.hidden_dim, 1).to(self.device) 

    def get_latent_emb(self, batch_state): 
        actions = batch_state.actions 
        parent_emb = self.get_graph_emb(batch_state) 
        
        edges_emb = list() 
        for (action_idx, child) in batch_state.get_children(): 
            action = actions[action_idx] 
            child_emb = self.get_graph_emb(child) 
            
            parent_edge_emb = parent_emb[:, action[0]] + parent_emb[:, action[1]] 
            child_edge_emb = child_emb[:, action[0]] + child_emb[:, action[1]] 
            edges_emb.append(
                torch.hstack([parent_edge_emb, child_edge_emb]).unsqueeze(1)  
            )
        edges_emb = torch.cat(edges_emb, dim=1) 
        return parent_emb.sum(dim=1), edges_emb   

    def get_pol(self, latent_emb, *args, **kwargs): 
        parent_emb, edges_emb = latent_emb 
        pol = self.mlp_logit(edges_emb).squeeze(dim=-1).softmax(dim=1)  
        flows = self.mlp_flows(parent_emb).squeeze(dim=-1) 
        return pol, flows 

class BackwardPolicy(nn.Module): 

    def __init__(self, device='cpu'): 
        super(BackwardPolicy, self).__init__() 
        self.device = device 
    
    def forward(self, batch_state, actions=None): 
        batch_size = batch_state.batch_size
        actions = torch.zeros((batch_size,), device=self.device, dtype=int) 
        logits = torch.zeros((batch_size,), device=self.device) 
        return actions, logits 