import sys, pathlib
import os, json, tqdm
import torch
import numpy as np 
from torch_geometric.utils import to_dense_adj
import spektral.data
from spektral.data import Dataset
import torch.nn.functional as F

if __name__ == "__main__":
    
    # Set package for relative imports
    if __package__ is None or len(__package__) == 0:                  
        DIR = pathlib.Path(__file__).resolve().parent.parent
        print(DIR)
        sys.path.insert(0, str(DIR.parent))
        __package__ = DIR.name


##############################################################################
#
#                              Dataset Code
#
##############################################################################

##############################################################################
# JSON STRUCTURE
##############################################################################
# ---final_metric_score
# ---optimized_hyperparamater_config
# ------ImageAugmentation:augment
# ------ImageAugmentation:autoaugment
# ------ImageAugmentation:cutout
# ------ImageAugmentation:cutout_holes
# ------ImageAugmentation:fastautoaugment
# ------NetworkSelectorDatasetInfo:darts:drop_path_prob
# ------NetworkSelectorDatasetInfo:network
# ------CreateImageDataLoader:batch_size
# ------ImageAugmentation:cutout_length
# ------LossModuleSelectorIndices:loss_module
# ------NetworkSelectorDatasetInfo:darts:auxiliary
# ------NetworkSelectorDatasetInfo:darts:edge_normal_0
# ------NetworkSelectorDatasetInfo:darts:edge_normal_1
# ------NetworkSelectorDatasetInfo:darts:edge_reduce_0
# ------NetworkSelectorDatasetInfo:darts:edge_reduce_1
# ------NetworkSelectorDatasetInfo:darts:init_channels
# ------NetworkSelectorDatasetInfo:darts:inputs_node_normal_3
# ------NetworkSelectorDatasetInfo:darts:inputs_node_normal_4
# ------NetworkSelectorDatasetInfo:darts:inputs_node_normal_5
# ------NetworkSelectorDatasetInfo:darts:inputs_node_reduce_3
# ------NetworkSelectorDatasetInfo:darts:inputs_node_reduce_4
# ------NetworkSelectorDatasetInfo:darts:inputs_node_reduce_5
# ------NetworkSelectorDatasetInfo:darts:layers
# ------OptimizerSelector:optimizer
# ------OptimizerSelector:sgd:learning_rate
# ------OptimizerSelector:sgd:momentum
# ------OptimizerSelector:sgd:weight_decay
# ------SimpleLearningrateSchedulerSelector:cosine_annealing:T_max
# ------SimpleLearningrateSchedulerSelector:cosine_annealing:eta_min
# ------SimpleLearningrateSchedulerSelector:lr_scheduler
# ------SimpleTrainNode:batch_loss_computation_technique
# ------SimpleTrainNode:mixup:alpha
# ------NetworkSelectorDatasetInfo:darts:edge_normal_11
# ------NetworkSelectorDatasetInfo:darts:edge_normal_13
# ------NetworkSelectorDatasetInfo:darts:edge_normal_2
# ------NetworkSelectorDatasetInfo:darts:edge_normal_3
# ------NetworkSelectorDatasetInfo:darts:edge_normal_5
# ------NetworkSelectorDatasetInfo:darts:edge_normal_6
# ------NetworkSelectorDatasetInfo:darts:edge_reduce_12
# ------NetworkSelectorDatasetInfo:darts:edge_reduce_13
# ------NetworkSelectorDatasetInfo:darts:edge_reduce_3
# ------NetworkSelectorDatasetInfo:darts:edge_reduce_4
# ------NetworkSelectorDatasetInfo:darts:edge_reduce_6
# ------NetworkSelectorDatasetInfo:darts:edge_reduce_7
# ---budget
# ---info
# ------train_loss
# ------train_accuracy
# ------train_cross_entropy
# ------val_accuracy
# ------val_cross_entropy
# ------epochs
# ------model_parameters
# ------learning_rate
# ------train_datapoints
# ------val_datapoints
# ------dataset_path
# ------dataset_id
# ------train_loss_final
# ------train_accuracy_final
# ------train_cross_entropy_final
# ------val_accuracy_final
# ------val_cross_entropy_final
# ---test_accuracy
# ---runtime
# ---optimizer
# ---learning_curves
# ------Train/train_loss
# ------Train/train_accuracy
# ------Train/train_cross_entropy
# ------Train/val_accuracy
# ------Train/val_cross_entropy

