import os
import pickle
from typing import Union, List
import numpy as np
import spektral
import tensorflow as tf
from datasets.nb101_dataset import NasBench101Dataset, mask_padding_vertex_for_spec, mask_padding_vertex_for_model, \
    mask_for_spec
from datasets.nb201_dataset import NasBench201Dataset, OPS_by_IDX_201, ADJACENCY, ops_list_to_nb201_arch_str, \
    OP_PRIMITIVES_NB201
from datasets.nb101_dataset import OP_PRIMITIVES_NB101, NasBench101Dataset, pad_nb101_graph
from datasets.transformation import OnlyValidAccTransform, OnlyFinalAcc, LabelScale
from datasets.utils import train_valid_test_split_dataset, mask_graph_dataset
from models.GNN import GraphAutoencoderNVP
from nats_bench import create
import matplotlib.pyplot as plt
from spektral.data import BatchLoader
import matplotlib as mpl
from utils.tf_utils import to_undiredted_adj, set_global_determinism
from nb101_helper import largest_subgraph_0_to_6, get_spec_hash, all_topological_sorts, reorder_graph, one_hot_encode

nb201api = create(None, 'tss', fast_mode=True, verbose=False)
nb101_dataset = NasBench101Dataset(end=0)
nb201_dataset = NasBench201Dataset(end=0)


def query_acc_by_ops(ops: Union[list, np.ndarray], dataset_name, is_random=False, on='valid-accuracy') -> float:
    """
    :param ops: ops_idx or ops_cls
    :param is_random: False will return the avg. of result
    :param on: valid-accuracy or test-accuracy
    :return: acc
    """
    if isinstance(ops, np.ndarray):
        ops_idx = np.argmax(ops, axis=-1)
    else:
        ops_idx = ops

    ops = [OPS_by_IDX_201[i] for i in ops_idx]
    assert ops[0] == 'input' and ops[-1] == 'output'
    arch_str = ops_list_to_nb201_arch_str(ops)

    if on == 'valid-accuracy':
        acc = nb201_dataset.hash_to_metrics[arch_str][dataset_name][1]
    elif on == 'test-accuracy':
        '''
        if dataset_name == 'cifar10-valid':
            data = meta_info.get_metrics('cifar10', 'ori-test', iepoch=None, is_random=is_random)
        else:
            data = meta_info.get_metrics(dataset_name, 'x-test', iepoch=None, is_random=is_random)
        acc = data['accuracy'] / 100
        '''
        acc = nb201_dataset.hash_to_metrics[arch_str][dataset_name][2]
    else:
        raise ValueError('on should be valid-accuracy or test-accuracy')

    return float(acc)

def inverse_from_acc(model: tf.keras.Model, num_sample_z: int, x_dim: int, z_dim: int, to_inv_acc,
                     noise_std=0., version=2):
    batch_size = int(tf.shape(to_inv_acc)[0])
    try:
        num_nvp = model.num_nvp
    except:
        num_nvp = 1

    y = tf.repeat(to_inv_acc, num_sample_z, axis=0)  # (batch_size * num_sample_z, 1)
    # z (batch_size * num_sample_z, z_dim)

#==============================================================
    noise = tf.random.normal(tf.shape(y), stddev=0.003)    #y可以加noise!!!!!
    y = y + noise
#==============================================================

    y = tf.concat([tf.random.normal((batch_size * num_sample_z, z_dim)), y], axis=-1)  # (num_sample_z, z_dim + 1)

    rev_latent = model.inverse(y)  # (num_sample_z, num_nvp, latent_dim)
    if version == 1:
        raise NotImplementedError
        # rev_latent = rev_latent[:, :x_dim]
    elif version == 2:
        rev_latent = tf.reshape(rev_latent,
                                (batch_size * num_nvp, model.num_nodes, -1))  # (batch_size, num_sample_z, latent_dim)
    else:
        raise ValueError('version')

    _, adj, ops_cls, adj_cls = model.decode(rev_latent + tf.random.normal(tf.shape(rev_latent),
                                                                          stddev=noise_std))  # (batch_size, num_sample_z, 8, 8), (batch_size, num_sample_z, 8, 7)
    ops_cls = tf.reshape(ops_cls,
                         (batch_size * num_nvp, num_sample_z, -1, model.num_ops))  # (batch_size, num_sample_z, 8, 7)
    ops_vote = tf.reduce_sum(ops_cls, axis=1).numpy()  # (batch_size, 1, 8 * 7)

    adj = tf.reshape(adj, (
    batch_size * num_nvp, num_sample_z, model.num_nodes, model.num_nodes))  # (batch_size, num_sample_z, 8, 8)
    adj = tf.where(tf.reduce_mean(adj, axis=1) >= 0.5, x=1., y=0.).numpy()  # (batch_size, 8, 8)

    ops_idx_list = [np.argmax(i, axis=-1).tolist() for i in ops_vote]
    adj_list = [i for i in adj]

    return ops_idx_list, adj_list

