import argparse
import pickle
import random
import numpy as np
from tensorflow.python.keras.callbacks import CSVLogger, EarlyStopping
from datasets.nb101_dataset import OP_PRIMITIVES_NB101, NasBench101Dataset, mask_padding_vertex_for_spec
from datasets.transformation import ReshapeYTransform, OnlyValidAccTransform, OnlyFinalAcc, LabelScale
from invertible_neural_networks.flow import MMD_multiscale
from models.GNN import GraphAutoencoder, GraphAutoencoderEnsembleNVP, get_rank_weight
from models.TransformerAE import TransformerAutoencoderNVP
import tensorflow as tf
import os
from datasets.nb201_dataset import NasBench201Dataset, OP_PRIMITIVES_NB201, ADJACENCY
from datasets.utils import train_valid_test_split_dataset, mask_graph_dataset, arch_list_to_set, graph_to_str, \
    repeat_graph_dataset_element
from evalGAE import eval_query_best, nb101_dataset, query_tabular, query_tabular_accuracy, eval_from_lat, query_acc_by_ops
from trainGAE_two_phase import mask_for_model, graph_to_spec_graph, RandomArchGenerator
from utils.py_utils import get_logdir_and_logger
from spektral.data import Graph, PackedBatchLoader
from utils.tf_utils import to_undiredted_adj, set_global_determinism
import logging
from copy import deepcopy
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.utils import shuffle
from Flow2FlowAgent import Flow2FlowAgent
import pandas as pd
import ast
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='tensorflow.python.framework.indexed_slices')

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_sample_amount', type=int, default=30, help='Number of samples to train (default: 50)')
    parser.add_argument('--valid_sample_amount', type=int, default=10, help='Number of samples to train (default: 50)')
    parser.add_argument('--query_budget', type=int, default=192)
    parser.add_argument('--dataset', type=str, default='cifar10-valid', help='Could be nb101, cifar10-valid, cifar100, ImageNet16-120')
    parser.add_argument('--top_k', type=int, default=5)
    parser.add_argument('--finetune', action='store_true')
    parser.add_argument('--no_finetune', dest='finetune', action='store_false')
    parser.set_defaults(finetune=False)
    parser.add_argument('--retrain_finetune', action='store_true')
    parser.add_argument('--no_retrain_finetune', dest='retrain_finetune', action='store_false')
    parser.set_defaults(retrain_finetune=False)
    parser.add_argument('--rank_weight', action='store_true')
    parser.add_argument('--no_rank_weight', dest='rank_weight', action='store_false')
    parser.set_defaults(rank_weight=False)
    parser.add_argument('--random_sample', action='store_true')
    parser.set_defaults(random_sample=False)
    parser.add_argument('--seed', type=int, default=0)
    return parser.parse_args()


def cal_ops_adj_loss_for_graph(x_batch_train, ops_cls, adj_cls, reduction='auto', rank_weight=None):
    ops_label, adj_label = x_batch_train
    #adj_label = tf.reshape(adj_label, [tf.shape(adj_label)[0], -1])
    ops_loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False, reduction=reduction)(ops_label, ops_cls)
    #ops_loss = tf.keras.losses.KLDivergence()(ops_label, ops_cls)
    adj_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False, reduction=reduction)(adj_label, adj_cls)
    if reduction == 'none':
        ops_loss = tf.reduce_mean(ops_loss, axis=-1)
        adj_loss = tf.reduce_mean(adj_loss, axis=-1)
    if rank_weight is not None:
        ops_loss = tf.reduce_sum(tf.multiply(ops_loss, rank_weight))
        adj_loss = tf.reduce_sum(tf.multiply(adj_loss, rank_weight))

    return ops_loss, adj_loss


class Trainer1(tf.keras.Model):
    def __init__(self, model: GraphAutoencoder):
        super(Trainer1, self).__init__()
        self.model = model
        self.ops_weight = 1
        self.adj_weight = 1
        self.kl_weight = 0.16

        self.loss_tracker = {
            'rec_loss': tf.keras.metrics.Mean(name="rec_loss"),
            'ops_loss': tf.keras.metrics.Mean(name="ops_loss"),
            'adj_loss': tf.keras.metrics.Mean(name="adj_loss"),
            'kl_loss': tf.keras.metrics.Mean(name="kl_loss")
        }

    def train_step(self, data):
        x_batch_train, _ = data
        undirected_x_batch_train = (x_batch_train[0], to_undiredted_adj(x_batch_train[1]))

        # Open a GradientTape to record the operations run
        # during the forward pass, which enables auto-differentiation.
        # Forward loss and AE Reconstruct loss
        with tf.GradientTape() as tape:
            # Run the forward pass of the layer.
            # The operations that the layer applies
            # to its inputs are going to be recorded
            # on the GradientTape.
            ops_cls, adj_cls, kl_loss, _ = self.model(undirected_x_batch_train,
                                                      training=True)  # Logits for this minibatch
            ops_loss, adj_loss = cal_ops_adj_loss_for_graph(x_batch_train, ops_cls, adj_cls)
            rec_loss = self.ops_weight * ops_loss + self.adj_weight * adj_loss + self.kl_weight * kl_loss

        grads = tape.gradient(rec_loss, self.model.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))

        self.loss_tracker['rec_loss'].update_state(rec_loss)
        self.loss_tracker['ops_loss'].update_state(ops_loss)
        self.loss_tracker['adj_loss'].update_state(adj_loss)
        self.loss_tracker['kl_loss'].update_state(kl_loss)
        return {key: value.result() for key, value in self.loss_tracker.items()}

    def test_step(self, data):
        x_batch_train, _ = data
        undirected_x_batch_train = (x_batch_train[0], to_undiredted_adj(x_batch_train[1]))
        ops_cls, adj_cls, kl_loss, _ = self.model(undirected_x_batch_train, training=True)  # Logits for this minibatch
        ops_loss, adj_loss = cal_ops_adj_loss_for_graph(x_batch_train, ops_cls, adj_cls)
        rec_loss = self.ops_weight * ops_loss + self.adj_weight * adj_loss + self.kl_weight * kl_loss

        self.loss_tracker['rec_loss'].update_state(rec_loss)
        self.loss_tracker['ops_loss'].update_state(ops_loss)
        self.loss_tracker['adj_loss'].update_state(adj_loss)
        self.loss_tracker['kl_loss'].update_state(kl_loss)
        return {key: value.result() for key, value in self.loss_tracker.items()}

    @property
    def metrics(self):
        return [value for _, value in self.loss_tracker.items()]