##############################################################################


OP_PRIMITIVES = [
    'identity',
    'max_pool_3x3',
    'avg_pool_3x3',
    'skip_connect',
    'sep_conv_3x3',
    'sep_conv_5x5',
    'dil_conv_3x3',
    'dil_conv_5x5'
]

OP_PRIMITIVES_NB301 = [
    'output', # very last output of reduction cell 0
    'input', #very frist input of normal cell 1
    'input_1', #second input of normal cell c_k-1 2
    'identity', #3
    'max_pool_3x3', #4
    'avg_pool_3x3', #5
    'skip_connect', #6
    'sep_conv_3x3', #7
    'sep_conv_5x5', #8
    'dil_conv_3x3', #9 
    'dil_conv_5x5' #10
]

OP_ONEHOT = {i:np.eye(8)[i] for i in range(8)}
OP_ONEHOT_BY_PRIMITIVE ={i:OP_ONEHOT[OP_PRIMITIVES.index(i)] for i in OP_PRIMITIVES}

def convert_matrix_ops_to_graph(matrix, ops):
    num_features = len(OP_PRIMITIVES_NB301)
    num_nodes = matrix.shape[0]

    # Node features X
    x = np.zeros((num_nodes, num_features), dtype=float)  # num_nodes * (features + metadata + num_layer)
    for i in range(num_nodes):
        # x[i][OP_ONEHOT_BY_PRIMITIVE[ops[i]]] = 1
        x[i][ops[i]] = 1

    # Adjacency matrix A
    a = np.array(matrix).astype(float)

    y = NasBench301Dataset.get_prediction(a, x)

    return spektral.data.Graph(x=x, a=a, y=np.array([y]))

def sort_edge_index(edge_index):
    num_nodes = np.max(edge_index)+1
    idx = edge_index[1] * num_nodes + edge_index[0]
    perm = idx.argsort()
    return edge_index[:, perm]

adj = 1-np.tri(7)
adj[:,6] =adj[:,1]=0
EDGE_LIST_ALL = np.array(np.nonzero(adj))
EDGE_LIST_ALL=  sort_edge_index(EDGE_LIST_ALL)


L = [(x,y) for x,y in EDGE_LIST_ALL.T] 
L_inverse = {(x,y): i for i,(x,y) in enumerate(L)}


# # Load NB301 Surrogate Model
# import os
from collections import namedtuple

import nasbench301 as nb

version = '0.9'

current_dir = os.path.dirname(os.path.abspath(__file__))
models_0_9_dir = os.path.join(current_dir, 'nb_models_0.9')
model_paths_0_9 = {
    model_name : os.path.join(models_0_9_dir, '{}_v0.9'.format(model_name))
    for model_name in ['xgb', 'gnn_gin', 'lgb_runtime']
}
models_1_0_dir = os.path.join(current_dir, 'nb_models_1.0')
model_paths_1_0 = {
    model_name : os.path.join(models_1_0_dir, '{}_v1.0'.format(model_name))
    for model_name in ['xgb', 'gnn_gin', 'lgb_runtime']
}
model_paths = model_paths_0_9 if version == '0.9' else model_paths_1_0

# If the models are not available at the paths, automatically download
# the models
# Note: If you would like to provide your own model locations, comment this out
if not all(os.path.exists(model) for model in model_paths.values()):
    nb.download_models(version=version, delete_zip=True,
                       download_dir=current_dir)