#=============================================================================================================
def eval_query_best(model: tf.keras.Model, dataset_name, x_dim: int, z_dim: int, query_amount=10, noise_scale=0.0,
                    version=2):
    # Eval query 1.0
    y = [] 
    found_arch_list = []
    invalid = 0
    to_inv_acc = 0.9161 #cifar10的最高validation accuracy #可以改dataset最大accuracy!!!!!
    to_inv = tf.repeat(tf.reshape(tf.constant(to_inv_acc), [-1, 1]), query_amount, axis=0)
    # to_inv += noise_scale * tf.random.normal(tf.shape(to_inv))
    ops_idx_lis, adj_list = inverse_from_acc(model, num_sample_z=1, x_dim=x_dim, z_dim=z_dim,
                                             noise_std=noise_scale, to_inv_acc=to_inv, version=version)
    for ops_idx, adj in zip(ops_idx_lis, adj_list):
        try:
            if dataset_name != 'nb101':
                assert np.array_equal(adj, ADJACENCY)
                acc = query_acc_by_ops(ops_idx, dataset_name, is_random=False)
            else:
                adj_for_spec, ops_idx_for_spec = mask_padding_vertex_for_spec(adj, ops_idx)
                acc = nb101_dataset.get_metrics(adj_for_spec, ops_idx_for_spec)[1]  # [1] for valid acc

            y.append(acc)
            if dataset_name == 'nb101':
                adj, ops = mask_padding_vertex_for_model(adj, np.eye(model.num_ops)[ops_idx])
            else:
                ops = np.eye(model.num_ops)[ops_idx]

            if adj is not None and ops is not None:     
                if dataset_name != 'nb101': 
                    found_arch_list.append({'x': ops.astype(np.float32),
                                            'a': adj.astype(np.float32),
                                            'y': np.array([acc]).astype(np.float32)})
                else:
                    found_arch_list.append({'x': ops.astype(np.float32),
                                            'a': adj.astype(np.float32),
                                            'y': np.array(acc).astype(np.float32)})
        except:
            # print('invalid')
            invalid += 1
    # print("invalid:", invalid)
    # print("found_arch_list:", len(found_arch_list))
    to_inv = None
    if len(y) == 0:
        return invalid, 0, 0, found_arch_list

    return invalid, sum(y) / len(y), max(y), found_arch_list
#=============================================================================================================
def eval_query_target(model: tf.keras.Model, dataset_name, x_dim: int, z_dim: int, query_amount=10, noise_scale=0.0,
                    version=2, target=1.0):
    # Eval query 1.0
    y = []
    found_arch_list = []
    invalid = 0
    to_inv_acc = float(target)
    to_inv = tf.repeat(tf.reshape(tf.constant(to_inv_acc), [-1, 1]), query_amount, axis=0)
    # to_inv += noise_scale * tf.random.normal(tf.shape(to_inv))
    ops_idx_lis, adj_list = inverse_from_acc(model, num_sample_z=1, x_dim=x_dim, z_dim=z_dim,
                                             noise_std=noise_scale, to_inv_acc=to_inv, version=version)
    for ops_idx, adj in zip(ops_idx_lis, adj_list):
        try:
            if dataset_name != 'nb101':
                assert np.array_equal(adj, ADJACENCY)
                acc = query_acc_by_ops(ops_idx, dataset_name, is_random=False)
            else:
                adj_for_spec, ops_idx_for_spec = mask_padding_vertex_for_spec(adj, ops_idx)
                acc = nb101_dataset.get_metrics(adj_for_spec, ops_idx_for_spec)[1]  # [1] for valid acc

            y.append(acc)
            if dataset_name == 'nb101':
                adj, ops = mask_padding_vertex_for_model(adj, np.eye(model.num_ops)[ops_idx])
            else:
                ops = np.eye(model.num_ops)[ops_idx]

            if adj is not None and ops is not None:
                found_arch_list.append({'x': ops.astype(np.float32),
                                        'a': adj.astype(np.float32),
                                        'y': np.array([acc]).astype(np.float32)})
        except:
            print('invalid')
            invalid += 1

    to_inv = None
    if len(y) == 0:
        return invalid, 0, 0, found_arch_list

    return invalid, sum(y) / len(y), max(y), found_arch_list

