import collections
import itertools
import torch
import time
import numpy as np
import pickle
import torch.nn.functional as F
import networkx as nx
import random
from torch_geometric.data import Data, Batch

class CliqueDataset(object):
    def __init__(self, conf, mode='train', printer=None):
        self.dataset = conf.dataset.name
        self.path = conf.dataset.path
        self.mode = mode
        self.print = printer
        self.batch_size = conf.training.batch_size
        self.feature_type = conf.dataset.feature_type
        self.data_type = conf.dataset.data_type
        self.alignment_type = conf.dataset.alignment_type  
        self.max_node_set_size = conf.dataset.max_node_set_size
        self.max_edge_set_size = conf.dataset.max_edge_set_size
        self.device = conf.training.device
        self.load_graphs()
        
    def load_graphs(self):
        """
        self.corpus_graphs : list of nx graphs
        self.ground_truth : list of max_clique grount truth values
        self.query_graphs : list of complete nx graphs (fixed template)
        """
        # Load graphs and ground_truths
        if self.print:
            self.print(f'Loading graphs from {self.path}/{self.dataset}')
        self.corpus_graphs = pickle.load(open(f'{self.path}/{self.dataset}/graphs/{self.mode}.pkl', 'rb'))
        self.ground_truth = pickle.load(open(f'{self.path}/{self.dataset}/truth/{self.mode}.pkl', 'rb'))
        assert len(self.corpus_graphs) == len(self.ground_truth)
        print("Pkl load done")
        self.corpus_graph_node_sizes = [x.number_of_nodes() for x in self.corpus_graphs]
        self.corpus_graph_edge_sizes = [x.number_of_edges() for x in self.corpus_graphs]
        
        if self.dataset in ['aids_mcis', 'ptc_mm_mcis']:
            self.mcis_graph_pairs = pickle.load(open(f'{self.path}/{self.dataset}/tuple_of_graphs/{self.mode}.pkl', 'rb'))
            self.set1_graphs, self.set2_graphs = self.mcis_graph_pairs
            self.set1_node_sizes = [x.number_of_nodes() for x in self.set1_graphs]
            self.set2_node_sizes = [x.number_of_nodes() for x in self.set2_graphs]
            self.set1_edge_sizes = [x.number_of_edges() for x in self.set1_graphs]
            self.set2_edge_sizes = [x.number_of_edges() for x in self.set2_graphs]
            self.set1_adj_list = self.fetch_subgraph_adjacency_info(self.set1_graphs)
            self.set2_adj_list = self.fetch_subgraph_adjacency_info(self.set2_graphs)


        if self.data_type == 'gmn':
            self.max_query_template_size = 2 * int(max(self.ground_truth))
            self.query_graphs =  [nx.complete_graph(x) for x in range(2,self.max_query_template_size)]
            self.packed_query_graphs = self._pack_batch(self.query_graphs)
            self.query_graph_node_sizes = [x.number_of_nodes() for x in self.query_graphs]
            self.query_graph_edge_sizes = [x.number_of_edges() for x in self.query_graphs]

            self.query_adj_list = self.fetch_subgraph_adjacency_info(self.query_graphs)

        self.corpus_adj_list = self.fetch_subgraph_adjacency_info(self.corpus_graphs) 
        print("going to preprocess")
        self.preprocess_subgraphs_to_pyG_data() 

        self.list_all_truth = []
        for id in range(len(self.corpus_graphs)):
            self.list_all_truth.append(
                (id, self.ground_truth[id])
            )

    def create_pyG_data_object(self, g):
        assert self.feature_type == "One"
        x1 = torch.ones((g.number_of_nodes(), 1),device=self.device, dtype=torch.float32)
        edges = np.array(g.edges).T
        edges_flip = np.flip(edges, axis=0)
        edge_index = torch.tensor(np.concatenate([edges, edges_flip], axis=1),device=self.device,dtype=torch.long) 
        
        # TODO: save sizes and whatnot as per mode - node/edge
        return Data(x=x1, edge_index=edge_index), g.number_of_nodes()

    def preprocess_subgraphs_to_pyG_data(self):
        """
        self.query_graph_data_list
        self.query_graph_size_list
        self.corpus_graph_data_list
        self.corpus_graph_size_list
        """
        assert self.feature_type == "One"
        self.num_features = 1
        self.graph_data_list = []
        self.graph_size_list = []
        n_graphs = len(self.corpus_graphs)
        for i in range(n_graphs):
            data, size = self.create_pyG_data_object(self.corpus_graphs[i])
            self.graph_data_list.append(data)
            self.graph_size_list.append(size)
    
    def fetch_subgraph_adjacency_info(self, glist):
        """
        self.query_graph_adj_list
        self.corpus_graph_adj_list
        """
        graph_adj_list = []
        n_graphs = len(glist)
        for i in range(n_graphs):
            g = glist[i]
            x1 = torch.tensor(nx.adjacency_matrix(g).todense(), device=self.device, dtype=torch.float32)
            x2 = F.pad(
                x1,
                pad=(
                    0,
                    self.max_node_set_size - x1.shape[1],
                    0,
                    self.max_node_set_size - x1.shape[0],
                ),
            )
            graph_adj_list.append(x2)
        return graph_adj_list    
            
    def create_batches(self, shuffle, input_list=None):
        """
        create batches as is and return number of batches created
        shuffle: set to true when training. False during eval (if batching needed during eval)
        """
        if input_list is None:
            list_all = self.list_all_truth
        else:
            list_all = input_list

        if shuffle:
            random.shuffle(list_all)
        self.batches = []
        for i in range(0, len(list_all), self.batch_size):
            self.batches.append(list_all[i : i + self.batch_size])

        self.num_batches = len(self.batches)

        return self.num_batches


    def _pack_batch(self, graphs): 
        """Pack a batch of graphs into a single `GraphData` instance.
        Args:
            graphs: a list of generated networkx graphs.
        Returns:
            graph_data: a `GraphData` instance, with node and edge indices properly
            shifted.
        """
        from_idx = []
        to_idx = []
        graph_idx = []

        n_total_nodes = 0
        n_total_edges = 0
        for i, g in enumerate(graphs):
            n_nodes = g.number_of_nodes()
            n_edges = g.number_of_edges()
            edges = np.array(g.edges(), dtype=np.int32)
            # shift the node indices for the edges
            from_idx.append(edges[:, 0] + n_total_nodes)
            to_idx.append(edges[:, 1] + n_total_nodes)
            graph_idx.append(np.ones(n_nodes, dtype=np.int32) * i)

            n_total_nodes += n_nodes
            n_total_edges += n_edges

        GraphData = collections.namedtuple(
            "GraphData",
            [
                "from_idx",
                "to_idx",
                "node_features",
                "edge_features",
                "graph_idx",
                "n_graphs",
            ],
        )


        return GraphData(
                    from_idx = torch.tensor(np.concatenate(from_idx, axis=0), dtype=torch.int64, device=self.device),
                    to_idx = torch.tensor(np.concatenate(to_idx, axis=0), dtype=torch.int64, device=self.device),
                    graph_idx = torch.tensor(np.concatenate(graph_idx, axis=0), dtype=torch.int64, device=self.device),
                    n_graphs = len(graphs),
                    node_features = torch.ones(n_total_nodes, 1, dtype=torch.float, device=self.device),
                    edge_features = torch.ones(n_total_edges, 1, dtype=torch.float, device=self.device)
                )

    def fetch_batched_data_by_id(self, i):
        """
        returns
        batch_corpus_graphs  : graph node, edge info
        all_sizes : this is required to create padding tensors for
                    batching variable size graphs
        batch_target    : maxclique ground truth values 
        """
        if i >= self.num_batches:
            raise IndexError('Fetched index is greater than number of batches')
        
        batch = self.batches[i]
        a, b = zip(*batch)
        corpus_idx_list = list(a)
        score = list(b)

        if self.data_type == 'gmn':
            g1 = [self.corpus_graphs[i] for i in corpus_idx_list]
        else:
            g1 = [self.graph_data_list[i] for i in corpus_idx_list]


        batch_corpus_node_size = [self.corpus_graph_node_sizes[i] for i in a]
        batch_corpus_edge_size = [self.corpus_graph_edge_sizes[i] for i in a]
        batch_corpus_adj = [self.corpus_adj_list[i] for i in a]

        if self.data_type == 'gmn':
            batch_corpus_graphs = self._pack_batch(g1)
        else:
            batch_corpus_graphs = Batch.from_data_list(g1)
    
        batch_target = torch.tensor(score, dtype=torch.float, device=self.device)
        return batch_corpus_graphs, batch_corpus_node_size, batch_corpus_edge_size, batch_target, batch_corpus_adj
    
   
    def fetch_batched_tuple_data_by_id(self, i):
        if i >= self.num_batches:
            raise IndexError('Fetched index is greater than number of batches')
        
        batch = self.batches[i]
        a, b = zip(*batch)
        idx_list = list(a)
        score = list(b)
                
        if self.data_type == 'gmn':
            g1 = [self.set1_graphs[i] for i in idx_list]
            g2 = [self.set2_graphs[i] for i in idx_list]
            g1_node_sizes = torch.tensor([self.set1_node_sizes[i] for i in idx_list], device=self.device)
            g2_node_sizes = torch.tensor([self.set2_node_sizes[i] for i in idx_list], device=self.device)
            g1_adj = [self.set1_adj_list[i] for i in idx_list]
            g2_adj = [self.set2_adj_list[i] for i in idx_list]

            all_data = self._pack_batch(
                list(itertools.chain.from_iterable(zip(g1, g2)))
            )
        else:
           raise NotImplementedError("Use GMN")
       
    
        return (
            all_data,
            list(zip(g1_node_sizes, g2_node_sizes)),
            torch.tensor(score, dtype=torch.float, device=self.device),
            list(zip(g1_adj, g2_adj))
        )
            
            
                
