"""
The code allocates Wiener Indices by pickle / pt files
"""

import torch
from torch_geometric.data import Data
import math
import copy
import pickle
import os.path
from tqdm import tqdm


def split_list(l, n):
    for i in range(0, len(l), n): yield l[i:i + n]

 ##### TOPG index system #####
class WienerBook:
    def __init__(self, k, dataset_name, workload=1000, path='./wiener_files/'):
        self.k = k
        self.dataset_name = dataset_name
        self.workload=workload
        self.indexes = {(0, 0): 0} # for isolated nodes
        self.update = True
        self.path = path
        print("ALLOCATION: quantization parameter allocated with wiener index and indegree")

    def freeze_quantization_parameters(self):
        # This freezes the index dictionary. 
        # new keys will not be registered in the inference process.
        self.update = False
    
    def unfreeze_quantization_parameters(self):
        self.update = True
    
    def set_path(self, path):
        self.path = path # for 10-fold cross validation

    def load_wiener_indexes(self):
        fpath = self.path+'{}_k{}_indexes_train.pickle'.format(self.dataset_name, self.k)
        if os.path.isfile(fpath):
            with open(fpath, 'rb') as f:
                self.indexes = pickle.load(f)
            return True
        else: return False
    
    def save_wiener_indexes(self, idx):
        assert idx == "train"
        fpath = self.path+'{}_k{}_indexes_{}.pickle'.format(self.dataset_name, self.k, idx)
        if os.path.isfile(fpath):
            print("There already is a file, updating into current version.")
        with open(fpath, 'wb') as f:
                pickle.dump(self.indexes, f, pickle.HIGHEST_PROTOCOL)
        print("file: {}_k{}_indexes_{}.pickle - successfully saved".format(self.dataset_name, self.k, idx))
    
    def load_query_keys(self, idx):
        fpath = self.path+'{}_k{}_queries_{}.pt'.format(self.dataset_name, self.k, idx)
        if os.path.isfile(fpath):
            query_set = torch.load(fpath)
            return True, query_set
        else: return False, None
    
    def save_query_keys(self, query_set, idx):
        fpath = self.path+'{}_k{}_queries_{}.pt'.format(self.dataset_name, self.k, idx)
        if os.path.isfile(fpath):
            print("There already is a file, updating into current version.")
        torch.save(query_set, fpath)
        print("file: {}_k{}_queries_{}.pt - successfully saved".format(self.dataset_name, self.k, idx))

    def load_query_results(self, idx):
        fpath = self.path+'{}_k{}_results_{}.pickle'.format(self.dataset_name, self.k, idx)
        if os.path.isfile(fpath):
            with open(fpath, 'rb') as f:
                results = pickle.load(f)
            return True, results
        else: return False, None
    
    def save_query_results(self, results, idx):
        fpath = self.path+'{}_k{}_results_{}.pickle'.format(self.dataset_name, self.k, idx)
        if os.path.isfile(fpath):
            print("There already is a file, updating into current version.")
        with open(fpath, 'wb') as f:
                pickle.dump(results, f, pickle.HIGHEST_PROTOCOL)
        print("file: {}_k{}_results_{}.pickle - successfully saved".format(self.dataset_name, self.k, idx))
    
    
    # function to use in both tasks
    def allocate_indexes(self, query_sets, idx):
        # This finds the keys appropriate for /validation/test nodes
        print("VAL/TEST: Allocating WienerBook ... ")
        (have_legacy, results) = self.load_query_results(idx=idx)
        if not have_legacy: 
            print("QUERY RESULTS NOT available in path: {}".format(self.path))
            
        print("VAL/TEST: WienerBook allocation done!")
        assert -1 not in results
        return results

    def load_indexes(self, graph_dataset, idx="train"):
        if self.update: assert idx == "train"
        print("Creating WienerBook queries...")
        whole_results = None
        (have_legacy, query_sets) = self.load_query_keys(idx=idx)
        if not have_legacy:
            print("QUERY KEYS NOT available in path: {}".format(self.path))
        else:
            self.load_wiener_indexes()
            whole_indegree, _ = query_sets.transpose(0, 1)
            if idx == "train":
                _, whole_results = self.load_query_results(idx=idx)

        return whole_results, query_sets.tolist(), whole_indegree.tolist()
    
    
    def get_wiener_processed_dataset_graph_level(self, dataset, train=True, idx='train'):
        assert train == self.update
        
        results, queries, indegree = self.load_indexes(dataset, idx=idx)
        if not train: results = self.allocate_indexes(queries, idx=idx)

        processed_dataset = []
        print("Final Data Processing ... ")
        counter = 0
        t = tqdm(total=len(dataset), initial=1)  
        for i in range(len(dataset)):
            g = dataset[i]
            g.group_num = torch.tensor(results[counter:counter+g.num_nodes])
            g.indegree = torch.tensor(indegree[counter:counter+g.num_nodes])
            counter += g.num_nodes
            processed_dataset.append(g)
            t.update(1)
        assert counter == len(results)
        print("WienerBook Processing : ALL DONE")

        return processed_dataset

    
    def load_indexes_for_node_level(self, graph_dataset, train_idx=None, idx="train"):
        # train idx must contain True or False
        if self.update: assert idx == "train"
        print("Creating WienerBook queries...")
        # see if there are past pickles that have calculated the statistics
        (have_legacy, query_sets) = self.load_query_keys(idx="whole")
        if not have_legacy:
            print("QUERY KEYS NOT available in path: {}".format(self.path))
        else:
            self.load_wiener_indexes()
            whole_indegree, _ = query_sets.transpose(0, 1)

        return whole_indegree.tolist(), query_sets.tolist()
        

    def get_wiener_processed_dataset_node_level(self, dataset, train_idx=None, idx="train"):
        assert train_idx is not None

        # get quantization parameter sets - with updates with ONLY train nodes
        indegree, query_sets = self.load_indexes_for_node_level(dataset, train_idx=train_idx, idx=idx)
        results = self.allocate_indexes(query_sets, idx="whole")
        # process the dataset
        processed_dataset = []
        print("Final Data Processing ... ")
        counter = 0
        for i in range(len(dataset)):
            g = dataset[i]
            g.group_num = torch.tensor(results[counter:counter+g.num_nodes])
            g.indegree = torch.tensor(indegree[counter:counter+g.num_nodes])
            counter += g.num_nodes
            processed_dataset.append(g)
        assert counter == len(results)
        print("WienerBook Processing : ALL DONE")

        return processed_dataset