import numpy as np
import tensorflow as tf
import time,copy,hashlib
from NASBenchNLP_Dateset import NASBenchNLPDataset
import copy

def inverse_from_acc(model: tf.keras.Model, num_sample_z: int, x_dim: int, z_dim: int, to_inv_acc,
                     noise_std=0., version=2):
    batch_size = int(tf.shape(to_inv_acc)[0])
    try:
        num_nvp = model.num_nvp
    except:
        num_nvp = 1

    # current_time = int(time.time())
    # tf.random.set_seed(current_time)

    y = tf.repeat(to_inv_acc, num_sample_z, axis=0)  # (batch_size * num_sample_z, 1)
    noise = tf.random.normal(tf.shape(y), stddev=0.003)
    y = y + noise
    # z (batch_size * num_sample_z, z_dim)
    y = tf.concat([tf.random.normal((batch_size * num_sample_z, z_dim), stddev=1), y], axis=-1)  # (num_sample_z, z_dim + 1)

    rev_latent = model.inverse(y)  # (num_sample_z, num_nvp, latent_dim)
    if version == 1:
        raise NotImplementedError
        # rev_latent = rev_latent[:, :x_dim]
    elif version == 2:
        rev_latent = tf.reshape(rev_latent,
                                (batch_size * num_nvp, model.num_nodes, -1))  # (batch_size, num_sample_z, latent_dim)
    else:
        raise ValueError('version')

    _, adj, ops_cls, _ = model.decode(rev_latent + tf.random.normal(tf.shape(rev_latent),
                                                                          stddev=noise_std))  # (batch_size, num_sample_z, 8, 8), (batch_size, num_sample_z, 8, 7)
    ops_cls = tf.reshape(ops_cls,
                         (batch_size * num_nvp, num_sample_z, -1, model.num_ops))  # (batch_size, num_sample_z, 8, 7)
    ops_vote = tf.reduce_sum(ops_cls, axis=1).numpy()  # (batch_size, 1, 8 * 7)

    adj = tf.reshape(adj, (
    batch_size * num_nvp, num_sample_z, model.num_nodes, model.num_nodes))  # (batch_size, num_sample_z, 8, 8)
    adj = tf.where(tf.reduce_mean(adj, axis=1) >= 0.5, x=1., y=0.).numpy()  # (batch_size, 8, 8)

    ops_idx_list = [np.argmax(i, axis=-1).tolist() for i in ops_vote]
    adj_list = [i for i in adj]

    return ops_idx_list, adj_list, rev_latent

def query_by_encoding(model: tf.keras.Model, x_dim: int, z_dim: int, x_encoding, visited, query_amount=5, noise_scale=0.03, hash_version=2):
    found_arch_list = {}
    rev_latent = tf.reshape(x_encoding, (query_amount, model.num_nodes, -1))
    _, adj, ops_cls, _ = model.decode(rev_latent + tf.random.normal(tf.shape(rev_latent),
                                                                        stddev=0.01))
    ops_cls = tf.reshape(ops_cls,
                        (query_amount, 1, -1, model.num_ops))  # (batch_size, num_sample_z, 8, 7)
    ops_vote = tf.reduce_sum(ops_cls, axis=1).numpy()  # (batch_size, 1, 8 * 7)
    adj = tf.reshape(adj, (
    query_amount, 1, model.num_nodes, model.num_nodes))  # (batch_size, num_sample_z, 8, 8)
    adj = tf.where(tf.reduce_mean(adj, axis=1) >= 0.5, x=1., y=0.).numpy()  # (batch_size, 8, 8)

    ops_idx_list = [np.argmax(i, axis=-1).tolist() for i in ops_vote]
    adj_list = [i for i in adj]