class Trainer2(tf.keras.Model):
    def __init__(self, model: TransformerAutoencoderNVP, x_dim, y_dim, z_dim, finetune=False, is_rank_weight=False):
        super(Trainer2, self).__init__()
        self.model = model
        self.x_dim = x_dim
        self.y_dim = y_dim
        self.z_dim = z_dim
        self.finetune = finetune
        self.is_rank_weight = is_rank_weight

        # For reg loss weight
        self.w1 = 5.
        # For latent loss weight
        self.w2 = 1.
        # For rev loss weight
        self.w3 = 10.

        if self.is_rank_weight:
            self.reduction = 'none'
        else:
            self.reduction = 'auto'

        self.reg_loss_fn = tf.keras.losses.MeanSquaredError(reduction=self.reduction)
        self.loss_latent = MMD_multiscale
        self.loss_backward = tf.keras.losses.MeanSquaredError(reduction=self.reduction)
        self.loss_tracker = {
            'total_loss': tf.keras.metrics.Mean(name="total_loss"),
            'reg_loss': tf.keras.metrics.Mean(name="reg_loss"),
            'latent_loss': tf.keras.metrics.Mean(name="latent_loss"),
            'rev_loss': tf.keras.metrics.Mean(name="rev_loss")
        }
        if self.finetune:
            self.loss_tracker.update({
                'rec_loss': tf.keras.metrics.Mean(name="rec_loss"),
                'ops_loss': tf.keras.metrics.Mean(name="ops_loss"),
                'adj_loss': tf.keras.metrics.Mean(name="adj_loss"),
                'kl_loss': tf.keras.metrics.Mean(name="kl_loss")
            })
            self.ops_weight = 1
            self.adj_weight = 1
            self.kl_weight = 0.16

    def cal_reg_and_latent_loss(self, y, z, y_out, nan_mask, rank_weight=None):
        # y: (batch_size, y_dim)
        # y_out: (batch_size, num_nvp, y_dim + z_dim)
        y = tf.boolean_mask(y, nan_mask)
        y = tf.expand_dims(y, axis=1)
        y = tf.repeat(y, repeats=tf.shape(y_out)[1], axis=1)
        z = tf.boolean_mask(z, nan_mask)
        z = tf.expand_dims(z, axis=1)
        z = tf.repeat(z, repeats=tf.shape(y_out)[1], axis=1)
        pred_y = tf.boolean_mask(y_out[:, :, self.z_dim:], nan_mask)
        pred_z = tf.boolean_mask(y_out[:, :, :self.z_dim], nan_mask)

        reg_loss = self.reg_loss_fn(y, pred_y)

        zy = tf.transpose(tf.concat([z, y], axis=-1), (1, 0, 2))  # (num_nvp, batch_size, y_dim + z_dim)
        pred_zy = tf.transpose(tf.concat([pred_z, pred_y], axis=-1), (1, 0, 2)) # (num_nvp, batch_size, y_dim + z_dim)
        latent_loss = tf.reduce_mean(tf.vectorized_map(lambda x: self.loss_latent(x[0], x[1]), (zy, pred_zy)))

        if self.is_rank_weight:
            # reg_loss (batch_size, num_nvp)
            reg_loss = tf.multiply(reg_loss, tf.broadcast_to(tf.expand_dims(rank_weight, -1), tf.shape(reg_loss)))
            reg_loss = tf.reduce_sum(tf.reduce_mean(reg_loss, axis=-1))
        return reg_loss, latent_loss

    def cal_rev_loss(self, undirected_x_batch_train, y, z, nan_mask, noise_scale, rank_weight=None):
        y = tf.boolean_mask(y, nan_mask)
        z = tf.boolean_mask(z, nan_mask)
        y = y + noise_scale * tf.random.normal(shape=tf.shape(y), dtype=tf.float32)
        non_nan_x_batch = (tf.boolean_mask(undirected_x_batch_train[0], nan_mask),
                           tf.boolean_mask(undirected_x_batch_train[1], nan_mask))
        _, _, _, _, x_encoding = self.model(non_nan_x_batch, training=True)  # Logits for this minibatch
        x_rev = self.model.inverse(tf.concat([z, y], axis=-1))  # (batch_size, num_nvp, x_dim)
        x_encoding = tf.expand_dims(x_encoding, axis=1)
        x_encoding = tf.repeat(x_encoding, repeats=tf.shape(x_rev)[1], axis=1)
        rev_loss = self.loss_backward(x_encoding, x_rev)  # * x_batch_train.shape[0]
        if self.is_rank_weight:
            # rev_loss (batch_size, num_nvp)
            tf.multiply(rev_loss, tf.broadcast_to(tf.expand_dims(rank_weight, -1), tf.shape(rev_loss)))
            rev_loss = tf.reduce_sum(tf.reduce_mean(rev_loss, axis=-1))
        return rev_loss

    def train_step(self, data):
        if not self.finetune:
            self.model.encoder.trainable = False
            self.model.decoder.trainable = False

        x_batch_train, y_batch_train = data
        undirected_x_batch_train = (x_batch_train[0], to_undiredted_adj(x_batch_train[1]))
        y = y_batch_train[:, -self.y_dim:]
        z = tf.random.normal([tf.shape(y_batch_train)[0], self.z_dim])
        nan_mask = tf.where(~tf.math.is_nan(tf.reduce_sum(y, axis=-1)), x=True, y=False)
        rank_weight = get_rank_weight(tf.boolean_mask(y, nan_mask)) if self.is_rank_weight else None

        # Open a GradientTape to record the operations run
        # during the forward pass, which enables auto-differentiation.
        # Forward loss and AE Reconstruct loss
        with tf.GradientTape() as tape:
            # Run the forward pass of the layer.
            # The operations that the layer applies
            # to its inputs are going to be recorded
            # on the GradientTape.
            ops_cls, adj_cls, kl_loss, y_out, x_encoding = self.model(undirected_x_batch_train, kl_reduction='none', training=True)

            # To avoid nan loss when batch size is small
            reg_loss, latent_loss = tf.cond(tf.reduce_any(nan_mask),
                                            lambda: self.cal_reg_and_latent_loss(y, z, y_out, nan_mask, rank_weight),
                                            lambda: (0., 0.))

            forward_loss = self.w1 * reg_loss + self.w2 * latent_loss
            rec_loss = 0.
            if self.finetune:
                ops_loss, adj_loss = cal_ops_adj_loss_for_graph(x_batch_train, ops_cls, adj_cls, self.reduction, rank_weight)
                if rank_weight is not None:
                    kl_loss = tf.reduce_sum(tf.multiply(kl_loss, rank_weight))
                else:
                    kl_loss = tf.reduce_mean(kl_loss)
                rec_loss = self.ops_weight * ops_loss + self.adj_weight * adj_loss + self.kl_weight * kl_loss
                forward_loss += rec_loss

        grads = tape.gradient(forward_loss, self.model.trainable_weights)
        #grads = [tf.clip_by_norm(g, 1.) for g in grads]
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))

        # Backward loss
        with tf.GradientTape() as tape:
            self.model.encoder.trainable = False
            self.model.decoder.trainable = False
            # To avoid nan loss when batch size is small
            rev_loss = tf.cond(tf.reduce_any(nan_mask),
                               lambda: self.cal_rev_loss(undirected_x_batch_train, y, z, nan_mask, 0.0001, rank_weight),
                               lambda: 0.)
            backward_loss = self.w3 * rev_loss

        grads = tape.gradient(backward_loss, self.model.trainable_weights)
        #grads = [tf.clip_by_norm(g, 1.) for g in grads]
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
        self.model.encoder.trainable = True
        self.model.decoder.trainable = True

        self.loss_tracker['total_loss'].update_state(forward_loss + backward_loss)
        self.loss_tracker['reg_loss'].update_state(reg_loss)
        self.loss_tracker['latent_loss'].update_state(latent_loss)
        self.loss_tracker['rev_loss'].update_state(rev_loss)
        if self.finetune:
            self.loss_tracker['rec_loss'].update_state(rec_loss)
            self.loss_tracker['ops_loss'].update_state(ops_loss)
            self.loss_tracker['adj_loss'].update_state(adj_loss)
            self.loss_tracker['kl_loss'].update_state(kl_loss)

        return {key: value.result() for key, value in self.loss_tracker.items()}

    def test_step(self, data):
        x_batch_train, y_batch_train = data
        undirected_x_batch_train = (x_batch_train[0], to_undiredted_adj(x_batch_train[1]))
        y = y_batch_train[:, -self.y_dim:]
        z = tf.random.normal(shape=[tf.shape(y_batch_train)[0], self.z_dim])
        nan_mask = tf.where(~tf.math.is_nan(tf.reduce_sum(y, axis=-1)), x=True, y=False)
        rank_weight = get_rank_weight(tf.boolean_mask(y, nan_mask)) if self.is_rank_weight else None

        ops_cls, adj_cls, kl_loss, y_out, x_encoding = self.model(undirected_x_batch_train, kl_reduction='none', training=False)
        reg_loss, latent_loss = tf.cond(tf.reduce_any(nan_mask),
                                        lambda: self.cal_reg_and_latent_loss(y, z, y_out, nan_mask, rank_weight),
                                        lambda: (0., 0.))
        forward_loss = self.w1 * reg_loss + self.w2 * latent_loss
        rev_loss = tf.cond(tf.reduce_any(nan_mask),
                           lambda: self.cal_rev_loss(undirected_x_batch_train, y, z, nan_mask, 0., rank_weight),
                           lambda: 0.)
        backward_loss = self.w3 * rev_loss
        if self.finetune:
            ops_loss, adj_loss = cal_ops_adj_loss_for_graph(x_batch_train, ops_cls, adj_cls, self.reduction, rank_weight)
            if rank_weight is not None:
                kl_loss = tf.reduce_sum(tf.multiply(kl_loss, rank_weight))
            else:
                kl_loss = tf.reduce_mean(kl_loss)

            rec_loss = self.ops_weight * ops_loss + self.adj_weight * adj_loss + self.kl_weight * kl_loss
            forward_loss += rec_loss

        self.loss_tracker['total_loss'].update_state(forward_loss + backward_loss)
        self.loss_tracker['reg_loss'].update_state(reg_loss)
        self.loss_tracker['latent_loss'].update_state(latent_loss)
        self.loss_tracker['rev_loss'].update_state(rev_loss)
        if self.finetune:
            self.loss_tracker['rec_loss'].update_state(rec_loss)
            self.loss_tracker['ops_loss'].update_state(ops_loss)
            self.loss_tracker['adj_loss'].update_state(adj_loss)
            self.loss_tracker['kl_loss'].update_state(kl_loss)

        return {key: value.result() for key, value in self.loss_tracker.items()}

    @property
    def metrics(self):
        return [value for _, value in self.loss_tracker.items()]