def query_tabular(dataset_name: str, archs: Union[List, spektral.data.Dataset]):
    if isinstance(archs, spektral.data.Dataset):
        archs = [{'a': graph.a, 'x': graph.x} for graph in archs]

    acc_list = []
    for idx, i in enumerate(archs):
        if dataset_name != 'nb101':
            acc = query_acc_by_ops(i['x'], dataset_name)
            test_acc = query_acc_by_ops(i['x'], dataset_name, on='test-accuracy')
        else:
            i = mask_for_spec(i)
            metrics = nb101_dataset.get_metrics(i['a'], np.argmax(i['x'], axis=-1))
            acc = float(metrics[1])
            test_acc = float(metrics[2])
        acc_list.append({'valid-accuracy': acc, 'test-accuracy': test_acc})

    return acc_list

output_filename = "spec_hash_dict.pkl"
if os.path.exists(output_filename):
    with open(output_filename, "rb") as f:
        spec_dict = pickle.load(f)

def eval_from_lat(model: tf.keras.Model, rev_latent, batch_size, num_nodes, latent_dim, dataset_name, hash_version=0):
    rev_latent = tf.reshape(rev_latent, (batch_size, num_nodes, latent_dim))

    try:
        num_nvp = model.num_nvp
    except AttributeError:
        num_nvp = 1

    # Decode the latent representation
    # print("rev latent", tf.shape(rev_latent))
    _, adj, ops_cls, _ = model.decode(rev_latent)  # Adjust shapes as per your model's decode method

    # Process ops_cls
    ops_cls = tf.reshape(ops_cls, (batch_size * num_nvp, 1, -1, model.num_ops))  # Adjust as needed
    ops_vote = tf.reduce_sum(ops_cls, axis=1).numpy()  # (batch_size * num_nvp, ..., model.num_ops)

    # Process adj
    adj = tf.reshape(adj, (batch_size * num_nvp, 1, model.num_nodes, model.num_nodes))
    adj = tf.where(tf.reduce_mean(adj, axis=1) >= 0.5, x=1., y=0.).numpy()  # Binarize adjacency matrices

    # Extract operation indices and adjacency lists
    ops_idx_list = [np.argmax(i, axis=-1).tolist() for i in ops_vote]
    adj_list = [i for i in adj]

    found_arch_list = []
    invalid = 0
    print_or_not = True
    idx = 0
    record = []
    # print(dataset_name)
    hash_list = []

    for ops_idx, adj, encoding in zip(ops_idx_list, adj_list, rev_latent):
        try:
            if dataset_name != 'nb101':
                assert np.array_equal(adj, ADJACENCY)
                acc = query_acc_by_ops(ops_idx, dataset_name, is_random=False)
            else:
                adj_for_spec, ops_idx_for_spec = mask_padding_vertex_for_spec(adj, ops_idx)
                acc = nb101_dataset.get_metrics(adj_for_spec, ops_idx_for_spec)[1]  # [1] for valid acc
                graph_str = nb101_dataset.get_spec_hash(adj_for_spec, ops_idx_for_spec)

            if dataset_name == 'nb101':
                adj, ops = mask_padding_vertex_for_model(adj, np.eye(model.num_ops)[ops_idx])
            else:
                ops = np.eye(model.num_ops)[ops_idx]
                ops_name = [OPS_by_IDX_201[i] for i in ops_idx]

            if dataset_name == 'nb101':
                if adj is not None and ops is not None and sum(ops[:, 3]) > 0 and graph_str not in hash_list:
                    found_arch_list.append({'x': ops.astype(np.float32),
                                            'a': adj_for_spec.astype(np.float32),
                                            'y': np.array([acc]).astype(np.float32),
                                            'latent': encoding,
                                            'origin_index': idx,
                                            'hash': graph_str})
                    record.append(idx)
                    hash_list.append(graph_str)
            else:
                found_arch_list.append({'x': ops.astype(np.float32),
                                        'a': adj.astype(np.float32),
                                        'y': np.array([acc]).astype(np.float32),
                                        'latent': encoding,
                                        'origin_index': idx,
                                        'hash': ops_list_to_nb201_arch_str(ops_name)})
                record.append(idx)
        except:
            invalid += 1
        idx += 1

    return found_arch_list, record

