import torch 
import torch.nn as nn 
import itertools 
import copy 
import networkx as nx 

from gfn.utils import Environment 

class LogReward(nn.Module): 

    def __init__(self, pi=[1., .5], device='cpu'): 
        super(LogReward, self).__init__() 
        self.pi = torch.tensor(pi, device=device) 

    def forward(self, batch_state): 
        return torch.log(
            self.pi[batch_state.indices]
        ) 

class RegGraph(Environment): 

    def __init__(self, source_dir, idx, batch_size, log_reward, device='cpu'):  
        super(RegGraph, self).__init__(batch_size, 1, log_reward, device) 
        self.source_path = source_dir  
    
        self.p = self.to_amat(f'{source_dir}/graph_{idx}_p.gml') 
        self.child_l = self.to_amat(f'{source_dir}/graph_{idx}_child_l.gml') 
        self.child_r = self.to_amat(f'{source_dir}/graph_{idx}_child_r.gml') 
        # Check which edges were added on each state 
        
        self.actions = torch.vstack([
            torch.argwhere(self.child_l - self.p == 1.).flatten(),   
            torch.argwhere(self.child_r - self.p == 1.).flatten() 
        ]) 
        assert self.actions.shape == (2, 4), 'child has more than 1 extra edge wrt parent' 
        self.actions = self.actions[:, :2] 

        self.children = torch.cat([self.child_l.unsqueeze(0), self.child_r.unsqueeze(0)], dim=0)
        
        self.num_nodes = self.p.shape[-1] 
        self.x = torch.ones((self.num_nodes, 1), device=self.device) 

        self.state = self.p.repeat(self.batch_size, 1 ,1) 
        self.indices = None 

    @torch.no_grad() 
    def get_children(self): 
        for action_idx, child_adj in enumerate(self.children): 
            child = copy.deepcopy(self) 
            child.state = child_adj.repeat(self.batch_size, 1, 1) 
            yield action_idx, child  

    def to_amat(self, gml_f): 
        nx_graph = nx.read_gml(gml_f)
        return torch.tensor(
            nx.adjacency_matrix(nx_graph).todense(), device=self.device  
        )

    @torch.no_grad()     
    def apply(self, actions): 
        # An action is a binary choice of going to the left or to the right 
        self.state = self.children[actions]
        self.stopped[:] = 1 
        self.is_initial[:] = 0 
        self.indices = actions 

    @torch.no_grad() 
    def backward(self, actions): 
        self.state = self.p.unsqueeze(0)[actions]  
        self.stopped[:] = 0 
        self.is_initial[:] = 1 
        indices = copy.deepcopy(self.indices) 
        self.indices = None 
        return indices 

    @property 
    @torch.no_grad() 
    def unique_input(self): 
        return self.state.flatten(start_dim=1)  

    @property 
    @torch.no_grad() 
    def data(self): 
        return self.x.repeat(self.batch_size, 1, 1).view(self.num_nodes * self.batch_size, 1)  

    @property 
    @torch.no_grad() 
    def edge_index(self): 
        indices = torch.argwhere(self.state == 1.)  
        return (self.num_nodes * indices[:, 0].view(-1, 1) + indices[:, 1:]).t() 