def train(phase: int, model, loader, train_epochs, logdir, callbacks=None, x_dim=None, y_dim=None,
          z_dim=None, finetune=False, learning_rate=1e-3, no_valid=False):
    if phase == 1:
        trainer = Trainer1(model)
    elif phase == 2:
        trainer = Trainer2(model, x_dim, y_dim, z_dim, finetune)
    else:
        raise ValueError('phase should be 1 or 2')

    try:
        if no_valid:
            kw = {'steps_per_epoch': loader['train'].steps_per_epoch}
        else:
            kw = {'validation_steps': loader['valid'].steps_per_epoch,
                  'steps_per_epoch': loader['train'].steps_per_epoch}
    except:
        kw = {}

    if phase == 2 and no_valid:
        learning_rate = tf.keras.optimizers.schedules.CosineDecay(learning_rate, train_epochs * loader['train'].steps_per_epoch)

    trainer.compile(optimizer=tf.keras.optimizers.Adam(learning_rate), run_eagerly=False)
    if no_valid:
        trainer.fit(loader['train'].load(),
                    epochs=train_epochs,
                    callbacks=callbacks,
                    **kw)
    else:
        trainer.fit(loader['train'].load(),
                    validation_data=loader['valid'].load(),
                    epochs=train_epochs,
                    callbacks=callbacks,
                    **kw)
    model.save_weights(os.path.join(logdir, f'modelGAE_weights_phase{phase}'))
    return trainer