def eval_query_target(model: tf.keras.Model, x_dim: int, z_dim: int, query_amount=10, noise_scale=0.0,
                    version=2, target=0.68, hash_version=0):
    # Eval query 1.0
    found_arch_list = {}
    to_inv_acc = float(target)
    to_inv = tf.repeat(tf.reshape(tf.constant(to_inv_acc), [-1, 1]), query_amount, axis=0)
    # to_inv += noise_scale * tf.random.normal(tf.shape(to_inv))
    ops_idx_lis, adj_list, rev_latent = inverse_from_acc(model, num_sample_z=1, x_dim=x_dim, z_dim=z_dim,
                                             noise_std=noise_scale, to_inv_acc=to_inv, version=version)
    invalid = 0
    #time.sleep(5)
    bag_arch = []
    total_found_arch = 0
    total_found_arch+=(len(adj_list))
    for ops_idx, adj, encoding in zip(ops_idx_lis, adj_list, rev_latent):
        #arch = NASBenchNLPDataset.change_to_recepie(adj,ops_idx)
        #print("=====",arch)
        try:
            arch = NASBenchNLPDataset.change_to_recepie(adj,ops_idx,2,True)
            recepie = copy.deepcopy(arch['recepie'])
            item = NASBenchNLPDataset.map_network(recepie,check_max=True)
            compact = NASBenchNLPDataset.generate_nx11_compact(item) # 排圖的
            recepie=NASBenchNLPDataset.convert_compact_to_recipe(compact)
            hash_number=hash(str(arch.items()))
            recepies = copy.deepcopy(recepie)
            del recepies["h_new_node"]
            #print("============new arch=============\n",recepies)
            #if (not NASBenchNLPDataset.check_map_network_node_find_less_number(recepies,10)) or (not NASBenchNLPDataset.check_pass_arch(recepies)):
            if  not NASBenchNLPDataset.check_pass_arch(recepies):
                   # print(NASBenchNLPDataset.check_pass_arch(recepies))
                    bag_arch.append(recepies)
                    continue
            
            recepies["h_new_node"] = recepie["h_new_node"]
            arch['recepie'] = copy.deepcopy(recepies)
            recepiess = copy.deepcopy(recepies)
            del recepiess["h_new_node"]
            found_arch_list[hash_number]={'x':np.array(NASBenchNLPDataset.arch_Operation_Matrix(arch)).astype(np.float32),
                            'a':np.array(NASBenchNLPDataset.arch_Adjacency_Matrix_for_node(arch)).astype(np.float32),
                            'y':np.array(NASBenchNLPDataset.set_y_data(recepiess)).astype(np.float32),
                            'latent': encoding}

        except:
           invalid += 1
    #print("invalid amount:", invalid)
    return found_arch_list,bag_arch,total_found_arch

def eval_query_target_node_one(model: tf.keras.Model, x_dim: int, z_dim: int, query_amount=10, noise_scale=0.0,
                    version=2, target=0.68, hash_version=0):
    # Eval query 1.0
    found_arch_list = {}
    to_inv_acc = float(target)
    to_inv = tf.repeat(tf.reshape(tf.constant(to_inv_acc), [-1, 1]), query_amount, axis=0)
    # to_inv += noise_scale * tf.random.normal(tf.shape(to_inv))
    ops_idx_lis, adj_list, rev_latent = inverse_from_acc(model, num_sample_z=1, x_dim=x_dim, z_dim=z_dim,
                                             noise_std=noise_scale, to_inv_acc=to_inv, version=version)
    invalid = 0
    #time.sleep(5)
    for ops_idx, adj, encoding in zip(ops_idx_lis, adj_list, rev_latent):
        #arch = NASBenchNLPDataset.change_to_recepie(adj,ops_idx)
        #print("=====",arch)
        #try:
            arch = NASBenchNLPDataset.change_to_recepie(adj,ops_idx,2,True)
            recepie = copy.deepcopy(arch['recepie'])
            #sprint("recepie",recepie)
            item = NASBenchNLPDataset.map_network(recepie)
            compact = NASBenchNLPDataset.generate_nx11_compact(item)
            recepie=NASBenchNLPDataset.convert_compact_to_recipe(compact)
            hash_number=hash(str(recepie.items()))
            recepies = copy.deepcopy(recepie)
            del recepies["h_new_node"]

            #if (not NASBenchNLPDataset.check_map_network_node_find_less_number(recepies,10)) or (not NASBenchNLPDataset.check_pass_arch(recepies)):
            if (not NASBenchNLPDataset.check_map_network_node_find_num(recepies,10)) or (not NASBenchNLPDataset.check_pass_arch(recepies)):
                    continue
            recepies["h_new_node"] = recepie["h_new_node"]
            arch['recepie'] = copy.deepcopy(recepies)
            
            found_arch_list[hash_number]={'x':np.array(NASBenchNLPDataset.arch_Operation_Matrix(arch)).astype(np.float32),
                            'a':np.array(NASBenchNLPDataset.arch_Adjacency_Matrix_for_node(arch)).astype(np.float32),
                            'y':np.array(NASBenchNLPDataset.set_y_data(recepies)).astype(np.float32),
                            'latent': encoding}
           
        #except:
        #    invalid += 1
    # print("invalid amount:", invalid)
    return found_arch_list 

