import sys, pathlib, time
import os, glob, json, copy
import tqdm, hashlib
#import networkx as nx
import numpy as np
from nasbenchx11.nas_bench_x11.api import load_ensemble
from nasbenchx11.naslib.predictors.utils.encodings_nlp import encode_adj
from nasbenchx11.naslib.search_spaces.nasbenchnlp.conversions import convert_recipe_to_compact, \
    convert_compact_to_recipe
from Generator import scores_to_adj, adj_to_scores
import torch
from torch_geometric.data import Data
from torch_geometric.data import DataLoader
from torch_geometric.utils import to_dense_adj, subgraph
import torch.nn.functional as F
import spektral.data
from spektral.data import Dataset
# import Settings

# from Generator import scores_to_adj, adj_to_scores

import nasbenchnlpn_node_data as nodedata
print("================================")
print(nodedata.PATH_NBNLP)
nbnlp_surrogate_model = load_ensemble(nodedata.PATH_NBNLP)
NB_NLP_ops = ['output', 'input', 'activation_sigm', 'activation_tanh', 'activation_leaky_relu', \
              'elementwise_sum', 'elementwise_prod', 'linear', 'blend']

class NASBenchNLPDataset(Dataset):

    arch_recipce = [],
    new_arch_recipce=[],
    json_arch_recipce = [],
    nlparchjsonpath = nodedata.nlparchjsonpath
    print("nlparchjsonpath == ",nlparchjsonpath)
    dataset = 'ImageNet16-120'
    global_node = nodedata.global_node
    propertys = nodedata.propertys
    Match_data = []
    image_data='cifar10_valid_converged'
    arch_max_mode = 0
    x=[]
    adj=[]
    y=[]
    def __init__(self, undirected = False, **kwargs):
        self.undirected = undirected
        super().__init__(**kwargs)
        
    def read(self):
        self.arch_recipce = json.load(open(self.nlparchjsonpath, "r"))
        j = 0
        output = []
        for arch in self.arch_recipce:#find arch_be_used_max_node
            recepie = copy.deepcopy(arch['recepie'])
            del recepie['h_new_node']
            item = NASBenchNLPDataset.map_network(recepie,check_max=True)
            # compact = NASBenchNLPDataset.generate_nx11_compact(item)
            # arch['recepie']=NASBenchNLPDataset.convert_compact_to_recipe(compact)
        print('====== set dataset ======')
        print("max node",NASBenchNLPDataset.arch_max_mode)
        for arch in tqdm.tqdm(self.arch_recipce):
            recepie = copy.deepcopy(arch['recepie'])
            del recepie['h_new_node']
            print("NASBenchNLPDataset.set_y_data(recepie)",NASBenchNLPDataset.set_y_data(recepie))
            output.append(
              spektral.data.Graph(x=np.array(NASBenchNLPDataset.arch_Operation_Matrix(arch)),
                        a=np.array(NASBenchNLPDataset.arch_Adjacency_Matrix_for_node(arch)),
                        y=np.array(NASBenchNLPDataset.set_y_data(recepie))))
        """for myslef_data in output:
            print("===========================")
            for key, value in self.arch_recipce[j]['recepie'].items():
                # print(j,len(self.arch_recipce))
                print(key, value)
            for i in range(len(myslef_data[0])):
                print("%8s %s %s" % (list(self.global_node.keys())[i], myslef_data[0][i], myslef_data[1][i]))
            print("================================")
            if j < len(self.arch_recipce):
                j += 1"""
        return output

    def arch_Adjacency_Matrix_for_node(arch):
        recepie = arch["recepie"]
        node_matrix = []
        len_node = len(NASBenchNLPDataset.global_node)
        h_new_node = recepie['h_new_node'] if "h_new_node" in recepie else arch["h_new_node"]
        for i in range(len_node):  # row init
            coloumn_node = []
            for j in range(len_node):  # coloumn
                coloumn_node.append(0)
            node_matrix.append(coloumn_node)

        for key_recepie, value_recepie in recepie.items():
          if key_recepie != "h_new_node":
            if "input" in value_recepie:
              for node_input in value_recepie['input']:
                if key_recepie != "h_new_0" and key_recepie != "h_new_1":
                  if node_input != "h_new_0" and node_input != "h_new_1":  #
                      node_matrix[NASBenchNLPDataset.global_node[node_input]][NASBenchNLPDataset.global_node[key_recepie]] = 1
                  else:
                      node_matrix[NASBenchNLPDataset.global_node[h_new_node[node_input]]][NASBenchNLPDataset.global_node[key_recepie]] = 1
                else:
                  if node_input != "h_new_0" and node_input != "h_new_1":  #
                      node_matrix[NASBenchNLPDataset.global_node[node_input]][NASBenchNLPDataset.global_node[h_new_node[key_recepie]]] = 1
                  else:
                      node_matrix[NASBenchNLPDataset.global_node[h_new_node[node_input]]][
                          NASBenchNLPDataset.global_node[h_new_node[key_recepie]]] = 1
        return node_matrix
    def check_compact_match(item):
      max_nodes = 12
      data = NASBenchNLPDataset.map_network(item)
      compact = NASBenchNLPDataset.generate_nx11_compact(data)
      last_idx = len(compact[1]) - 1
      if last_idx <= max_nodes:
        return True
      return False
    def arch_Operation_Matrix( arch):
        #print("=============arch_Operation_Matrix=====arch===========")
        #for i,j in arch.items():
        #  print(i,j)
        recepie = arch["recepie"]
        operation_matrix = []
        propertys = NASBenchNLPDataset.propertys
        len_node = len(NASBenchNLPDataset.global_node)
        h_new_node = recepie['h_new_node'] if "h_new_node" in recepie else arch["h_new_node"]
        for i in range(len_node):  # row init
            coloumn_node = []
            for j in range(len(propertys)):  # coloumn
                coloumn_node.append(0)
            operation_matrix.append(coloumn_node)

        for key_recepie, value_recepie in recepie.items():
          if key_recepie != "h_new_node":
            if "op" in value_recepie:
              if key_recepie != "h_new_0" and key_recepie != "h_new_1":
                  #print("NASBenchNLPDataset.global_node[key_recepie]",NASBenchNLPDataset.global_node[key_recepie])
                  #print("value_recepie",value_recepie)
                  operation_matrix[NASBenchNLPDataset.global_node[key_recepie]][propertys[value_recepie["op"]]] = 1
              else:
                  operation_matrix[NASBenchNLPDataset.global_node[h_new_node[key_recepie]]][propertys[value_recepie["op"]]] = 1
            if "input" in value_recepie:
              if "x" in value_recepie["input"]:
                  operation_matrix[NASBenchNLPDataset.global_node["x"]][propertys["input"]] = 1
              if "h_prev_0" in value_recepie["input"]:
                  operation_matrix[NASBenchNLPDataset.global_node["h_prev_0"]][propertys["input"]] = 1
              if "h_prev_1" in value_recepie["input"]:
                  operation_matrix[NASBenchNLPDataset.global_node["h_prev_1"]][propertys["input"]] = 1
        for i in range(len(operation_matrix)):
          if 1 not in operation_matrix[i]:
            operation_matrix[i][propertys["without"]] = 1
        return operation_matrix

    def generate_nx11_compact(item):
        # Calculates a compact version of our dataset item to a tuple of edges, operations and hidden states
        # Needed for NAS-Bench-X11 Surrogate Model

        # edges

        max_ind = torch.max(item.edge_index)
        subset = list(range(max_ind))
        edge_index_sub = subgraph(subset, item.edge_index)[0]
        edge_index_sub, item.edge_index
        edges = tuple(tuple(edge.tolist()) for edge in edge_index_sub.T)

        # hidden:
        adj = to_dense_adj(item.edge_index)[0]
        h = tuple(np.nonzero(adj.numpy()[:, -1])[0])

        # ops:
        op_dict = ['in', 'activation_sigm', 'activation_tanh', 'activation_leaky_relu', \
                   'elementwise_sum', 'elementwise_prod', 'linear', 'blend']

        op = [0, 0, 0, 0]
        ops = [NB_NLP_ops[i] for i in item.x]
        inter_node = [i for i in ops if i not in ['output', 'input']]

        for i in inter_node:
            op.append(op_dict.index(i))
        o = tuple(op)

        return (edges, o, h)
    def check_map_network_node(item):
        edges, op, hidden = convert_recipe_to_compact(item)

        output_idx = len(op) + 1
        adj = torch.zeros(output_idx, output_idx)
        for edge in edges:
            adj[edge[0], edge[1]] = 1

        for edge in hidden:
            adj[edge, output_idx - 1] = 1

        edge_index = torch.nonzero(adj).T
        node_attr = torch.tensor([*[i + 1 for i in op], NB_NLP_ops.index('output')])
        num_nodes = len(node_attr)

        y_nodes = node_attr
        if len(node_attr) != 26:
            y_nodes = F.pad(y_nodes, pad=(0, 26 - len(y_nodes)), value=NB_NLP_ops.index('output'))

        x_binary = torch.nn.functional.one_hot(y_nodes, num_classes=len(NB_NLP_ops))

        scores = adj_to_scores(adj, l=325)
        if len(adj) <= 9:
          return True
        return False        
    
    def check_map_network_node9(item):
        edges, op, hidden = convert_recipe_to_compact(item)

        output_idx = len(op) + 1
        adj = torch.zeros(output_idx, output_idx)
        for edge in edges:
            adj[edge[0], edge[1]] = 1

        for edge in hidden:
            adj[edge, output_idx - 1] = 1

        edge_index = torch.nonzero(adj).T
        node_attr = torch.tensor([*[i + 1 for i in op], NB_NLP_ops.index('output')])
        num_nodes = len(node_attr)

        y_nodes = node_attr
        if len(node_attr) != 26:
            y_nodes = F.pad(y_nodes, pad=(0, 26 - len(y_nodes)), value=NB_NLP_ops.index('output'))

        x_binary = torch.nn.functional.one_hot(y_nodes, num_classes=len(NB_NLP_ops))

        scores = adj_to_scores(adj, l=325)
        if len(adj) <= 9:
          return True
        return False        
    
    def check_map_network_node10(item):
        edges, op, hidden = convert_recipe_to_compact(item)

        output_idx = len(op) + 1
        adj = torch.zeros(output_idx, output_idx)
        for edge in edges:
            adj[edge[0], edge[1]] = 1

        for edge in hidden:
            adj[edge, output_idx - 1] = 1

        edge_index = torch.nonzero(adj).T
        node_attr = torch.tensor([*[i + 1 for i in op], NB_NLP_ops.index('output')])
        num_nodes = len(node_attr)

        y_nodes = node_attr
        if len(node_attr) != 26:
            y_nodes = F.pad(y_nodes, pad=(0, 26 - len(y_nodes)), value=NB_NLP_ops.index('output'))

        x_binary = torch.nn.functional.one_hot(y_nodes, num_classes=len(NB_NLP_ops))

        scores = adj_to_scores(adj, l=325)
        if len(adj) == 10:
          return True
        return False        

    def show_map_network_node(item):
        edges, op, hidden = convert_recipe_to_compact(item)

        output_idx = len(op) + 1
        adj = torch.zeros(output_idx, output_idx)
        for edge in edges:
            adj[edge[0], edge[1]] = 1

        for edge in hidden:
            adj[edge, output_idx - 1] = 1

        return len(adj)     
    

    def check_map_network_node_find(item):
        #print("entry")
        edges, op, hidden = convert_recipe_to_compact(item)
        output_idx = len(op) + 1
        adj = torch.zeros(output_idx, output_idx)
        for edge in edges:
            adj[edge[0], edge[1]] = 1

        for edge in hidden:
            adj[edge, output_idx - 1] = 1
        edge_index = torch.nonzero(adj).T
        node_attr = torch.tensor([*[i + 1 for i in op], NB_NLP_ops.index('output')])
        num_nodes = len(node_attr)
        y_nodes = node_attr
        if len(node_attr) != 26:
            y_nodes = F.pad(y_nodes, pad=(0, 26 - len(y_nodes)), value=NB_NLP_ops.index('output'))

        x_binary = torch.nn.functional.one_hot(y_nodes, num_classes=len(NB_NLP_ops))

        scores = adj_to_scores(adj, l=325)
        if len(adj) >= 9:
          return True
        return False  

    def check_map_network_node_find_num(item,num):
    #print("entry")
        edges, op, hidden = convert_recipe_to_compact(item)
        output_idx = len(op) + 1
        adj = torch.zeros(output_idx, output_idx)
        for edge in edges:
            adj[edge[0], edge[1]] = 1

        for edge in hidden:
            adj[edge, output_idx - 1] = 1
        edge_index = torch.nonzero(adj).T
        node_attr = torch.tensor([*[i + 1 for i in op], NB_NLP_ops.index('output')])
        num_nodes = len(node_attr)
        y_nodes = node_attr
        if len(node_attr) != 26:
            y_nodes = F.pad(y_nodes, pad=(0, 26 - len(y_nodes)), value=NB_NLP_ops.index('output'))

        x_binary = torch.nn.functional.one_hot(y_nodes, num_classes=len(NB_NLP_ops))

        scores = adj_to_scores(adj, l=325)
        if len(adj) == num:
            return True
        return False   
    def check_map_network_node_find_more_num(item,num):
        #print("entry")
        edges, op, hidden = convert_recipe_to_compact(item)
        output_idx = len(op) + 1
        adj = torch.zeros(output_idx, output_idx)
        for edge in edges:
            adj[edge[0], edge[1]] = 1

        for edge in hidden:
            adj[edge, output_idx - 1] = 1
        edge_index = torch.nonzero(adj).T
        node_attr = torch.tensor([*[i + 1 for i in op], NB_NLP_ops.index('output')])
        num_nodes = len(node_attr)
        y_nodes = node_attr
        if len(node_attr) != 26:
            y_nodes = F.pad(y_nodes, pad=(0, 26 - len(y_nodes)), value=NB_NLP_ops.index('output'))

        x_binary = torch.nn.functional.one_hot(y_nodes, num_classes=len(NB_NLP_ops))

        scores = adj_to_scores(adj, l=325)
        if len(adj) >= num:
          return True
        return False   

    def check_map_network_node_find_less_number(item,num):
        #print("entry")
        edges, op, hidden = convert_recipe_to_compact(item)
        output_idx = len(op) + 1
        adj = torch.zeros(output_idx, output_idx)
        for edge in edges:
            adj[edge[0], edge[1]] = 1

        for edge in hidden:
            adj[edge, output_idx - 1] = 1
        edge_index = torch.nonzero(adj).T
        node_attr = torch.tensor([*[i + 1 for i in op], NB_NLP_ops.index('output')])
        num_nodes = len(node_attr)
        y_nodes = node_attr
        if len(node_attr) != 26:
            y_nodes = F.pad(y_nodes, pad=(0, 26 - len(y_nodes)), value=NB_NLP_ops.index('output'))

        x_binary = torch.nn.functional.one_hot(y_nodes, num_classes=len(NB_NLP_ops))

        scores = adj_to_scores(adj, l=325)
        if len(adj) <= num:
          return True
        return False 
    def remove_redundant_nodes(recepie):
        q = [f'h_new_{i}' for i in range(2)]

        visited = set(q)
        while len(q) > 0:
            if q[0] in recepie:
                for node in recepie[q[0]]['input']:
                    if node not in visited:
                        q.append(node)
                        visited.add(node)
            q = q[1:]

        for k in list(recepie.keys()):
            if k not in visited:
                del recepie[k]

        return visited
    
    def check_pass_arch(recepie):
        visited = NASBenchNLPDataset.remove_redundant_nodes(recepie)
        is_sanity_check_ok = True
        prev_hidden_nodes = [f'h_prev_{i}' for i in range(2)]
        base_nodes = ['x'] + prev_hidden_nodes
        # check that all input nodes are in the graph
        for node in base_nodes:
            if node not in visited:
                is_sanity_check_ok = False
                break
        #print("is_sanity_check_ok",is_sanity_check_ok)
        # constraint: prev hidden nodes are not connected directly to new hidden nodes
        for i in range(2):
            if len(set(recepie[f'h_new_{i}']['input']) & set(prev_hidden_nodes)) > 0:
                is_sanity_check_ok = False
                break
        return is_sanity_check_ok 
    
    def map_network(item,check_max=False):
        edges, op, hidden = convert_recipe_to_compact(item)

        output_idx = len(op) + 1
        adj = torch.zeros(output_idx, output_idx)
        for edge in edges:
            adj[edge[0], edge[1]] = 1

        for edge in hidden:
            adj[edge, output_idx - 1] = 1

        edge_index = torch.nonzero(adj).T
        node_attr = torch.tensor([*[i + 1 for i in op], NB_NLP_ops.index('output')])
        num_nodes = len(node_attr)

        y_nodes = node_attr
        if len(node_attr) != 26:
            y_nodes = F.pad(y_nodes, pad=(0, 26 - len(y_nodes)), value=NB_NLP_ops.index('output'))

        x_binary = torch.nn.functional.one_hot(y_nodes, num_classes=len(NB_NLP_ops))

        scores = adj_to_scores(adj, l=325)
        y = torch.cat((x_binary.reshape(-1).float(), scores.float()))
       # print(len(adj)"=========adj=========\n",adj,"\n")
        #if len(adj) == 14:
         # print("=======================\n",adj)
        if check_max == True and len(adj) >NASBenchNLPDataset.arch_max_mode:
          NASBenchNLPDataset.arch_max_mode = len(adj)
          #print("item\n",item)
          #print("adj\n",adj)
          #print("x_binary\n",x_binary)
        #if not check_max:
            #print("item\n",item)
        #    for key,value in item.items():
        #      print(key,value)
        #    print("adj\n",adj)
        return Data(edge_index=edge_index.long(), x=node_attr, x_binary=x_binary, num_nodes=num_nodes, y=y,
                    scores=scores)

    def get_info_generated_graph(item, dataset=None):
        # NAS-Bench-X11 Surrogate Model
        max_nodes = 12
        #print("item ::::", item)
        data_y=0
        #time.sleep(10)
        if hasattr(item, "acc"):
            data = item
        else: 
            compact = NASBenchNLPDataset.generate_nx11_compact(item)
            #NASBenchNLPDataset.convert_compact_to_recipe(compact)
            #time.sleep(3)
            arch = encode_adj(compact=compact, max_nodes=max_nodes, one_hot=False, accs=None)
            
            try:
                learning_curve = nbnlp_surrogate_model.predict(config=arch, representation='compact',
                                                with_noise=False, search_space='nlp')
            except ValueError:
                recipe = convert_compact_to_recipe(compact)
                # try:
                #     train_losses, val_losses, test_losses = main_one_model_train(recipe)
                #     assert len(val_losses) == 3

                # except Exception as e:
                val_losses = [6.5, 6.5, 6.5]
                accs = [100 - loss for loss in val_losses]
                arch = encode_adj(compact=compact, max_nodes=max_nodes, one_hot=False, accs=accs)
                learning_curve = nbnlp_surrogate_model.predict(config=arch, representation='compact',
                                                                with_noise=False, search_space='nlp')
            item.val_acc = torch.FloatTensor([learning_curve[-1] / 100.0])
            data_y=item.val_acc
            data = item
        #time.sleep(10)
        return data_y
    def convert_compact_to_recipe(compact):
      nodes = ['x', 'h_prev_0', 'h_prev_1', 'h_prev_2']
      op_dict = ['in', 'activation_sigm', 'activation_tanh', 'activation_leaky_relu', \
                'elementwise_sum', 'elementwise_prod', 'linear', 'blend']
      
      edges, ops, hiddens = compact
      max_node_idx = max([max(edge) for edge in edges])
      
      # create the set of node names
      reg_node_idx = 0
      hidden_node_idx = 0
      for i in range(len(nodes), max_node_idx + 1):
          if i not in hiddens:
              nodes.append('node_{}'.format(reg_node_idx))
              reg_node_idx += 1
          else:
              nodes.append('h_new_{}'.format(hidden_node_idx))
              hidden_node_idx += 1
      #print("hiddens",hiddens,"max_node_idx",max_node_idx)
      h_new_node = reg_node_idx
      recipe = {}
      for i in range(4, len(nodes)):
          node_dict = {}
          node_dict['op'] = op_dict[ops[i]]
          inputs = []
          for edge in edges:
              if edge[1] == i:
                  inputs.append(nodes[edge[0]])
          node_dict['input'] = inputs
          recipe[nodes[i]] = node_dict
      recipe["h_new_node"]={}
      for key,value in recipe.items():
      #  print(key,value)
        if key == "h_new_0" or key == "h_new_1":
          recipe["h_new_node"][key] = 'node_{}'.format(h_new_node)
          h_new_node+=1
      #print("new recipe\n",recipe)
      return recipe


    def change_to_recepie(adj, ops, hash_version = 2,retrain=False):
        # to directed
        if retrain:
          #print("ops====>",ops)
          new_ops = []
          for i in ops:
            opxlist=[]
            for j in range(len(nodedata.propertys)):
              if i==j:
                opxlist.append(1)
              else:
                opxlist.append(0)
            new_ops.append(opxlist)
          ops=new_ops
        #print("propertys",nodedata.propertys)
        #for i in range(len(nodedata.global_node)):
        #  print("%8s " %(list(nodedata.global_node.keys())[i]),end="")
        #  print(adj[i],ops[i])
        check_output_node = {}
        hnode = 1
        for i in range(len(ops)-1,2,-1):
          if 1 in ops[i] and ops[i][-1] != 1:
            check_output_node[list(nodedata.global_node.keys())[i]]="h_new_"+str(hnode)
            hnode-=1
          if hnode<0:
            break
        #print("check_output_node",check_output_node)#,"\n",nodedata.global_node.keys())
        matrix_to_dic={}
        for i in range(len(adj)):
          if 1 in adj[i]:
            for j in range(len(adj[i])):
              if adj[i][j]:
                if list(nodedata.global_node.keys())[j] in matrix_to_dic:
                  matrix_to_dic[list(nodedata.global_node.keys())[j]]["input"].append(list(nodedata.global_node.keys())[i])
                else:
                  matrix_to_dic[list(nodedata.global_node.keys())[j]]={"input":[]}
                  matrix_to_dic[list(nodedata.global_node.keys())[j]]["input"].append(list(nodedata.global_node.keys())[i])
        #print(matrix_to_dic)
        #print("==========matrix_to_dic==========")
        #for key,vale in matrix_to_dic.items():
        #  print(key,vale)
        #print("=======================")
        for i in range(3,len(ops)):
          if 1 in list(ops[i]) and list(ops[i])[-1] != 1:
            if list(nodedata.global_node.keys())[i] not in matrix_to_dic.keys():
              matrix_to_dic[list(nodedata.global_node.keys())[i]]={}
            matrix_to_dic[list(nodedata.global_node.keys())[i]]["op"]=list(nodedata.propertys.keys())[list(ops[i]).index(1)]
        #print(matrix_to_dic)
        #print("1",type(matrix_to_dic))
        #print("==============================")
        new_matrix_to_dic = {}
        for key,value in matrix_to_dic.items():
          if "input" in value:
            for node in range(len(value["input"])):
              if value["input"][node] in check_output_node:
                value["input"][node] = check_output_node[value["input"][node]]
          if key in check_output_node:
            new_matrix_to_dic[check_output_node[key]]=value
          else:
            new_matrix_to_dic[key]=value
        del matrix_to_dic
        #print("============new_matrix_to_dic=================")
        #for key,value in new_matrix_to_dic.items():
        #  print(key,value)
        #print("=-=-=-=-=-=-=-=-=-=-=-=-=-=-=")
        h_new_node = {}
        for key,value in check_output_node.items():
          h_new_node[value] = key
        arch={}
        arch["arch_id"] = hash(str(new_matrix_to_dic.items()))
        arch["recepie"]=new_matrix_to_dic
        arch["h_new_node"]=h_new_node
        return arch
   
    def set_y_data(item):
        #print("set_y_data",NASBenchNLPDataset.get_info_generated_graph(NASBenchNLPDataset.map_network(item),"image_data"))
        return NASBenchNLPDataset.get_info_generated_graph(NASBenchNLPDataset.map_network(item),"image_data")

if __name__ == "__main__":
    NBND = NASBenchNLPDataset()