def to_loader(datasets, batch_size: int, epochs: int, repeat: True):
    loader = {}
    
    # Create a deep copy of the datasets to avoid modifying the original one
    datasets_copy = deepcopy(datasets)
    
    # Ensure 'train' dataset length is a multiple of batch_size
    if len(datasets_copy['train'].graphs) % batch_size != 0 and repeat:
        num_extra = batch_size - (len(datasets_copy['train'].graphs) % batch_size)
        extra_indices = np.random.choice(len(datasets_copy['train'].graphs), num_extra, replace=True)
        extra_graphs = [datasets_copy['train'].graphs[i] for i in extra_indices]
        datasets_copy['train'].graphs.extend(extra_graphs)
    
    for key, value in datasets_copy.items():
        if (key == 'train' or key == 'valid') and repeat:
            loader[key] = PackedBatchLoader(value, batch_size=batch_size, shuffle=True, epochs=epochs + 1)
        elif key == 'test' or key == 'train_1' or not repeat:
            loader[key] = PackedBatchLoader(value, batch_size=batch_size, shuffle=False, epochs=1)
    
    return loader

def get_new_archs_and_add_to_dataset(dataset_name, datasets, top_k, repeat,
                                     found_arch_list_set, visited, top_acc_list, top_test_acc_list):
    """
    Select top_k architectures from arch_list_set and query true accuracy then add to dataset
    """
    logger = logging.getLogger(__name__)
    num_new_found = 0

    # Select top-k to evaluate true label and add to training dataset
    found_arch_list_set = sorted(found_arch_list_set, key=lambda g: g['y'], reverse=True)[:top_k]
    acc_list = query_tabular(dataset_name, found_arch_list_set)
    top_acc_list.extend([i['valid-accuracy'] for i in acc_list])
    top_test_acc_list.extend([i['test-accuracy'] for i in acc_list])
    for idx, _ in enumerate(found_arch_list_set):
        found_arch_list_set[idx]['y'] = np.array([acc_list[idx]['valid-accuracy']])

    if len(top_acc_list) != 0:
        logger.info('Top acc list: {}'.format(top_acc_list))
        logger.info('Top test acc list: {}'.format(top_test_acc_list))
        logger.info(f'Avg found acc {sum(top_acc_list) / len(top_acc_list)}')
        logger.info(f'Best found acc {max(top_acc_list)}')
        logger.info(f'Avg found test acc {sum(top_test_acc_list) / len(top_test_acc_list)}')
        logger.info(f'Best found test acc {max(top_test_acc_list)}')
    else:
        logger.info('Top acc list is [] in this run')

    # Add top found architecture to training dataset
    if dataset_name == 'nb101':
        valid_visited = {nb101_dataset.get_spec_hash(i.a, np.argmax(i.x, axis=-1)): i.y.tolist() for i in
                         datasets['valid'].graphs if not np.isnan(i.y).any()}
    else:
        valid_visited = {graph_to_str(i): i.y.tolist() for i in datasets['valid'].graphs if not np.isnan(i.y).any()}

    for i in found_arch_list_set:
        if dataset_name == 'nb101':
            a, x = mask_padding_vertex_for_spec(i['a'], i['x'])
            graph_str = nb101_dataset.get_spec_hash(a, np.argmax(x, axis=-1))
        else:
            graph_str = graph_to_str(i)

        if graph_str not in visited:
            if graph_str not in valid_visited:
                logger.info(f'Data not in train and not in valid {i["y"].tolist()}')
                num_new_found += 1
            else:
                logger.info(f'Data in valid but not in train {i["y"].tolist()}')

            logger.info(f'Add to train {i["x"].tolist()} {i["a"].tolist()} {i["y"].tolist()}')
            datasets['train'].graphs.extend([Graph(x=i['x'], a=i['a'], y=i['y'])] * repeat)
            datasets['train_1'].graphs.extend([Graph(x=i['x'], a=i['a'], y=i['y'])])

    return num_new_found, found_arch_list_set


random_arch_generator: RandomArchGenerator = None