def eval_query_target_node_two(model: tf.keras.Model, x_dim: int, z_dim: int, query_amount=10, noise_scale=0.0,
                    version=2, target=0.68, hash_version=0):
    # Eval query 1.0
    found_arch_list = {}
    to_inv_acc = float(target)
    to_inv = tf.repeat(tf.reshape(tf.constant(to_inv_acc), [-1, 1]), query_amount, axis=0)
    # to_inv += noise_scale * tf.random.normal(tf.shape(to_inv))
    ops_idx_lis, adj_list, rev_latent = inverse_from_acc(model, num_sample_z=1, x_dim=x_dim, z_dim=z_dim,
                                             noise_std=noise_scale, to_inv_acc=to_inv, version=version)
    invalid = 0
    #time.sleep(5)
    for ops_idx, adj, encoding in zip(ops_idx_lis, adj_list, rev_latent):
        #arch = NASBenchNLPDataset.change_to_recepie(adj,ops_idx)
        #print("=====",arch)
        try:
            arch = NASBenchNLPDataset.change_to_recepie(adj,ops_idx,2,True)
            recepie = copy.deepcopy(arch['recepie'])
            #sprint("recepie",recepie)
            item = NASBenchNLPDataset.map_network(recepie)
            compact = NASBenchNLPDataset.generate_nx11_compact(item)
            recepie=NASBenchNLPDataset.convert_compact_to_recipe(compact)
            hash_number=hash(str(recepie.items()))
            recepies = copy.deepcopy(recepie)
            del recepies["h_new_node"]

            if (not NASBenchNLPDataset.check_map_network_node_find_more_num(recepies,11)) or (not NASBenchNLPDataset.check_pass_arch(recepies)):
                    continue
            recepies["h_new_node"] = recepie["h_new_node"]
            arch['recepie'] = copy.deepcopy(recepies)

            found_arch_list[hash_number]={'x':np.array(NASBenchNLPDataset.arch_Operation_Matrix(arch)).astype(np.float32),
                            'a':np.array(NASBenchNLPDataset.arch_Adjacency_Matrix_for_node(arch)).astype(np.float32),
                            'y':np.array(NASBenchNLPDataset.set_y_data(recepies)).astype(np.float32),
                            'latent': encoding}
           
        except:
            invalid += 1
    # print("invalid amount:", invalid)
    return found_arch_list 

def eval_from_lat(model: tf.keras.Model, rev_latent, batch_size, num_nodes, latent_dim, hash_version=0):
    """
    根据给定的潜在向量 rev_latent，计算其对应的操作索引、邻接矩阵及预测结果。
    
    Args:
        model: 已训练的模型对象。
        rev_latent: 已生成的潜在向量，形状为 (N, latent_dim)。
        version: 模型版本（默认为 2）。
        target: 目标精度（默认为 0.949）。
        hash_version: 哈希计算版本（默认为 0）。
        
    Returns:
        found_arch_list: 包含操作索引、邻接矩阵、预测结果及潜在向量的字典。
    """
    # 获取潜在向量的数量
    num_samples = rev_latent.shape[0]
    rev_latent = tf.reshape(rev_latent, (batch_size, num_nodes, latent_dim))

    try:
        num_nvp = model.num_nvp
    except AttributeError:
        num_nvp = 1

    # Decode the latent representation
    print("rev latent", tf.shape(rev_latent))
    _, adj, ops_cls, _ = model.decode(rev_latent)  # Adjust shapes as per your model's decode method

    # Process ops_cls
    ops_cls = tf.reshape(ops_cls, (batch_size * num_nvp, 1, -1, model.num_ops))  # Adjust as needed
    ops_vote = tf.reduce_sum(ops_cls, axis=1).numpy()  # (batch_size * num_nvp, ..., model.num_ops)

    # Process adj
    adj = tf.reshape(adj, (batch_size * num_nvp, 1, model.num_nodes, model.num_nodes))
    adj = tf.where(tf.reduce_mean(adj, axis=1) >= 0.5, x=1., y=0.).numpy()  # Binarize adjacency matrices

    # Extract operation indices and adjacency lists
    ops_idx_list = [np.argmax(i, axis=-1).tolist() for i in ops_vote]
    adj_list = [i for i in adj]

    found_arch_list = []
    invalid = 0
    print_or_not = True
    idx = 0
    record = []

    for ops_idx, adj, encoding in zip(ops_idx_list, adj_list, rev_latent):
        try:
            arch = NASBenchNLPDataset.change_to_recepie(adj,ops_idx,2,True)
            recepie = copy.deepcopy(arch['recepie'])
            #sprint("recepie",recepie)
            item = NASBenchNLPDataset.map_network(recepie)
            compact = NASBenchNLPDataset.generate_nx11_compact(item)
            recepie=NASBenchNLPDataset.convert_compact_to_recipe(compact)
            hash_number=hash(str(recepie.items()))
            recepies = copy.deepcopy(recepie)
            del recepies["h_new_node"]
            # 存储架构信息
            if not NASBenchNLPDataset.check_pass_arch(recepies):
                invalid += 1
                continue

            found_arch_list.append({'x':np.array(NASBenchNLPDataset.arch_Operation_Matrix(arch)).astype(np.float32),
                            'a':np.array(NASBenchNLPDataset.arch_Adjacency_Matrix_for_node(arch)).astype(np.float32),
                            'y':np.array(NASBenchNLPDataset.set_y_data(recepies)).astype(np.float32),
                            'latent': encoding,
                            'origin_index': idx,
                            'hash': hash_number})
            record.append(idx)
        except:
            # 跳过无效的架构
            invalid += 1
        idx += 1

    # 打印无效架构数量
    print(f"Invalid architectures: {invalid}/{num_samples}")

    return found_arch_list, record
