import torch 
import torch.nn as nn 

import torch_geometric as pyg 

from streaming_gfn.utils import ForwardPolicyMeta 

class ForwardPolicy(ForwardPolicyMeta):

    def __init__(self, hidden_dim, num_leaves, eps=.3, device='cpu'):
        super(ForwardPolicy, self).__init__(eps=eps, device=device)
        self.hidden_dim = hidden_dim
        self.num_leaves = num_leaves
        self.device = device 
        self.gcn = pyg.nn.SGConv(in_channels=num_leaves+2, out_channels=hidden_dim, K=3).to(self.device) 
        self.mlp = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(), 
                                 nn.Linear(hidden_dim, (num_leaves * (num_leaves - 1) // 2))).to(self.device) 
        self.mlp_flows = nn.Linear(hidden_dim, 1).to(self.device) 
        
    def get_latent_emb(self, batch_state): 
        nodes_emb = self.gcn(batch_state.expanded_data, batch_state.edge_list_t())  
        nodes_emb = nodes_emb.reshape(batch_state.batch_size, -1, self.hidden_dim) 
        graph_emb = nodes_emb.sum(dim=1) 
        return graph_emb 
    
    def get_pol(self, latent_emb, mask): 
        logits = self.mlp(latent_emb) 
        logits = logits * mask + (1 - mask) * self.masked_value 
        pol = logits.softmax(dim=-1) 
        gflows = self.mlp_flows(latent_emb).squeeze(dim=-1)  
        return pol, gflows 

class ForwardPolicyMLP(ForwardPolicyMeta):

    def __init__(self, hidden_dim, num_leaves, eps=.3, device='cpu'):
        super(ForwardPolicyMLP, self).__init__(eps=eps)
        self.hidden_dim = hidden_dim
        self.num_leaves = num_leaves
        self.device = device 
        self.num_nodes = 2 * self.num_leaves - 1 
        self.mlp = nn.Sequential(
            nn.Linear((self.num_nodes * self.num_nodes), hidden_dim), nn.LeakyReLU(), 
            nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU() 
        ).to(self.device) 
        self.mlp_logit = nn.Linear(hidden_dim, (num_leaves * (num_leaves - 1) // 2)).to(self.device) 
        self.mlp_flows = nn.Linear(hidden_dim, 1).to(self.device) 
        
    def get_latent_emb(self, batch_state): 
        return self.mlp(batch_state.adjacency_matrix.flatten(start_dim=1)) 
    
    def get_pol(self, latent_emb, mask): 
        logits = self.mlp_logit(latent_emb) 
        logits = logits * mask + (1 - mask) * self.masked_value 
        pol = logits.softmax(dim=-1) 
        gflows = self.mlp_flows(latent_emb).squeeze(dim=-1)  
        return pol, gflows 

class ForwardPolicyMHA(ForwardPolicyMeta): 

    def __init__(self, hidden_dim, num_leaves, num_heads=8, eps=.3, device='cpu'): 
        super(ForwardPolicyMHA, self).__init__(eps=eps)
        self.hidden_dim = hidden_dim 
        self.device = device 
        self.num_nodes = 2 * num_leaves - 1  
        self.num_heads = num_heads 
        
        self.qkv = nn.Linear(
            self.num_nodes * self.num_nodes, 
            3 * hidden_dim * self.num_heads
        ).to(self.device)  
        
        self.mha = nn.MultiheadAttention(
            self.hidden_dim * self.num_heads, 
            self.num_heads, 
            batch_first=True 
        ).to(self.device) 

        self.mlp = nn.Sequential(
            nn.Linear(self.num_heads * self.hidden_dim, self.num_heads * self.hidden_dim // 2), 
            nn.LeakyReLU(),
            nn.Linear(self.num_heads * self.hidden_dim // 2, self.num_heads * self.hidden_dim // 4), 
            nn.LeakyReLU(), 
            nn.Linear(self.num_heads * self.hidden_dim // 4, (num_leaves * (num_leaves - 1) // 2))
        ).to(self.device) 

    def get_latent_emb(self, batch_state): 
        input_nn = batch_state.adjacency_matrix.flatten(start_dim=1) 
        q, k, v = torch.chunk(
            self.qkv(input_nn), chunks=3, dim=1
        ) 
        out, _ = self.mha(q, k, v, need_weights=False) 
        
        return self.mlp(out) 

    def get_pol(self, logits, mask): 
        logits = logits * mask + (1 - mask) * self.masked_value 
        # returns None for retrocompatibility 
        return logits.softmax(dim=-1), None  

class BackwardPolicy(nn.Module):    

    masked_value = -1e5 

    def __init__(self, device='cpu'): 
        super(BackwardPolicy, self).__init__() 
        self.device = device 
    
    def forward(self, batch_state, actions=None):
        uniform_pol = torch.where(batch_state.backward_mask[:, batch_state.num_leaves:] == 1., 1., self.masked_value)
        uniform_pol = uniform_pol.softmax(dim=-1) 
        if actions is None: 
            actions = torch.multinomial(uniform_pol, num_samples=1, replacement=True) 
            actions = actions.squeeze(dim=-1)  
        else: 
            actions = torch.ones((batch_state.batch_size,), 
                        dtype=torch.long) * (batch_state.next_node - batch_state.num_leaves - 1) 
        return actions, - torch.log(uniform_pol[batch_state.batch_ids, actions]) 