#=================================================================
def query_tabular_accuracy(dataset_name, archs):

    acc_list = []
    for key in archs:
        if dataset_name != 'nb101':
            acc = query_acc_by_ops(archs[key]['x'], dataset_name)
            test_acc = query_acc_by_ops(archs[key]['x'], dataset_name, on='test-accuracy')
        else:
            i = mask_for_spec(archs[key])
            # print('type(archs)', type(archs))
            # print('archs', archs)
            #print('key[x]', key['x'])
            num_node = i['x'].shape[0]
            metrics = nb101_dataset.get_metrics(archs[key]['a'][:num_node,:num_node], np.argmax(i['x'], axis=-1))
            acc = float(metrics[1])
            test_acc = float(metrics[2])
        acc_list.append({'valid-accuracy': acc, 'test-accuracy': test_acc})

    return acc_list

#=================================================================


if __name__ == '__main__':
    mpl.rcParams['figure.dpi'] = 300
    random_seed = 0
    set_global_determinism(random_seed)
    #random.seed(random_seed)
    #tf.random.set_seed(random_seed)

    dataset_name = 'cifar10-valid'
    plot_on_slit = 'train'  # train, valid, test
    model_weight = 'logs/for_tsne/20230603-141711/modelGAE_weights_retrain'

    if dataset_name == 'nb101':
        num_ops = len(OP_PRIMITIVES_NB101)  # 5
        num_nodes = 7
        num_adjs = num_nodes ** 2
        if os.path.exists('datasets/NasBench101Dataset.cache'):
            datasets = pickle.load(open('datasets/NasBench101Dataset.cache', 'rb'))
        else:
            datasets = NasBench101Dataset(start=0, end=423623)
            with open('datasets/NasBench101Dataset.cache', 'wb') as f:
                pickle.dump(datasets, f)
        datasets = train_valid_test_split_dataset(datasets,
                                                  ratio=[0.8, 0.1, 0.1],
                                                  shuffle=True,
                                                  shuffle_seed=random_seed)
    else:
        # 15624
        num_ops = len(OP_PRIMITIVES_NB201)  # 7
        num_nodes = 8
        num_adjs = num_nodes ** 2
        label_epochs = 200
        if os.path.exists(f'datasets/NasBench201Dataset_{dataset_name}.cache'):
            datasets = pickle.load(open(f'datasets/NasBench201Dataset_{dataset_name}.cache', 'rb'))
        else:
            datasets = NasBench201Dataset(start=0, end=15624, dataset=dataset_name, hp=str(label_epochs), seed=False)
            with open(f'datasets/NasBench201Dataset_{dataset_name}.cache', 'wb') as f:
                pickle.dump(datasets, f)
        datasets = train_valid_test_split_dataset(datasets,
                                                  ratio=[0.8, 0.1, 0.1],
                                                  shuffle=True,
                                                  shuffle_seed=random_seed)

    for key in datasets:
        datasets[key].apply(OnlyValidAccTransform())
        datasets[key].apply(OnlyFinalAcc())
        if dataset_name != 'nb101':
                datasets[key].apply(LabelScale(scale=0.01))

    datasets['train'] = mask_graph_dataset(datasets['train'], 350, 1, random_seed=random_seed)
    datasets['valid'] = mask_graph_dataset(datasets['valid'], 50, 1, random_seed=random_seed)
    datasets['train'].filter(lambda g: not np.isnan(g.y))
    datasets['valid'].filter(lambda g: not np.isnan(g.y))

    d_model = 32
    dropout_rate = 0.0
    dff = 256
    num_layers = 3
    num_heads = 3

    latent_dim = 16
    x_dim = latent_dim * num_nodes
    y_dim = 1  # 1
    z_dim = x_dim - 1  # 27
    # z_dim = latent_dim * 4 - 1
    tot_dim = y_dim + z_dim  # 28
    pad_dim = tot_dim - x_dim  # 14

    nvp_config = {
        'n_couple_layer': 4,
        'n_hid_layer': 4,
        'n_hid_dim': 128,
        'name': 'NVP',
        'num_couples': 2,
        'inp_dim': tot_dim
    }

    model = GraphAutoencoderNVP(nvp_config=nvp_config, latent_dim=latent_dim, num_layers=num_layers,
                                d_model=d_model, num_heads=num_heads,
                                dff=dff, num_ops=num_ops, num_nodes=num_nodes,
                                num_adjs=num_adjs, dropout_rate=dropout_rate, eps_scale=0.)
    model((tf.random.normal(shape=(1, num_nodes, num_ops)), tf.random.normal(shape=(1, num_nodes, num_nodes))))
    model.load_weights(model_weight)

    # Eval inverse
    x = []
    y = []
    invalid = 0
    loader = BatchLoader(datasets[plot_on_slit], batch_size=512, epochs=1)
    for _, label_acc in loader:
        ops_idx_lis, adj_list = inverse_from_acc(model, num_sample_z=1, x_dim=x_dim, z_dim=z_dim,
                                                 to_inv_acc=label_acc[:, -1:].astype(np.float32))

        for ops_idx, adj, query_acc in zip(ops_idx_lis, adj_list, label_acc[:, -1]):
            try:
                if dataset_name != 'nb101':
                    assert np.array_equal(adj, ADJACENCY)
                    acc = query_acc_by_ops(ops_idx, dataset_name, is_random=False)
                    ops_str_list = [OPS_by_IDX_201[i] for i in ops_idx]
                else:
                    adj_for_spec, ops_idx_for_spec = mask_padding_vertex_for_spec(adj, ops_idx)
                    acc = nb101_dataset.get_metrics(adj_for_spec, ops_idx_for_spec)[1]  # [1] for valid acc

                x.append(float(query_acc))
                y.append(acc)
            except:
                print('invalid')
                invalid += 1

    print('Number of invalid decode', invalid)
    fig, ax = plt.subplots()
    ax.axline((0, 0), slope=1, linewidth=0.2, color='black')
    ax.set_xlabel('Query Accuracy')
    ax.set_ylabel('True Accuracy')
    plt.scatter(x, y, s=[1] * len(x))
    plt.xlim(0., 1.)
    plt.ylim(0., 1.)
    plt.savefig(f'inverse_{plot_on_slit}.png')
    plt.cla()

    # Eval regression
    loader = BatchLoader(datasets[plot_on_slit], batch_size=256, epochs=1)
    x = []
    y = []
    for arch, label_acc in loader:
        arch = (arch[0], to_undiredted_adj(arch[1]))
        ops_cls, adj_cls, kl_loss, reg, flat_encoding = model(arch)

        for true_acc, query_acc in zip(label_acc[:, -1], reg[:, -1]):
            x.append(float(query_acc))
            y.append(float(true_acc))

    print('Number of invalid decode', invalid)
    fig, ax = plt.subplots()
    ax.axline((0, 0), slope=1, linewidth=0.2, color='black')
    ax.set_xlabel('Predicted Accuracy')
    ax.set_ylabel('True Accuracy')
    plt.scatter(x, y, s=[1] * len(x))
    plt.xlim(0., 1.)
    plt.ylim(0., 1.)
    plt.savefig(f'regresion_{plot_on_slit}.png')
    plt.cla()

    # Eval decending
    x = []
    y = []
    invalid = 0
    to_inv_acc = 0.00
    to_inv = []
    to_inv_repeat = 1
    while to_inv_acc <= 1.0:
        to_inv.append(to_inv_acc)
        to_inv_acc += 0.005

    to_inv = tf.repeat(tf.reshape(tf.constant(to_inv), [-1, 1]), to_inv_repeat, axis=0)
    ops_idx_lis, adj_list = inverse_from_acc(model, num_sample_z=1, x_dim=x_dim, z_dim=z_dim, to_inv_acc=to_inv)
    for ops_idx, adj, query_acc in zip(ops_idx_lis, adj_list, to_inv[:, -1]):
        try:
            if dataset_name != 'nb101':
                assert np.array_equal(adj, ADJACENCY)
                acc = query_acc_by_ops(ops_idx, dataset_name, is_random=False)
            else:
                adj_for_spec, ops_idx_for_spec = mask_padding_vertex_for_spec(adj, ops_idx)
                acc = nb101_dataset.get_metrics(adj_for_spec, ops_idx_for_spec)[1]  # [1] for valid acc

            print(acc)
            x.append(query_acc)
            y.append(acc)
        except:
            print('invalid')
            invalid += 1

    print('Number of invalid decode', invalid)
    fig, ax = plt.subplots()
    ax.axline((0, 0), slope=1, linewidth=0.2, color='black')
    ax.set_xlabel('Query Accuracy')
    ax.set_ylabel('True Accuracy')
    plt.scatter(x, y, s=[1] * len(x))
    plt.xlim(0., 1.)
    plt.ylim(0., 1.)
    plt.savefig('decending.png')
    plt.cla()

    # invalid, avg_acc, best_acc, _ = eval_query_best(model, dataset, x_dim, z_dim)
    # print('Number of invalid decode', invalid)
    # print('Avg found acc', avg_acc)
    # print('Best found acc', best_acc)