# Load the performance surrogate model
#NOTE: Loading the ensemble will set the seed to the same as used during training (logged in the model_configs.json)
#NOTE: Defaults to using the default model download path
print("==> Loading performance surrogate model...")
ensemble_dir_performance = model_paths['xgb']
print(ensemble_dir_performance)
from pathlib import Path
ensemble_member_dirs = [os.path.dirname(filename) for filename in Path(ensemble_dir_performance).rglob('*surrogate_model.model')]
performance_model = nb.load_ensemble(ensemble_member_dirs[0])

MODULE_DIR = os.path.dirname(os.path.abspath(__file__))
DEFAULT_DATA_DIR = os.path.join(MODULE_DIR, "nasbench301")

class NasBench301Dataset(Dataset):
            
    ##########################################################################
    def __init__(self, undirected = False, **kwargs):
        self.undirected = undirected
        super().__init__(**kwargs)

    def read(self):
        output = []
        topo_file = os.path.join(DEFAULT_DATA_DIR, "normal_cell_topologies.json")
        d = json.load(open(topo_file,"r"))
        data = list(d.values())
        flat_list = [item for sublist in data for item in sublist]
        for file in tqdm.tqdm(flat_list):##change to go through each folder!
            output.append(NasBench301Dataset.map_network(file, self.undirected))
            # if len(output) > 10000:
            #     break
        return output

    ##########################################################################
    @staticmethod
    def map_network_cell(item, cell="normal"):
        if "optimized_hyperparamater_config" in item:
            item = item["optimized_hyperparamater_config"]
        
        edge_u = []
        edge_v = []
        edge_attr = []
        for i in range(14):
            idx = f"NetworkSelectorDatasetInfo:darts:edge_{cell}_{i}"
            if idx in item:
                u, v = L[i]
                edge_u += [u]
                edge_v += [v]
                edge_attr += [OP_PRIMITIVES_NB301.index(item[idx])]

        edge_u    += [2, 3, 4, 5]
        edge_v    += [6]*4
        edge_attr += [OP_PRIMITIVES_NB301.index("identity")]*4
        
        edge_index = torch.tensor([edge_u, edge_v])
        edge_attr  = torch.tensor(edge_attr)
        
        return edge_index, edge_attr

    ##########################################################################
    @staticmethod
    def map_network_cell_like(normal_cell):

        def make_adj(edge_index):
            adj = torch.zeros(11, 11)

            adj[0][2]=1
            adj[1][3]=1

            adj[2][-1] = 1
            adj[3][-1] = 1
            adj[4][-1] = 1
            adj[5][-1] = 1
            adj[6][-1] = 1
            adj[7][-1] = 1
            adj[8][-1] = 1
            adj[9][-1] = 1

            ##Edge 3 : Node  4,5
            subgraph=[]
            for i in range(len(edge_index[0])):
                if edge_index[1][i]==3:
                    subgraph.append((edge_index[0][i].item(), edge_index[1][i].item()))    

            if subgraph[0][0]>=2:
                if subgraph[0][0]%2!=0:
                    adj[subgraph[0][0]+1][4]=1
                    adj[subgraph[0][0]+2][4]=1
                else:
                    adj[subgraph[0][0]][4]=1
                    adj[subgraph[0][0]+1][4]=1
            else:
                adj[subgraph[0][0]][4]=1
            if subgraph[1][0]>=2:
                if subgraph[1][0]%2!=0:
                    adj[subgraph[1][0]+1][5]=1
                    adj[subgraph[1][0]+2][5]=1
                else:
                    adj[subgraph[1][0]][5]=1
                    adj[subgraph[1][0]+1][5]=1
            else:
                adj[subgraph[1][0]][5]=1

            ##Edge 4 : Node  6,7
            subgraph=[]
            for i in range(len(edge_index[0])):
                if edge_index[1][i]==4:
                    subgraph.append((edge_index[0][i].item(), edge_index[1][i].item()))    


            if subgraph[0][0]>=2:
                if subgraph[0][0]%2!=0:
                    adj[subgraph[0][0]+1][6]=1
                    adj[subgraph[0][0]+2][6]=1
                else:
                    adj[subgraph[0][0]][6]=1
                    adj[subgraph[0][0]+1][6]=1
            else:
                adj[subgraph[0][0]][6]=1
            if subgraph[1][0]>=2:
                if subgraph[1][0]%2!=0:
                    adj[subgraph[1][0]+1][7]=1
                    adj[subgraph[1][0]+2][7]=1
                else:
                    adj[subgraph[1][0]][7]=1
                    adj[subgraph[1][0]+1][7]=1
            else:
                adj[subgraph[1][0]][7]=1

            ##Edge 5 : Node  8,9
            subgraph=[]
            for i in range(len(edge_index[0])):
                if edge_index[1][i]==5:
                    subgraph.append((edge_index[0][i].item(), edge_index[1][i].item()))    

            if subgraph[0][0]>=2:
                if subgraph[0][0]%2!=0:
                    adj[subgraph[0][0]+1][8]=1
                    adj[subgraph[0][0]+2][8]=1
                else:
                    adj[subgraph[0][0]][8]=1
                    adj[subgraph[0][0]+1][8]=1
            else:
                adj[subgraph[0][0]][8]=1
            if subgraph[1][0]>=2:
                if subgraph[1][0]%2!=0:
                    adj[subgraph[1][0]+1][9]=1
                    adj[subgraph[1][0]+2][9]=1
                else:
                    adj[subgraph[1][0]+2][9]=1
                    adj[subgraph[1][0]+3][9]=1
            else:
                adj[subgraph[1][0]][9]=1

            return adj

        normal_adj = make_adj(normal_cell[0])            
        edge_index = torch.nonzero(normal_adj, as_tuple=False).T

        node_attr = [OP_PRIMITIVES_NB301.index('input'), OP_PRIMITIVES_NB301.index('input_1')]
        node_attr.extend(normal_cell[1][:8])
        node_attr.extend([OP_PRIMITIVES_NB301.index('output')])
        
        return edge_index, torch.tensor(node_attr)

    @staticmethod
    def merge_normal_reduced(node_attr_normal, node_attr_reduce, edge_index_normal, edge_index_reduce):
        node_attr_normal = torch.cat((node_attr_normal, node_attr_reduce))
        for i in range(edge_index_reduce.shape[1]):
            edge_index_reduce[0][i] += 11
            edge_index_reduce[1][i] += 11

        edge_index_normal = torch.cat((edge_index_normal, edge_index_reduce), 1)

        return node_attr_normal, edge_index_normal

    @staticmethod
    def map_network(item, undirected):
        edge_index_normal, edge_attr_normal = NasBench301Dataset.map_network_cell(item, cell="normal")
        edge_index_reduce, edge_attr_reduce = NasBench301Dataset.map_network_cell(item, cell="reduce")
        edge_index_normal, node_attr_normal = NasBench301Dataset.map_network_cell_like((edge_index_normal, edge_attr_normal))
        edge_index_reduce, node_attr_reduce = NasBench301Dataset.map_network_cell_like((edge_index_reduce, edge_attr_reduce))

        x, edge_index = NasBench301Dataset.merge_normal_reduced(node_attr_normal, node_attr_reduce, edge_index_normal, edge_index_reduce)
        adj = to_dense_adj(edge_index)
        graph = convert_matrix_ops_to_graph(adj.numpy()[0], x)
        if undirected:
            graph.a = graph.a + graph.a.T - np.diag(graph.a.diagonal())

        return graph
    ##########################################################################
    @staticmethod
    def transform_node_atts_to_darts_cell(matrix):
        r'input already normal matrix or reduced!!'
        darts_adjacency_matrix = np.zeros((7, 7)) 

        darts_adjacency_matrix[:,0] = matrix[:7,0]
        darts_adjacency_matrix[:,1] = matrix[:7,1]
        d_c = 2
        for i in range(2,10,2):
            a = np.zeros((7))
            a[:2] = np.sum(matrix[:2,i:i+2],1)
            column_conc = np.sum(matrix[2:-1, i:i+2],1)
            row = [np.sum(column_conc[i:i+2]) for i in range(0,8,2)]
            a[2:-1] = row
            a[-1] = np.sum(matrix[-1,i:i+2])
            darts_adjacency_matrix[:,d_c] = a
            d_c += 1
        a = np.zeros((7))
        a[:2] = matrix[:2,-1]
        row = [np.sum(matrix[i:i+2,-1]) for i in range(2,10,2)]
        a[2:-1] = row
        a[-1] = matrix[-1,-1]
        darts_adjacency_matrix[:,-1] = a

        return darts_adjacency_matrix

    ##########################################################################
    @staticmethod
    def generate_genotype(matrix, ops):
        r'already normal or reduced matrix and ops'

        cell = []
        i = 0
        for ind in range(2,6):
            edge_0 = np.where(matrix[:,ind]==1)[0][0]
            edge_1 = np.where(matrix[:,ind]==1)[0][1]
            cell.append((ops[i],edge_0))
            cell.append((ops[i+1],edge_1))
            i += 1
        return cell
    
    @staticmethod
    def get_prediction(adj, ops):
        # to directed
        adj = np.triu(adj)
        
        Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')

        adjacency_matrix_n = adj[:11, :11]
        adjacency_matrix_r = adj[11:, 11:]
        ops = np.asarray(ops)
        if ops.ndim == 2:
            ops_n = np.argmax(ops[:11], axis=1)
            ops_n =[OP_PRIMITIVES_NB301[attr] for attr in ops_n]
            ops_r = np.argmax(ops[11:], axis=1)
            ops_r =[OP_PRIMITIVES_NB301[attr] for attr in ops_r]
        else :
            ops_n = [OP_PRIMITIVES_NB301[attr] for attr in ops[:11]]
            ops_r = [OP_PRIMITIVES_NB301[attr] for attr in ops[11:]]

        normal_darts_adj = NasBench301Dataset.transform_node_atts_to_darts_cell(adjacency_matrix_n[:11,:11]) 
            
        normal_darts_adj[normal_darts_adj >= 2] = 1

        reduce_darts_adj = NasBench301Dataset.transform_node_atts_to_darts_cell(adjacency_matrix_r[:11,:11]) 
        reduce_darts_adj[reduce_darts_adj >= 2] = 1

        ops_normal = ops_n[2:10]
        ops_reduce = ops_r[2:10]

        normal_cell = NasBench301Dataset.generate_genotype(normal_darts_adj, ops_normal)
        reduce_cell = NasBench301Dataset.generate_genotype(reduce_darts_adj, ops_reduce)

        genotype_config = Genotype(normal=normal_cell, normal_concat=[2, 3, 4, 5], reduce=reduce_cell, reduce_concat=[2, 3, 4, 5])
        # Predict
        prediction_genotype = performance_model.predict(config=genotype_config, representation="genotype", with_noise=False)
        acc = prediction_genotype/100.0
        try:
            normal_darts_adj = NasBench301Dataset.transform_node_atts_to_darts_cell(adjacency_matrix_n[:11,:11]) 

            normal_darts_adj[normal_darts_adj >= 2] = 1

            reduce_darts_adj = NasBench301Dataset.transform_node_atts_to_darts_cell(adjacency_matrix_r[:11,:11]) 
            reduce_darts_adj[reduce_darts_adj >= 2] = 1

            ops_normal = ops_n[2:10]
            ops_reduce = ops_r[2:10]

            normal_cell = NasBench301Dataset.generate_genotype(normal_darts_adj, ops_normal)
            reduce_cell = NasBench301Dataset.generate_genotype(reduce_darts_adj, ops_reduce)

            genotype_config = Genotype(normal=normal_cell, normal_concat=[2, 3, 4, 5], reduce=reduce_cell, reduce_concat=[2, 3, 4, 5])
            # Predict
            prediction_genotype = performance_model.predict(config=genotype_config, representation="genotype", with_noise=False)
            acc = prediction_genotype/100.0
        except:
            pass
        
        return acc

    def get_prediction_by_cell(normal_cell, reduce_cell):

        try:
            Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')

            genotype_config = Genotype(normal=normal_cell, normal_concat=[2, 3, 4, 5], reduce=reduce_cell, reduce_concat=[2, 3, 4, 5])
            # Predict
            prediction_genotype = performance_model.predict(config=genotype_config, representation="genotype", with_noise=False)
            acc = prediction_genotype/100.0
        except:
            pass
        
        return acc
    
    def get_hash_by_darts_cell(adj, ops, hash_version = 2):
        # to directed
        adj = np.triu(adj)

        adjacency_matrix_n = adj[:11, :11]
        adjacency_matrix_r = adj[11:, 11:]
        ops = np.asarray(ops)
        if ops.ndim == 2:
            ops_n = np.argmax(ops[:11], axis=1)
            ops_n = [OP_PRIMITIVES_NB301[attr] for attr in ops_n]
            ops_r = np.argmax(ops[11:], axis=1)
            ops_r = [OP_PRIMITIVES_NB301[attr] for attr in ops_r]
        else :
            ops_n = [OP_PRIMITIVES_NB301[attr] for attr in ops[:11]]
            ops_r = [OP_PRIMITIVES_NB301[attr] for attr in ops[11:]]

        normal_darts_adj = NasBench301Dataset.transform_node_atts_to_darts_cell(adjacency_matrix_n) 
        normal_darts_adj[normal_darts_adj >= 2] = 1
        reduced_darts_adj = NasBench301Dataset.transform_node_atts_to_darts_cell(adjacency_matrix_r) 
        reduced_darts_adj[reduced_darts_adj >= 2] = 1

        ops_normal = ops_n[2:10]
        ops_reduced = ops_r[2:10]

        normal_cell = NasBench301Dataset.generate_genotype(normal_darts_adj, ops_normal)
        reduced_cell = NasBench301Dataset.generate_genotype(reduced_darts_adj, ops_reduced)
        if hash_version == 0:   
            immutable_data = tuple(normal_cell) + tuple(reduced_cell)
        elif hash_version == 1:
            immutable_data = tuple(normal_cell)
        else:
            immutable_data = tuple(reduced_cell)
        hash_value = hash(immutable_data)
        
        return hash_value

    def get_hash_and_cell(adj, ops, hash_version = 2):
        # to directed
        adj = np.triu(adj)

        adjacency_matrix_n = adj[:11, :11]
        adjacency_matrix_r = adj[11:, 11:]
        ops = np.asarray(ops)
        if ops.ndim == 2:
            ops_n = np.argmax(ops[:11], axis=1)
            ops_n = [OP_PRIMITIVES_NB301[attr] for attr in ops_n]
            ops_r = np.argmax(ops[11:], axis=1)
            ops_r = [OP_PRIMITIVES_NB301[attr] for attr in ops_r]
        else :
            ops_n = [OP_PRIMITIVES_NB301[attr] for attr in ops[:11]]
            ops_r = [OP_PRIMITIVES_NB301[attr] for attr in ops[11:]]

        normal_darts_adj = NasBench301Dataset.transform_node_atts_to_darts_cell(adjacency_matrix_n) 
        normal_darts_adj[normal_darts_adj >= 2] = 1
        reduced_darts_adj = NasBench301Dataset.transform_node_atts_to_darts_cell(adjacency_matrix_r) 
        reduced_darts_adj[reduced_darts_adj >= 2] = 1

        ops_normal = ops_n[2:10]
        ops_reduced = ops_r[2:10]

        normal_cell = NasBench301Dataset.generate_genotype(normal_darts_adj, ops_normal)
        reduced_cell = NasBench301Dataset.generate_genotype(reduced_darts_adj, ops_reduced)
        if hash_version == 0:   
            immutable_data = tuple(normal_cell) + tuple(reduced_cell)
        elif hash_version == 1:
            immutable_data = tuple(normal_cell)
        else:
            immutable_data = tuple(reduced_cell)
        hash_value = hash(immutable_data)
        
        return hash_value, normal_cell, reduced_cell

##############################################################################
#
#                              Debugging
#
##############################################################################

if __name__ == "__main__":
    
    ds = NasBench301Dataset()