def sample_arch_candidates(model, dataset_name, x_dim, z_dim, visited, sample_amount=200):
    logger = logging.getLogger(__name__)
    found_arch_list_set = []
    visited_arch = []
    max_retry = 10
    std_idx = 0
    noise_std_list = [0.0, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 0.1, 0.2]
    amount_scale_list = [1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2, 2.4, 2.8, 3.0, 3.0]
    while len(found_arch_list_set) < sample_amount and std_idx < len(noise_std_list):
        retry = 0
        while len(found_arch_list_set) < sample_amount and retry < max_retry:
            _, _, _, found_arch_list = eval_query_best(model, dataset_name, x_dim, z_dim,
                                                       noise_scale=noise_std_list[std_idx],
                                                       query_amount=int(sample_amount * amount_scale_list[std_idx]) // model.num_nvp)
            if dataset_name == 'nb101':
                found_arch_list = list(map(mask_for_model, found_arch_list))
                found_arch_list = list(filter(lambda arch: arch is not None and arch['x'] is not None, found_arch_list))
                found_in_this_round = []
                for arch in found_arch_list:
                    a, x = mask_padding_vertex_for_spec(arch['a'], arch['x'])
                    spec_hash = nb101_dataset.get_spec_hash(a, np.argmax(x, axis=-1))
                    if spec_hash not in visited_arch and spec_hash not in visited:
                        found_in_this_round.append(arch)
                        visited_arch.append(spec_hash)
            else:
                found_in_this_round = list(filter(lambda arch: graph_to_str(arch) not in visited_arch and graph_to_str(arch) not in visited, found_arch_list))
                found_in_this_round = arch_list_to_set(found_in_this_round)
                visited_arch.extend(list(map(graph_to_str, found_in_this_round)))

            if len(found_in_this_round) + len(found_arch_list_set) > sample_amount:
                random.shuffle(found_in_this_round)
                found_in_this_round = found_in_this_round[: sample_amount - len(found_arch_list_set)]

            found_arch_list_set.extend(found_in_this_round)
            retry += 1

        logger.info(f'std scale {noise_std_list[std_idx]}, num sample {len(found_arch_list_set)}')
        std_idx += 1

    # if retry == max_retry and len(found_arch_list_set) < sample_amount:
    #    model.set_weights_from_self_ckpt()
    #    logging.getLogger(__name__).info('Reset model weights')
    #    return None

    return found_arch_list_set


def predict_arch_acc(found_arch_list_set, model, theta):
    """
    Predict accuracy by INN (performance predictor) with theta weight and assign
     the predicted value to found_arch_list_set
    """
    x = tf.stack([tf.constant(i['x']) for i in found_arch_list_set])
    a = tf.stack([tf.constant(i['a']) for i in found_arch_list_set])
    if tf.shape(x)[0] != 0:
        _, _, _, reg, _ = model((x, to_undiredted_adj(a)), training=False)  # (batch, num_nvp, z_dim+y_dim)
        reg = reg[:, :, -1]  # (batch, num_nvp)
        theta_expanded = tf.expand_dims(theta, axis=0)  # (1, num_nvp)
        reg = reg * tf.tile(theta_expanded, (tf.shape(reg)[0], 1))
        reg = tf.reduce_sum(reg, axis=-1)  # (batch, )
        for i in range(len(found_arch_list_set)):
            # Assign the predicted accuracy to the architecture
            found_arch_list_set[i]['y'] = reg[i]


def retrain(trainer, datasets, dataset_name, batch_size, train_epochs, logdir, repeat, top_k=5, random_sample=False):
    logger = logging.getLogger(__name__)
    # Generate total 100 architectures
    if dataset_name == 'nb101':
        visited = {nb101_dataset.get_spec_hash(i.a, np.argmax(i.x, axis=-1)): i.y.tolist()
                   for i in list(map(graph_to_spec_graph, datasets['train'].graphs)) if not np.isnan(i.y).any()}
    else:
        visited = {graph_to_str(i): i.y.tolist() for i in datasets['train'].graphs if not np.isnan(i.y).any()}

    # Generate total at least sample amount architectures
    if random_sample:
        found_arch_list_set = random_arch_generator.random_sample(visited, sample_amount=100)
    else:
        found_arch_list_set = sample_arch_candidates(trainer.model, dataset_name, trainer.x_dim, trainer.z_dim, visited,
                                                     sample_amount=100)
    '''
    if found_arch_list_set is None:
        loader = to_loader(datasets, batch_size, epochs=500)
        callbacks = [CSVLogger(os.path.join(logdir, f"learning_curve_retrain_from_scratch.csv")),
                     tf.keras.callbacks.ReduceLROnPlateau(monitor='val_total_loss', factor=0.1, patience=50 // 2,
                                                          verbose=1,
                                                          min_lr=1e-5),
                     EarlyStopping(monitor='val_total_loss', patience=50, restore_best_weights=True)]
        trainer.fit(loader['train'].load(),
                    validation_data=loader['valid'].load(),
                    epochs=500,
                    callbacks=callbacks,
                    steps_per_epoch=loader['train'].steps_per_epoch,
                    validation_steps=loader['valid'].steps_per_epoch)

        found_arch_list_set = sample_arch_candidates(trainer.model, dataset_name, trainer.x_dim, trainer.z_dim, visited,
                                                     sample_amount=200)
    '''
    theta = get_theta(trainer.model, datasets['train_1'])

    logger.info(f'Theta: {theta.numpy().tolist()}')

    logger.info(f'Length of found_arch_list_set {len(found_arch_list_set)}')
    top_acc_list = []
    top_test_acc_list = []

    # Predict accuracy by INN (performance predictor)
    predict_arch_acc(found_arch_list_set, trainer.model, theta)
    found_arch_list_set = sorted(found_arch_list_set, key=lambda g: g['y'], reverse=True)

    num_new_found, top_arch_list_set = get_new_archs_and_add_to_dataset(dataset_name, datasets, top_k,
                                                                        repeat, found_arch_list_set, visited,
                                                                        top_acc_list, top_test_acc_list)
    '''
    if num_new_found == 0:
        logger.info('No new architecture found, filter the found_arch_list_set')
        train_graph_set = [graph_to_str(i) for i in datasets['train_1'].graphs]
        found_arch_list_set = list(filter(lambda arch: graph_to_str(arch) not in train_graph_set, found_arch_list_set))
        logger.info(f'Length of found_arch_list_set after filter {len(found_arch_list_set)}')
        num_new_found, top_arch_list_set = get_new_archs_and_add_to_dataset(dataset_name, datasets, top_k,
                                                                            repeat, top_list, found_arch_list_set, visited,
                                                                            top_acc_list, top_test_acc_list)
    '''

    logger.info(f'{datasets["train"]}')
    logger.info(f'{datasets["train_1"]}')

    loader = to_loader(datasets, batch_size, train_epochs, True)
    callbacks = [CSVLogger(os.path.join(logdir, f"learning_curve_phase2_retrain.csv")),
                 # tf.keras.callbacks.ReduceLROnPlateau(monitor='val_total_loss', factor=0.1, patience=15, verbose=1,
                 #                                     min_lr=1e-6),
                 EarlyStopping(monitor='val_total_loss', patience=10, restore_best_weights=True)
                 ]

    # tf.keras.backend.set_value(trainer.optimizer.learning_rate, 1e-3)
    trainer.fit(loader['train'].load(),
                validation_data=loader['valid'].load(),
                epochs=train_epochs,
                callbacks=callbacks,
                steps_per_epoch=loader['train'].steps_per_epoch,
                validation_steps=loader['valid'].steps_per_epoch
                )

    #results = trainer.evaluate(loader['test'].load(), steps=loader['test'].steps_per_epoch)
    #logger.info(str(dict(zip(trainer.metrics_names, results))))
    return top_acc_list, top_test_acc_list, top_arch_list_set, num_new_found


def get_theta(model, dataset, eps: float = 1e-8):
    import tensorflow as tf
    import numpy as np

    num_nvp = len(model.nvp_list)
    n       = len(dataset)

    loader        = PackedBatchLoader(dataset, batch_size=n, epochs=1, shuffle=False)
    (x, a), _     = next(loader.load())
    a             = to_undiredted_adj(a)                     
    y_true        = tf.constant(
        [float(np.squeeze(d.y)) for d in dataset], dtype=tf.float32
    )                                                       # shape (n,)
    y_true        = tf.expand_dims(y_true, 1)               # (n,1) 

    _, _, _, regs, _ = model((x, a), training=False)        # regs: (n, num_nvp, z_dim + y_dim)
    preds = tf.squeeze(regs[:, :, -1])                      # (n, num_nvp)

    mse_per_model = tf.reduce_mean(tf.square(preds - y_true), axis=0)   # (num_nvp,)

    inv_mse = 1.0 / (mse_per_model + eps)
    theta   = inv_mse / tf.reduce_sum(inv_mse)              # (num_nvp,)

    return theta

def prepare_model(num_nvp, nvp_config, latent_dim, num_layers, d_model, num_heads, dff, num_ops, num_nodes, num_adjs,
                  dropout_rate, eps_scale):
    pretrained_model = GraphAutoencoder(latent_dim=latent_dim, num_layers=num_layers,
                                        d_model=d_model, num_heads=num_heads,
                                        dff=dff, num_ops=num_ops, num_nodes=num_nodes,
                                        num_adjs=num_adjs, dropout_rate=dropout_rate, eps_scale=eps_scale)
    pretrained_model(
        (tf.random.normal(shape=(1, num_nodes, num_ops)), tf.random.normal(shape=(1, num_nodes, num_nodes))))

    model = GraphAutoencoderEnsembleNVP(num_nvp, nvp_config=nvp_config, latent_dim=latent_dim, num_layers=num_layers,
                                        d_model=d_model, num_heads=num_heads,
                                        dff=dff, num_ops=num_ops, num_nodes=num_nodes,
                                        num_adjs=num_adjs, dropout_rate=dropout_rate, eps_scale=eps_scale)
    model((tf.random.normal(shape=(1, num_nodes, num_ops)), tf.random.normal(shape=(1, num_nodes, num_nodes))))
    '''
    retrain_model = GraphAutoencoderEnsembleNVP(num_nvp, nvp_config=nvp_config, latent_dim=latent_dim,
                                                num_layers=num_layers,
                                                d_model=d_model, num_heads=num_heads,
                                                dff=dff, num_ops=num_ops, num_nodes=num_nodes,
                                                num_adjs=num_adjs, dropout_rate=dropout_rate, eps_scale=eps_scale)
    retrain_model((tf.random.normal(shape=(1, num_nodes, num_ops)), tf.random.normal(shape=(1, num_nodes, num_nodes))))
    '''
    return pretrained_model, model  #, retrain_model

def split_datasets_by_op(datasets, op_index=3):
    contains_op = []
    not_contains_op = []

    for graph in datasets:
        op_matrix = graph.x  
        if op_matrix[:, op_index].sum() > 0:  
            contains_op.append(graph)
        else:
            not_contains_op.append(graph)
    return contains_op, not_contains_op

def pad_vectors_to_same_dim(vectors_a, vectors_b):
    dim_a = vectors_a[0].shape[0]
    dim_b = vectors_b[0].shape[0]
    print("dim_a",dim_a)
    print("dim_b",dim_b)
    
    if dim_a > dim_b:
        for i in range(len(vectors_b)):
            padding = np.random.normal(loc=0, scale=0.15, size=(dim_a- dim_b))
            vectors_b[i] = np.concatenate([vectors_b[i], padding]).astype(np.float32)
    elif dim_b > dim_a:
        for i in range(len(vectors_a)):
            padding = np.random.normal(loc=0, scale=0.15, size=(dim_b - dim_a,))
            vectors_a[i] = np.concatenate([vectors_a[i], padding]).astype(np.float32)
    
    return vectors_a, vectors_b

class EncodedDataset(Dataset):
    def __init__(self, vectors_a, vectors_b):
        self.vectors_a, self.vectors_b = shuffle(vectors_a, vectors_b)
    
    def __len__(self):
        return len(self.vectors_a)
    
    def __getitem__(self, idx):
        return {'src': torch.tensor(self.vectors_a[idx], dtype=torch.float32),
                'tgt': torch.tensor(self.vectors_b[idx], dtype=torch.float32)}

OP_NAMES = ["Identity", "Zero", "ReLUConvBN3x3", "ReLUConvBN1x1", "AvgPool1x1"]

csv_to_nb201_op = {
    "Identity":      "skip_connect",
    "Zero":          "none",
    "ReLUConvBN3x3": "nor_conv_3x3",
    "ReLUConvBN1x1": "nor_conv_1x1",
    "AvgPool1x1":    "avg_pool_3x3",
}

NB201_op_to_index = {
    'input':        1,
    'output':       0,
    'nor_conv_1x1': 2,
    'nor_conv_3x3': 3,
    'avg_pool_3x3': 4,
    'skip_connect': 5,
    'none':         6
}

def convert_csv_spec_to_nb201_arch(csv_path, dataset_name):
    df = pd.read_csv(csv_path)
    nb201_arch_list = []

    for idx, row in df.iterrows():
        spec_str = row['spec'].strip()
        if spec_str.startswith("(") and spec_str.endswith(")"):
            spec_str = spec_str[1:-1]
        tokens = [t.strip() for t in spec_str.split(',')]
        if len(tokens) != 6:
            continue
        
        csv_op_indices = [int(t) for t in tokens]  # ex: [0,2,4,4,1,2]
        csv_ops = [OP_NAMES[i] for i in csv_op_indices]  
        # ex: ["Identity", "ReLUConvBN3x3", "AvgPool1x1", ...]

        nb201_mid_ops = [csv_to_nb201_op[op] for op in csv_ops]  
        # ex: ["skip_connect", "nor_conv_3x3", "avg_pool_3x3", ...]

        # ex: ["input", ...6..., "output"]
        nb201_ops_str = ['input'] + nb201_mid_ops + ['output']

        # ex: [1, 3, 4, 0, ...]
        nb201_ops_idx = [NB201_op_to_index[op_str] for op_str in nb201_ops_str]

        a = ADJACENCY  # shape=(8,8)

        acc = query_acc_by_ops(nb201_ops_idx, dataset_name, is_random=False)

        ops = np.eye(7)[nb201_ops_idx]

        nb201_arch_list.append(Graph(x=ops.astype(np.float32),a=a.astype(np.float32),y=np.array([acc], dtype=np.float32)))

    return nb201_arch_list

def main(seed, dataset_name, train_sample_amount, valid_sample_amount, query_budget,
         top_k, finetune, retrain_finetune, is_rank_weight, random_sample, num_couples, n_couple_layer, n_hid_layer, n_hid_dim):
    logdir, logger = get_logdir_and_logger(os.path.join(f'{train_sample_amount}_{valid_sample_amount}_{query_budget}_top{top_k}_finetune{finetune}_rfinetune{retrain_finetune}_rank{is_rank_weight}_ensemble_{num_couples}NN_{n_couple_layer}*{n_hid_layer}*{n_hid_dim}',
                                                        dataset_name), f'trainGAE_ensemble_{seed}.log')
    random_seed = seed
    set_global_determinism(random_seed)

    is_only_validation_data = True
    train_phase = [0, 1]

    pretrained_weight = 'logs/phase1_nb201_CE_64/modelGAE_weights_phase1'

    eps_scale = 0.05  # 0.1
    d_model = 32
    dropout_rate = 0.0
    dff = 256
    num_layers = 3
    num_heads = 3

    latent_dim = 16

    global random_arch_generator
    random_arch_generator = RandomArchGenerator(dataset_name, is_only_validation_data)

    if dataset_name != 'nb101':
        num_ops = len(OP_PRIMITIVES_NB201)  # 7
        num_nodes = 8
        num_adjs = num_nodes ** 2
        label_epochs = 200
        if os.path.exists(f'datasets/NasBench201Dataset_{dataset_name}.cache'):
            datasets = pickle.load(open(f'datasets/NasBench201Dataset_{dataset_name}.cache', 'rb'))
        else:
            datasets = NasBench201Dataset(start=0, end=15624, dataset=dataset_name, hp=str(label_epochs), seed=False)
            with open(f'datasets/NasBench201Dataset_{dataset_name}.cache', 'wb') as f:
                pickle.dump(datasets, f)
        for graph in datasets.graphs:
            graph.y = np.array([[graph.y[1, -1]]]).astype(np.float32)

        csv_file = "./init_arch/top30_architectures_by_nwot_nb201_noconv33.csv"  
        arch_list = convert_csv_spec_to_nb201_arch(csv_file, dataset_name)

    # else: skip nb101 

    datasets = train_valid_test_split_dataset(datasets,
                                              ratio=[0.8, 0.1, 0.1],
                                              shuffle=True,
                                              shuffle_seed=random_seed)

    x_dim = latent_dim * num_nodes
    y_dim = 1  # 1
    z_dim = x_dim - 1  # 127
    #z_dim = latent_dim * 4 - 1
    tot_dim = y_dim + z_dim  # 28
    #pad_dim = tot_dim - x_dim  # 14

    num_nvp = 10
    nvp_config = {
        'n_couple_layer': n_couple_layer,
        'n_hid_layer': n_hid_layer,
        'n_hid_dim': n_hid_dim,
        'name': 'NVP',
        'num_couples': num_couples,
        'inp_dim': tot_dim
    }

    pretrained_model, model = prepare_model(num_nvp, nvp_config, latent_dim, num_layers, d_model, num_heads, dff,
                                                           num_ops, num_nodes, num_adjs, dropout_rate, eps_scale)
    model.summary(print_fn=logger.info)

    tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)

    if train_phase[0]:
        logger.info('Train phase 1')
        train_epochs = 500
        patience = 100
        batch_size = 64
        loader = to_loader(datasets, batch_size, train_epochs, True)
        callbacks = [CSVLogger(os.path.join(logdir, "learning_curve_phase1.csv")),
                     tf.keras.callbacks.ReduceLROnPlateau(monitor='val_rec_loss', factor=0.1, patience=patience // 2,
                                                          verbose=1, min_lr=1e-5),
                     tensorboard_callback,
                     EarlyStopping(monitor='val_rec_loss', patience=patience, restore_best_weights=True)]
        trainer = train(1, pretrained_model, loader, train_epochs, logdir, callbacks)
        results = trainer.evaluate(loader['test'].load(), steps=loader['test'].steps_per_epoch)
        logger.info(f'{dict(zip(trainer.metrics_names, results))}')
    else:
        pretrained_model.load_weights(pretrained_weight)

    # Load AE weights from pretrained model
    model.encoder.set_weights(pretrained_model.encoder.get_weights())
    model.decoder.set_weights(pretrained_model.decoder.get_weights())

    global_top_acc_list = []
    global_top_test_acc_list = []
    global_top_arch_list = []
    history_top = 0
    patience_cot = 0
    record_top = {'valid': [], 'test': []}

    datasets = pickle.load(open(f'datasets/NasBench201Dataset_{dataset_name}.cache', 'rb'))
    datasets = train_valid_test_split_dataset(datasets,
                                              ratio=[0.8, 0.1, 0.1],
                                              shuffle=True,
                                              shuffle_seed=random_seed)
    
    for key in datasets:
        if is_only_validation_data:
            datasets[key].apply(OnlyValidAccTransform())
            datasets[key].apply(OnlyFinalAcc())
            if dataset_name != 'nb101':
                datasets[key].apply(LabelScale(scale=0.01))
        else:
            datasets[key].apply(ReshapeYTransform())

    if train_phase[1]:
        batch_size = 64
        train_epochs = 500
        retrain_epochs = 50
        patience = 50
        repeat_label = 20
        now_queried = train_sample_amount + valid_sample_amount
        logger.info('Train phase 2')
        datasets['train_1'] = mask_graph_dataset(datasets['train'], train_sample_amount, 1, random_seed=random_seed)
        datasets['valid_1'] = mask_graph_dataset(datasets['valid'], valid_sample_amount, 1, random_seed=random_seed)
        datasets['train_1'].filter(lambda g: not np.isnan(g.y))
        datasets['valid_1'].filter(lambda g: not np.isnan(g.y))
        datasets['train_1'].graphs = []
        for arch in arch_list[:30]:
            print(arch['y'])
            datasets['train_1'].graphs.append(Graph(x=arch['x'],a=arch['a'],y=arch['y']))

        # Add initial data to records
        acc_list = query_tabular(dataset_name, datasets['train_1'])
        global_top_acc_list.extend([i['valid-accuracy'] for i in acc_list])
        global_top_test_acc_list.extend([i['test-accuracy'] for i in acc_list])
        acc_list = query_tabular(dataset_name, datasets['valid_1'])
        global_top_acc_list.extend([i['valid-accuracy'] for i in acc_list])
        global_top_test_acc_list.extend([i['test-accuracy'] for i in acc_list])

        datasets['train'] = repeat_graph_dataset_element(datasets['train_1'], repeat_label)
        datasets['valid'] = repeat_graph_dataset_element(datasets['valid_1'], repeat_label)

        loader = to_loader(datasets, batch_size, train_epochs, True)
        callbacks = [CSVLogger(os.path.join(logdir, f"learning_curve_phase2.csv")),
                     #tensorboard_callback,
                     tf.keras.callbacks.ReduceLROnPlateau(monitor='val_total_loss', factor=0.1,
                                                          patience=patience // 2, verbose=1, min_lr=1e-5),
                     EarlyStopping(monitor='val_total_loss', patience=patience, restore_best_weights=True)
                     ]
        trainer = train(2, model, loader, train_epochs, logdir, callbacks,
                        x_dim=x_dim, y_dim=y_dim, z_dim=z_dim, finetune=finetune, learning_rate=1e-3, no_valid=False)
        results = trainer.evaluate(loader['test'].load(), steps=loader['test'].steps_per_epoch)
        logger.info(str(dict(zip(trainer.metrics_names, results))))

        # Recreate Trainer for retrain
        #retrain_model.set_weights(model.get_weights())
        trainer = Trainer2(model, x_dim, y_dim, z_dim, finetune=retrain_finetune, is_rank_weight=is_rank_weight)
        trainer.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), run_eagerly=False)

        run = 0
        while now_queried < query_budget and run <= 450:
            logger.info('')
            logger.info(f'Retrain run {run}')
            top_acc_list, top_test_acc_list, top_arch_list, num_new_found = retrain(trainer, datasets, dataset_name,
                                                                                    batch_size,
                                                                                    retrain_epochs, logdir,
                                                                                    repeat_label, top_k, random_sample)

            global_top_acc_list += top_acc_list
            global_top_test_acc_list += top_test_acc_list
            global_top_arch_list += top_arch_list
            now_queried += num_new_found
            logger.info(f'Now queried: {now_queried}')

            if now_queried > query_budget:
                global_top_test_acc_list = global_top_test_acc_list[: query_budget]
                global_top_acc_list = global_top_acc_list[: query_budget]
                global_top_arch_list = global_top_arch_list[: query_budget]
                break
            run += 1
            record_top['valid'].append({now_queried: sorted(global_top_acc_list, reverse=True)[:5]})
            record_top['test'].append({now_queried: sorted(global_top_test_acc_list, reverse=True)[:5]})

            '''
            if len(top_acc_list) != 0 and max(top_acc_list) > history_top:
                history_top = max(top_acc_list)
                patience_cot = 0
            else:
                patience_cot += 1
            '''

            logger.info(f'History top 5 acc: {sorted(global_top_acc_list, reverse=True)[:5]}')
            logger.info(f'History top 5 test acc: {sorted(global_top_test_acc_list, reverse=True)[:5]}')
            #if patience_cot >= patience:
            #    break

            target_acc = {'cifar10-valid': 0.9160, 'cifar100': 0.7349, 'ImageNet16-120': [0.4673, 0.4731],
                          'nb101': [0.9505, 0.94317]}
            if dataset_name in ['ImageNet16-120', 'nb101']:
                if max(global_top_acc_list) > target_acc.get(dataset_name, 1.0)[0] and max(global_top_test_acc_list) > \
                        target_acc.get(dataset_name, 1.0)[1]:
                    logger.info(f'Find optimal query amount {now_queried}')
                    break
            else:
                if max(global_top_acc_list) > target_acc.get(dataset_name, 1.0):
                    logger.info(f'Find optimal query amount {now_queried}')
                    break

    else:
        model.load_weights(pretrained_weight)

    logger.info('Final result')
    logger.info(f'Best found acc {max(global_top_acc_list)}')
    logger.info(f'Best test acc {max(global_top_test_acc_list)}')
    return max(global_top_acc_list), max(global_top_test_acc_list), record_top


if __name__ == '__main__':
    args = parse_args()
    #os.environ['CUDA_VISIBLE_DEVICES'] = ''
    main(args.seed, args.dataset, args.train_sample_amount, args.valid_sample_amount, args.query_budget,
         top_k=5, finetune=False, retrain_finetune=False, is_rank_weight=True, random_sample=False, num_couples=2,
         n_couple_layer=4, n_hid_layer=5, n_hid_dim=256)

