import argparse
import pickle
import random
import numpy as np
from tensorflow.python.keras.callbacks import CSVLogger, EarlyStopping, LearningRateScheduler, Callback
from invertible_neural_networks.flow import MMD_multiscale
from models.GNN import GraphAutoencoder, GraphAutoencoderNVP, get_rank_weight
from models.TransformerAE import TransformerAutoencoderNVP
import tensorflow as tf
import os
from datasets.NASBench301 import NasBench301Dataset, OP_PRIMITIVES_NB301
from datasets.utils import train_valid_test_split_dataset, mask_graph_dataset, repeat_graph_dataset_element
from spektral.data import PackedBatchLoader
from evalGAE_nb301 import eval_query_target, eval_from_lat_v2
from utils.py_utils import get_logdir_and_logger
from spektral.data import Graph
from utils.tf_utils import set_global_determinism
import logging
import random
import gc
import warnings
from sklearn.cluster import KMeans
from copy import deepcopy
from tensorflow_addons.optimizers import AdamW
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.utils import shuffle
from Flow2FlowAgent import Flow2FlowAgent
import pandas as pd
import ast
warnings.filterwarnings("ignore", category=UserWarning, module='tensorflow.python.framework.indexed_slices')

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_sample_amount', type=int, default=30, help='Number of samples to train (default: 50)')
    parser.add_argument('--valid_sample_amount', type=int, default=10, help='Number of samples to train (default: 10)')
    parser.add_argument('--sample_amount', type=int, default=100)
    parser.add_argument('--query_budget', type=int, default=150)
    parser.add_argument('--switch_ensemble', type=int, default=1)
    parser.add_argument('--target', type=float, default=1.0)
    parser.add_argument('--max_without_improvement', type=int, default=5)
    parser.add_argument('--decrease_value', type=float, default=0.1)
    parser.add_argument('--hash_version', type=int, default=2)
    parser.add_argument('--undirected', type=bool, default=False)
    parser.add_argument('--top_k', type=int, default=5)
    parser.add_argument('--finetune', action='store_true')
    parser.add_argument('--no_finetune', dest='finetune', action='store_false')
    parser.set_defaults(finetune=False)
    parser.add_argument('--retrain_finetune', action='store_true')
    parser.add_argument('--no_retrain_finetune', dest='retrain_finetune', action='store_false')
    parser.set_defaults(retrain_finetune=False)
    parser.add_argument('--rank_weight', action='store_true')
    parser.add_argument('--no_rank_weight', dest='rank_weight', action='store_false')
    parser.set_defaults(rank_weight=True)
    parser.add_argument('--seed', type=int, default=0)
    return parser.parse_args()


def cal_ops_adj_loss_for_graph(x_batch_train, ops_cls, adj_cls, reduction='auto', rank_weight=None):
    ops_label, adj_label = x_batch_train
    # adj_label = tf.reshape(adj_label, [tf.shape(adj_label)[0], -1])
    ops_loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False, reduction=reduction)(ops_label, ops_cls)
    #ops_loss = tf.keras.losses.KLDivergence()(ops_label, ops_cls)
    adj_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False, reduction=reduction)(adj_label, adj_cls)
    if reduction == 'none':
        ops_loss = tf.reduce_mean(ops_loss, axis=-1)
        adj_loss = tf.reduce_mean(adj_loss, axis=-1)
    if rank_weight is not None:
        ops_loss = tf.reduce_sum(tf.multiply(ops_loss, rank_weight))
        adj_loss = tf.reduce_sum(tf.multiply(adj_loss, rank_weight))

    return ops_loss, adj_loss

def MMD_per_data_point(x, y):
    """
    Computes the MMD for each data point and returns the average MMD.
    
    x: Tensor of shape [batch_size, feature_dim], real data samples
    y: Tensor of shape [batch_size, feature_dim], generated data samples
    """
    # Expand dims to compute pairwise distances for each data point
    x_exp = tf.expand_dims(x, 1)  # Shape: [batch_size, 1, feature_dim]
    y_exp = tf.expand_dims(y, 1)  # Shape: [batch_size, 1, feature_dim]
    
    # Compute MMD for each data point
    mmd_values = tf.map_fn(lambda i: MMD_multiscale(x_exp[i], y), tf.range(tf.shape(x)[0]), dtype=tf.float32)
    
    # Return the average MMD across all data points
    return tf.reduce_mean(mmd_values)

class Trainer1(tf.keras.Model):
    def __init__(self, model: GraphAutoencoder):
        super(Trainer1, self).__init__()
        self.model = model
        self.ops_weight = 1
        self.adj_weight = 1
        self.kl_weight = 0.16

        self.loss_tracker = {
            'rec_loss': tf.keras.metrics.Mean(name="rec_loss"),
            'ops_loss': tf.keras.metrics.Mean(name="ops_loss"),
            'adj_loss': tf.keras.metrics.Mean(name="adj_loss"),
            'kl_loss': tf.keras.metrics.Mean(name="kl_loss")
        }

    def train_step(self, data):
        x_batch_train, _ = data
        # undirected_x_batch_train = (x_batch_train[0], to_undiredted_adj(x_batch_train[1]))

        # Open a GradientTape to record the operations run
        # during the forward pass, which enables auto-differentiation.
        # Forward loss and AE Reconstruct loss
        with tf.GradientTape() as tape:
            # Run the forward pass of the layer.
            # The operations that the layer applies
            # to its inputs are going to be recorded
            # on the GradientTape.
            ops_cls, adj_cls, kl_loss, _ = self.model(x_batch_train,
                                                      training=True)  # Logits for this minibatch
            ops_loss, adj_loss = cal_ops_adj_loss_for_graph(x_batch_train, ops_cls, adj_cls)
            rec_loss = self.ops_weight * ops_loss + self.adj_weight * adj_loss + self.kl_weight * kl_loss

        grads = tape.gradient(rec_loss, self.model.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))

        self.loss_tracker['rec_loss'].update_state(rec_loss)
        self.loss_tracker['ops_loss'].update_state(ops_loss)
        self.loss_tracker['adj_loss'].update_state(adj_loss)
        self.loss_tracker['kl_loss'].update_state(kl_loss)
        return {key: value.result() for key, value in self.loss_tracker.items()}

    def test_step(self, data):
        x_batch_train, _ = data
        # undirected_x_batch_train = (x_batch_train[0], to_undiredted_adj(x_batch_train[1]))
        ops_cls, adj_cls, kl_loss, _ = self.model(x_batch_train, training=True)  # Logits for this minibatch
        ops_loss, adj_loss = cal_ops_adj_loss_for_graph(x_batch_train, ops_cls, adj_cls)
        rec_loss = self.ops_weight * ops_loss + self.adj_weight * adj_loss + self.kl_weight * kl_loss

        self.loss_tracker['rec_loss'].update_state(rec_loss)
        self.loss_tracker['ops_loss'].update_state(ops_loss)
        self.loss_tracker['adj_loss'].update_state(adj_loss)
        self.loss_tracker['kl_loss'].update_state(kl_loss)
        return {key: value.result() for key, value in self.loss_tracker.items()}

    @property
    def metrics(self):
        return [value for _, value in self.loss_tracker.items()]

class Trainer2(tf.keras.Model):
    def __init__(self, model: TransformerAutoencoderNVP, x_dim, y_dim, z_dim, finetune=False, is_rank_weight=True):
        super(Trainer2, self).__init__()
        self.model = model
        self.x_dim = x_dim
        self.y_dim = y_dim
        self.z_dim = z_dim
        self.finetune = finetune
        self.is_rank_weight = is_rank_weight

        # For reg loss weight
        self.w1 = 5.
        # For latent loss weight
        self.w2 = 1.
        # For rev loss weight
        self.w3 = 10.

        if self.is_rank_weight:
            self.reduction = 'none'
        else:
            self.reduction = 'auto'
        self.reg_loss_fn = tf.keras.losses.MeanSquaredError(reduction=self.reduction)
        self.loss_latent = MMD_multiscale
        self.loss_backward = tf.keras.losses.MeanSquaredError(reduction=self.reduction)
        self.loss_tracker = {
            'total_loss': tf.keras.metrics.Mean(name="total_loss"),
            'reg_loss': tf.keras.metrics.Mean(name="reg_loss"),
            'latent_loss': tf.keras.metrics.Mean(name="latent_loss"),
            'rev_loss': tf.keras.metrics.Mean(name="rev_loss")
        }
        if self.finetune:
            self.loss_tracker.update({
                'rec_loss': tf.keras.metrics.Mean(name="rec_loss"),
                'ops_loss': tf.keras.metrics.Mean(name="ops_loss"),
                'adj_loss': tf.keras.metrics.Mean(name="adj_loss"),
                'kl_loss': tf.keras.metrics.Mean(name="kl_loss")
            })
            self.ops_weight = 1
            self.adj_weight = 1
            self.kl_weight = 0.16

    def cal_reg_and_latent_loss(self, y, z, y_out, nan_mask, rank_weight=None, train=False):
        reg_loss = self.reg_loss_fn(tf.boolean_mask(y, nan_mask),
                                    tf.boolean_mask(y_out[:, self.z_dim:], nan_mask))
        latent_loss = self.loss_latent(tf.boolean_mask(tf.concat([z, y], axis=-1), nan_mask),
                                       tf.boolean_mask(
                                           tf.concat([y_out[:, :self.z_dim], y_out[:, -self.y_dim:]], axis=-1),
                                           nan_mask))  # * x_batch_train.shape[0]
        if self.is_rank_weight and train:
            # reg_loss (batch_size)
            reg_loss = tf.multiply(reg_loss, rank_weight)
            reg_loss = tf.reduce_sum(reg_loss)
        return reg_loss, latent_loss

    def cal_rev_loss(self, undirected_x_batch_train, y, z, nan_mask, noise_scale, rank_weight=None, train=False):
        y = tf.boolean_mask(y, nan_mask)
        z = tf.boolean_mask(z, nan_mask)
        y = y + noise_scale * tf.random.normal(shape=tf.shape(y), dtype=tf.float32)
        non_nan_x_batch = (tf.boolean_mask(undirected_x_batch_train[0], nan_mask),
                           tf.boolean_mask(undirected_x_batch_train[1], nan_mask))
        _, _, _, _, x_encoding = self.model(non_nan_x_batch, training=True)  # Logits for this minibatch
        x_rev = self.model.inverse(tf.concat([z, y], axis=-1))
        rev_loss = self.loss_backward(x_rev, x_encoding)  # * x_batch_train.shape[0]
        if self.is_rank_weight and train:
            # rev_loss (batch_size)
            rev_loss = tf.multiply(rev_loss, rank_weight)
            rev_loss = tf.reduce_sum(rev_loss)
        return rev_loss

    def train_step(self, data):
        if not self.finetune:
            self.model.encoder.trainable = False
            self.model.decoder.trainable = False

        x_batch_train, y_batch_train = data
        # undirected_x_batch_train = (x_batch_train[0], to_undiredted_adj(x_batch_train[1]))
        y = y_batch_train[:, -self.y_dim:]
        z = tf.random.normal([tf.shape(y_batch_train)[0], self.z_dim])
        nan_mask = tf.where(~tf.math.is_nan(tf.reduce_sum(y, axis=-1)), x=True, y=False)
        rank_weight = get_rank_weight(tf.boolean_mask(y, nan_mask)) if self.is_rank_weight else None

        with tf.GradientTape() as tape:
            ops_cls, adj_cls, kl_loss, y_out, x_encoding = self.model(x_batch_train, kl_reduction='none', training=True)

            # To avoid nan loss when batch size is small
            reg_loss, latent_loss = tf.cond(tf.reduce_any(nan_mask),
                                            lambda: self.cal_reg_and_latent_loss(y, z, y_out, nan_mask, rank_weight, True),
                                            lambda: (0., 0.))
            forward_loss = self.w1 * reg_loss + self.w2 * latent_loss
            rec_loss = 0. 
            if self.finetune:
                ops_loss, adj_loss = cal_ops_adj_loss_for_graph(x_batch_train, ops_cls, adj_cls, 'auto', None)
                '''
                if rank_weight is not None:
                    kl_loss = tf.reduce_sum(tf.multiply(kl_loss, rank_weight))
                else:
                    kl_loss = tf.reduce_mean(kl_loss)
                '''
                kl_loss = tf.reduce_mean(kl_loss)
                rec_loss = self.ops_weight * ops_loss + self.adj_weight * adj_loss + self.kl_weight * kl_loss
                forward_loss += rec_loss

        grads = tape.gradient(forward_loss, self.model.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
        forward_loss += latent_loss

        # Backward loss
        with tf.GradientTape() as tape:
            self.model.encoder.trainable = False
            self.model.decoder.trainable = False
            # To avoid nan loss when batch size is small
            rev_loss = tf.cond(tf.reduce_any(nan_mask),
                               lambda: self.cal_rev_loss(x_batch_train, y, z, nan_mask, 0.0001, rank_weight, True),
                               lambda: 0.)
            # l2_loss = tf.add_n([self.regularizer(w) for w in self.model.trainable_weights])
            backward_loss = self.w3 * rev_loss

        grads = tape.gradient(backward_loss, self.model.trainable_weights)
        # grads = [tf.clip_by_norm(g, 1.) for g in grads]
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
        self.model.encoder.trainable = True
        self.model.decoder.trainable = True

        self.loss_tracker['total_loss'].update_state(forward_loss + backward_loss)
        self.loss_tracker['reg_loss'].update_state(reg_loss)
        self.loss_tracker['latent_loss'].update_state(latent_loss)
        self.loss_tracker['rev_loss'].update_state(rev_loss)
        # self.loss_tracker['l2_loss'].update_state(l2_loss)
        if self.finetune:
            self.loss_tracker['rec_loss'].update_state(rec_loss)
            self.loss_tracker['ops_loss'].update_state(ops_loss)
            self.loss_tracker['adj_loss'].update_state(adj_loss)
            self.loss_tracker['kl_loss'].update_state(kl_loss)

        return {key: value.result() for key, value in self.loss_tracker.items()}

    def test_step(self, data):
        self.model.encoder.trainable = False
        self.model.decoder.trainable = False

        x_batch_train, y_batch_train = data
        # undirected_x_batch_train = (x_batch_train[0], to_undiredted_adj(x_batch_train[1]))
        y = y_batch_train[:, -self.y_dim:]
        z = tf.random.normal(shape=[tf.shape(y_batch_train)[0], self.z_dim])
        nan_mask = tf.where(~tf.math.is_nan(tf.reduce_sum(y, axis=-1)), x=True, y=False)
        rank_weight = get_rank_weight(tf.boolean_mask(y, nan_mask)) if self.is_rank_weight else None

        ops_cls, adj_cls, kl_loss, y_out, x_encoding = self.model(x_batch_train, kl_reduction='none', training=False)
        reg_loss, latent_loss = tf.cond(tf.reduce_any(nan_mask),
                                        lambda: self.cal_reg_and_latent_loss(y, z, y_out, nan_mask, rank_weight),
                                        lambda: (0., 0.))
        forward_loss = self.w1 * reg_loss + self.w2 * latent_loss
        rev_loss = tf.cond(tf.reduce_any(nan_mask),
                           lambda: self.cal_rev_loss(x_batch_train, y, z, nan_mask, 0., rank_weight),
                           lambda: 0.)
        # l2_loss = tf.add_n([self.regularizer(w) for w in self.model.trainable_weights])
        # backward_loss = self.w3 * rev_loss + l2_loss
        backward_loss = self.w3 * rev_loss
        if self.finetune:
            ops_loss, adj_loss = cal_ops_adj_loss_for_graph(x_batch_train, ops_cls, adj_cls, 'auto', None)
            '''
            if rank_weight is not None:
                kl_loss = tf.reduce_sum(tf.multiply(kl_loss, rank_weight))
            else:
                kl_loss = tf.reduce_mean(kl_loss)
            '''
            kl_loss = tf.reduce_mean(kl_loss)
            rec_loss = self.ops_weight * ops_loss + self.adj_weight * adj_loss + self.kl_weight * kl_loss
            forward_loss += rec_loss

        self.loss_tracker['total_loss'].update_state(forward_loss + backward_loss)
        self.loss_tracker['reg_loss'].update_state(reg_loss)
        self.loss_tracker['latent_loss'].update_state(latent_loss)
        self.loss_tracker['rev_loss'].update_state(rev_loss)
        # self.loss_tracker['l2_loss'].update_state(l2_loss)
        if self.finetune:
            self.loss_tracker['rec_loss'].update_state(rec_loss)
            self.loss_tracker['ops_loss'].update_state(ops_loss)
            self.loss_tracker['adj_loss'].update_state(adj_loss)
            self.loss_tracker['kl_loss'].update_state(kl_loss)

        self.model.encoder.trainable = True
        self.model.decoder.trainable = True

        return {key: value.result() for key, value in self.loss_tracker.items()}

    @property
    def metrics(self):
        return [value for _, value in self.loss_tracker.items()]

def train(phase: int, model, loader, train_epochs, logdir, callbacks=None, x_dim=None, y_dim=None,
          z_dim=None, finetune=False, learning_rate=1e-3):
    if phase == 1:
        trainer = Trainer1(model)
    elif phase == 2:
        trainer = Trainer2(model, x_dim, y_dim, z_dim, finetune)
    else:
        raise ValueError('phase should be 1 or 2')

    try:
        kw = {'validation_steps': loader['valid'].steps_per_epoch,
              'steps_per_epoch': loader['train'].steps_per_epoch}
    except:
        kw = {}

    trainer.compile(optimizer=tf.keras.optimizers.Adam(learning_rate), run_eagerly=False)
    trainer.fit(loader['train'].load(),
                validation_data=loader['valid'].load(),
                epochs=train_epochs,
                callbacks=callbacks,
                **kw)
    model.save_weights(os.path.join(logdir, f'modelGAE_weights_phase{phase}'))
    return trainer

def to_loader(datasets, batch_size: int, epochs: int, repeat: True):
    loader = {}
    
    # Create a deep copy of the datasets to avoid modifying the original one
    datasets_copy = deepcopy(datasets)
    
    # Ensure 'train' dataset length is a multiple of batch_size
    if len(datasets_copy['train'].graphs) % batch_size != 0 and repeat:
        num_extra = batch_size - (len(datasets_copy['train'].graphs) % batch_size)
        extra_indices = np.random.choice(len(datasets_copy['train'].graphs), num_extra, replace=True)
        extra_graphs = [datasets_copy['train'].graphs[i] for i in extra_indices]
        datasets_copy['train'].graphs.extend(extra_graphs)
    
    for key, value in datasets_copy.items():
        if (key == 'train' or key == 'valid') and repeat:
            loader[key] = PackedBatchLoader(value, batch_size=batch_size, shuffle=True, epochs=epochs + 1)
        elif key == 'test' or key == 'train_1' or not repeat:
            loader[key] = PackedBatchLoader(value, batch_size=batch_size, shuffle=False, epochs=1)
    
    return loader

def get_theta(trainers_list, dataset, eps = 1e-8):
    import tensorflow as tf
    import numpy as np

    n = len(dataset)
    loader = PackedBatchLoader(dataset, batch_size=n, epochs=1, shuffle=False)
    (x, a), _ = next(loader.load())

    # y_true shape (n,)
    y_true = tf.constant([float(np.squeeze(d.y)) for d in dataset], dtype=tf.float32)
    # print("y_true", y_true)

    pred_mat = []                         
    for idx, trainer in enumerate(trainers_list):
        _, _, _, reg, _  = trainer.model((x, a), training=False)
        pred = tf.squeeze(reg[:, -1])         # shape (n,)
        # print("pred", pred)
        pred_mat.append(pred)

    pred_mat = tf.stack(pred_mat, axis=1)     # shape (n, 10)
    # print("pred_mat.shape", pred_mat.shape)

    err_sq = tf.square(pred_mat - tf.expand_dims(y_true, 1))   # (n, 10)
    # print("err_sq.shape", err_sq.shape)
    mse_per_model = tf.reduce_mean(err_sq, axis=0)             # (10,)
    # print("mse_per_model.shape", mse_per_model.shape)

    inv_mse = 1.0 / (mse_per_model + eps)                      # (10,)
    theta   = inv_mse / tf.reduce_sum(inv_mse)                 # (10,)

    return theta

import logging
import random

def sample_arch_candidates(trainers_list, x_dim, z_dim, visited, sample_amount_per_trainer=20, cur_best=0.94, target=1, hash_version=0):
    logger = logging.getLogger(__name__)
    found_arch_list_set = {i: {} for i in range(len(trainers_list))}
    max_retry = 5
    std_idx = 0
    noise_std_list = [1e-5, 1e-2, 2e-2, 3e-2, 5e-2]
    amount_scale_list = [1.5, 2, 2.5, 2.5, 5]
    global_found_arch_list = set()

    while any(len(archs) < sample_amount_per_trainer for archs in found_arch_list_set.values()) and std_idx < len(noise_std_list):
        retry = 0
        while any(len(archs) < sample_amount_per_trainer for archs in found_arch_list_set.values()) and retry < max_retry:
            for i, trainer in enumerate(trainers_list):
                if len(found_arch_list_set[i]) >= sample_amount_per_trainer:
                    continue

                query_amount = int(sample_amount_per_trainer * amount_scale_list[std_idx])
                if std_idx == len(noise_std_list) - 1 and retry == 0:
                    print(f"Trainer {i}: Already found: {len(found_arch_list_set[i])} / {sample_amount_per_trainer}, cur query: {query_amount}")

                found_arch_list = eval_query_target(trainer.model, x_dim, z_dim,
                                                    noise_scale=noise_std_list[std_idx],
                                                    query_amount=query_amount, 
                                                    target=target,
                                                    hash_version=hash_version)

                for arch_hash in found_arch_list:
                    if arch_hash not in visited and arch_hash not in global_found_arch_list:
                        if arch_hash not in visited:
                            found_arch_list_set[i][arch_hash] = found_arch_list[arch_hash]

                if len(found_arch_list_set[i]) > sample_amount_per_trainer:
                    keys = list(found_arch_list_set[i].keys())
                    num_keys_to_select = sample_amount_per_trainer
                    selected_keys = random.sample(keys, num_keys_to_select)
                    found_arch_list_set[i] = {key: found_arch_list_set[i][key] for key in selected_keys}

            retry += 1

        logger.info(f'std scale {noise_std_list[std_idx]}, num samples: {[len(archs) for archs in found_arch_list_set.values()]}')
        std_idx += 1

    return found_arch_list_set

def predict_arch_acc(found_arch_list_set, trainers_list, theta):
    """
    以多模型集成 + 權重 θ 預測每個 architecture 的準確率
    `found_arch_list_set` 為 dict 或 list，鍵/索引須能遍歷
    """

    # ----------- 將 candidate architectures 打包成 batch ----------
    idx_keys = list(found_arch_list_set.keys())  # 若是 list 可直接用 range
    x_batch = tf.stack([tf.constant(found_arch_list_set[k]['x']) for k in idx_keys])
    a_batch = tf.stack([tf.constant(found_arch_list_set[k]['a']) for k in idx_keys])

    if tf.shape(x_batch)[0] == 0:
        return  # 空集合直接退出

    # ----------- 收集每個模型的預測 -------------------------------
    pred_matrix = []
    for trainer in trainers_list:
        _, _, _, reg, _ = trainer.model((x_batch, a_batch), training=False)  # (batch, z_dim + y_dim)
        # print("reg.shape", reg.shape)
        y_hat = tf.squeeze(reg[:, -1])       # shape (batch,)
        # print("y_hat.shape", y_hat.shape)

        pred_matrix.append(y_hat.numpy())

    # (batch, num_models)
    # print("pred_matrix.shape", pred_matrix.shape)
    pred_matrix = np.stack(pred_matrix, axis=1)
    # print("pred_matrix.shape", pred_matrix.shape)

    # ----------- 加權平均 ----------------------------------------
    theta_np = np.asarray(theta, dtype=np.float32)               # (num_models,)
    # print("theta_np.shape", theta_np.shape)
    y_hat = pred_matrix @ theta_np                               # (batch,)

    # ----------- 寫回結果 ---------------------------------------
    for k, y_pred in zip(idx_keys, y_hat):
        found_arch_list_set[k]['y_pred'] = y_pred

def sample_and_select(trainers_list, datasets, visited, logger, repeat, top_k=5, is_reset=False, cur_best=0.94, target=1, hash_version=0, sample_amount_per_trainer=20, global_arch_dict={}):
    # Generate architectures per trainer
    # if len(visited) <= 40:
    #     sample_amount_per_trainer = 30
    found_arch_list_set = sample_arch_candidates(trainers_list, trainers_list[0].x_dim, trainers_list[0].z_dim, visited,
                                                 sample_amount_per_trainer=sample_amount_per_trainer, cur_best=cur_best, target=target, hash_version=hash_version)

    # Combine all architectures from each trainer into a single list
    combined_arch_list = {arch_hash: arch for found_arch_list in found_arch_list_set.values() for arch_hash, arch in found_arch_list.items()}

    theta = get_theta(trainers_list, datasets['train_1'])
    print("theta:", theta)

    # Predict accuracy by INN (performance predictor) on the combined architecture list
    predict_arch_acc(combined_arch_list, trainers_list, theta)

    # Select top 5 architectures with highest y_pred values among the combined architectures
    top_5_archs = dict(sorted(combined_arch_list.items(), key=lambda item: item[1]['y_pred'], reverse=True)[:top_k])
    # if len(visited) <= 40:
    #     predict_arch_acc(found_arch_list_set[0], trainers_list)
    #     top_5_archs = dict(sorted(found_arch_list_set[0].items(), key=lambda item: item[1]['y_pred'], reverse=True)[:top_k])

    acc_list = [model_info['y'].item() for model_info in top_5_archs.values()]
    pred_list = [model_info['y_pred'] for model_info in top_5_archs.values()]

    # Print top 10 real y values
    top_10_real_y_archs = dict(sorted(combined_arch_list.items(), key=lambda item: item[1]['y'], reverse=True)[:10])
    top_10_real_y_values = [model_info['y'].item() for model_info in top_10_real_y_archs.values()]
    print("Top 10 real y values: ", top_10_real_y_values)

    if len(acc_list) != 0:
        logger.info('Top acc list: {}'.format(acc_list))
        logger.info(f'Avg found acc {sum(acc_list) / len(acc_list)}')
        logger.info(f'Best found acc {max(acc_list)}')
        logger.info(f'pred_acc list: {pred_list}')
    else:
        logger.info('Top acc list is [] in this run')

    sorted_acc_list = sorted(acc_list, reverse=True)

    num_new_found = 0
    for arch_hash in top_5_archs:
        if arch_hash not in visited:
            visited.append(arch_hash)
            num_new_found += 1

            if top_5_archs[arch_hash]['y'].item() == sorted_acc_list[-1]:
                print(sorted_acc_list[-1])
                datasets['train_1'].graphs.append(Graph(x=top_5_archs[arch_hash]['x'], 
                                                        a=top_5_archs[arch_hash]['a'], 
                                                        y=top_5_archs[arch_hash]['y']))
            else:
                datasets['train_1'].graphs.append(Graph(x=top_5_archs[arch_hash]['x'], 
                                                        a=top_5_archs[arch_hash]['a'], 
                                                        y=top_5_archs[arch_hash]['y']))

    datasets['train'] = repeat_graph_dataset_element(datasets['train_1'], repeat)

    return acc_list, found_arch_list_set, num_new_found, visited, global_arch_dict
def prepare_model(latent_dim, num_layers, d_model, num_heads, dff, num_ops, num_nodes, num_adjs,
                  dropout_rate, eps_scale):   
    pretrained_model = GraphAutoencoder(latent_dim=latent_dim, num_layers=num_layers,
                                        d_model=d_model, num_heads=num_heads,
                                        dff=dff, num_ops=num_ops, num_nodes=num_nodes,
                                        num_adjs=num_adjs, dropout_rate=dropout_rate, eps_scale=eps_scale)
    pretrained_model(
        (tf.random.normal(shape=(1, num_nodes, num_ops)), tf.random.normal(shape=(1, num_nodes, num_nodes))))

    return pretrained_model

def reset_trainer(nvp_config, latent_dim, num_layers, d_model, num_heads, dff, num_ops, num_nodes, num_adjs,
                  dropout_rate, eps_scale, num_models, init_seed, logger, pretrained_model, logdir,
                  x_dim, y_dim, z_dim, retrain_finetune, is_rank_weight, datasets, batch_size, train_epochs, patience):
    def create_model(seed):
        # tf.random.set_seed(seed)
        return GraphAutoencoderNVP(nvp_config=nvp_config, latent_dim=latent_dim, num_layers=num_layers,
                                d_model=d_model, num_heads=num_heads,
                                dff=dff, num_ops=num_ops, num_nodes=num_nodes,
                                num_adjs=num_adjs, dropout_rate=dropout_rate, eps_scale=eps_scale)
  
    models_list = [create_model(init_seed+seed) for seed in range(num_models)]
    input_ops = tf.random.normal(shape=(1, num_nodes, num_ops))
    input_adjs = tf.random.normal(shape=(1, num_nodes, num_nodes))
    for model in models_list:
        model((input_ops, input_adjs))
        model.encoder.set_weights(pretrained_model.encoder.get_weights())
        model.decoder.set_weights(pretrained_model.decoder.get_weights())

    def create_callbacks(logdir, idx):
            return [
                CSVLogger(os.path.join(logdir, f"learning_curve_phase2_model_{idx}.csv")),
                EarlyStopping(monitor='val_total_loss', patience=30, restore_best_weights=True)
            ]

    trainers_list = [Trainer2(model, x_dim, y_dim, z_dim, finetune=retrain_finetune, is_rank_weight=is_rank_weight)
        for model in models_list]
    for idx, trainer in enumerate(trainers_list):
        loader = to_loader(datasets, batch_size, train_epochs, True)
        callbacks = create_callbacks(logdir, idx)
        try:
            kw = {'validation_steps': loader['valid'].steps_per_epoch,
                'steps_per_epoch': loader['train'].steps_per_epoch}
        except:
            kw = {}

        # trainer.compile(optimizer=AdamW(learning_rate=0.001, weight_decay=0.01), run_eagerly=False)
        trainer.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), run_eagerly=False)
        trainer.fit(loader['train'].load(),
                    validation_data=loader['valid'].load(),
                    epochs=train_epochs,
                    callbacks=callbacks,
                    **kw,
                    verbose=0)
        trainer.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), run_eagerly=False)
        results = trainer.evaluate(loader['valid'].load(), steps=loader['valid'].steps_per_epoch, verbose=0)
        logger.info(f"Model {idx} evaluation results: {str(dict(zip(trainer.metrics_names, results)))}")
        del loader
        tf.keras.backend.clear_session()
        gc.collect()
    return trainers_list

def reset_single_trainer(nvp_config, latent_dim, num_layers, d_model, num_heads, dff, num_ops, num_nodes, num_adjs,
                  dropout_rate, eps_scale, num_models, init_seed, logger, pretrained_model, logdir,
                  x_dim, y_dim, z_dim, retrain_finetune, is_rank_weight, datasets, batch_size, train_epochs, patience, trainers_list, break_idx):
    def create_model(seed):
        # tf.random.set_seed(seed)
        return GraphAutoencoderNVP(nvp_config=nvp_config, latent_dim=latent_dim, num_layers=num_layers,
                                d_model=d_model, num_heads=num_heads,
                                dff=dff, num_ops=num_ops, num_nodes=num_nodes,
                                num_adjs=num_adjs, dropout_rate=dropout_rate, eps_scale=eps_scale)
  
    input_ops = tf.random.normal(shape=(1, num_nodes, num_ops))
    input_adjs = tf.random.normal(shape=(1, num_nodes, num_nodes))

    def create_callbacks(logdir, idx):
            return [
                CSVLogger(os.path.join(logdir, f"learning_curve_phase2_model_{idx}.csv")),
                EarlyStopping(monitor='val_total_loss', patience=30, restore_best_weights=True)
            ]

    for idx, trainer in enumerate(trainers_list):
        if idx in break_idx:     
            print("reset trainer ", idx)       
            loader = to_loader(datasets, batch_size, train_epochs, True)
            callbacks = create_callbacks(logdir, idx)
            try:
                kw = {'validation_steps': loader['valid'].steps_per_epoch,
                    'steps_per_epoch': loader['train'].steps_per_epoch}
            except:
                kw = {}
            model = create_model(init_seed+idx)
            model((input_ops, input_adjs))
            model.encoder.set_weights(pretrained_model.encoder.get_weights())
            model.decoder.set_weights(pretrained_model.decoder.get_weights())
            trainer = Trainer2(model, x_dim, y_dim, z_dim, finetune=retrain_finetune, is_rank_weight=is_rank_weight)
            # trainer.compile(optimizer=AdamW(learning_rate=0.001, weight_decay=0.01), run_eagerly=False)
            trainer.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), run_eagerly=False)
            trainer.fit(loader['train'].load(),
                        validation_data=loader['valid'].load(),
                        epochs=train_epochs,
                        callbacks=callbacks,
                        **kw,
                        verbose=0)
            trainer.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), run_eagerly=False)
            results = trainer.evaluate(loader['valid'].load(), steps=loader['valid'].steps_per_epoch, verbose=0)
            logger.info(f"Model {idx} evaluation results: {str(dict(zip(trainer.metrics_names, results)))}")
            del loader
            tf.keras.backend.clear_session()
            gc.collect()

def retrain(trainers_list, datasets, batch_size, train_epochs, logdir, logger):
    
    def create_callbacks(logdir, idx):
        return [
            CSVLogger(os.path.join(logdir, f"learning_curve_phase2_retrain_{idx}.csv")),
            EarlyStopping(monitor='val_total_loss', patience=20, restore_best_weights=True)
        ]
    break_idx = []
    for idx, trainer in enumerate(trainers_list):
        loader = to_loader(datasets, batch_size, train_epochs, True)
        callbacks = create_callbacks(logdir, idx)    

        silent = 0
        if idx == 0:
            silent = 2
        trainer.fit(loader['train'].load(),
                    validation_data=loader['valid'].load(),
                    epochs=train_epochs,
                    callbacks=callbacks,
                    steps_per_epoch=loader['train'].steps_per_epoch,
                    validation_steps=loader['valid'].steps_per_epoch,
                    verbose=silent)

        results = trainer.evaluate(loader['valid'].load(), steps=loader['valid'].steps_per_epoch, verbose=0)
        if dict(zip(trainer.metrics_names, results))["total_loss"] > 2.5:
            break_idx.append(idx)
        logger.info(f"Model {idx} evaluation results: {str(dict(zip(trainer.metrics_names, results)))}")
        del loader, callbacks
        tf.keras.backend.clear_session()
        gc.collect()
    return break_idx

def main(seed, train_sample_amount, valid_sample_amount, query_budget, top_k, finetune, retrain_finetune, is_rank_weight, target, undirected, hash_version, max_without_improvement, decrease_value, sample_amount, switch_ensemble):
    logdir, logger = get_logdir_and_logger(
        os.path.join(f'{train_sample_amount}_{valid_sample_amount}_{query_budget}_finetune{finetune}_rfinetune{retrain_finetune}_rank{is_rank_weight}',
                     "CIFAR-10"), f'trainGAE_two_phase_{seed}.log')
    random_seed = seed
    set_global_determinism(random_seed)
    
    train_phase = [0, 1]  # 0 not train, 1 train
    pretrained_weight = 'logs/phase1_nb301_CE_64_directed/modelGAE_weights_phase1'

    eps_scale = 0.05  # 0.1
    d_model = 32
    dropout_rate = 0.0
    dff = 256
    num_layers = 3
    num_heads = 3

    latent_dim = 16

    num_ops = len(OP_PRIMITIVES_NB301)  # 11
    num_nodes = 22
    num_adjs = num_nodes ** 2

    if undirected:        
        if os.path.exists('datasets/NasBench301Dataset.cache'):
            datasets = pickle.load(open('datasets/NasBench301Dataset.cache', 'rb'))
        else:
            datasets = NasBench301Dataset(undirected=undirected)
            with open('datasets/NasBench301Dataset.cache', 'wb') as f:
                pickle.dump(datasets, f)
    else:
        if os.path.exists('datasets/NasBench301Dataset_directed.cache'):
            datasets = pickle.load(open('datasets/NasBench301Dataset_directed.cache', 'rb'))
        else:
            datasets = NasBench301Dataset(undirected=undirected)
            with open('datasets/NasBench301Dataset_directed.cache', 'wb') as f:
                pickle.dump(datasets, f)

        datasets = train_valid_test_split_dataset(datasets,
                                                  ratio=[0.8, 0.1, 0.1],
                                                  shuffle=True,
                                                  shuffle_seed=random_seed)

    x_dim = latent_dim * num_nodes
    y_dim = 1  # 1
    z_dim = x_dim - 1  # 127
    # z_dim = latent_dim * 4 - 1
    tot_dim = y_dim + z_dim  # 28
    # pad_dim = tot_dim - x_dim  # 14

    nvp_config = {
        'n_couple_layer': 4,
        'n_hid_layer': 5,
        'n_hid_dim': 256,
        'name': 'NVP',
        'num_couples': 2,
        'inp_dim': tot_dim
    }

    nvp_config_single = {
        'n_couple_layer': 4,
        'n_hid_layer': 5,
        'n_hid_dim': 256,
        'name': 'NVP',
        'num_couples': 2,
        'inp_dim': tot_dim
    }

    pretrained_model = prepare_model(latent_dim, num_layers, d_model, num_heads, dff,
                                                           num_ops, num_nodes, num_adjs, dropout_rate, eps_scale)
    
    tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)

    if train_phase[0]:
        logger.info('Train phase 1')
        train_epochs = 500
        patience = 50
        batch_size = 32
        '''
        pretrained_datasets = copy.deepcopy(datasets)
        pretrained_datasets['train'] = mask_graph_dataset(pretrained_datasets['train'], int(42362 * 0.9), 1, random_seed=random_seed)
        pretrained_datasets['valid'] = mask_graph_dataset(pretrained_datasets['valid'], int(42362 * 0.1), 1, random_seed=random_seed)
        pretrained_datasets['train'].filter(lambda g: not np.isnan(g.y))
        pretrained_datasets['valid'].filter(lambda g: not np.isnan(g.y))
        '''
        loader = to_loader(datasets, batch_size, train_epochs, True)
        callbacks = [CSVLogger(os.path.join(logdir, "learning_curve_phase1.csv")),
                     tf.keras.callbacks.ReduceLROnPlateau(monitor='val_rec_loss', factor=0.1, patience=patience // 2,
                                                          verbose=1, min_lr=1e-5),
                     tensorboard_callback,
                     EarlyStopping(monitor='val_rec_loss', patience=patience, restore_best_weights=True)]
        trainer = train(1, pretrained_model, loader, train_epochs, logdir, callbacks)
        results = trainer.evaluate(loader['test'].load(), steps=loader['test'].steps_per_epoch)
        logger.info(f'{dict(zip(trainer.metrics_names, results))}')


    pretrained_model.load_weights(pretrained_weight)

    global_top_acc_list = []
    global_top_arch_list = []
    record_top = {'valid': [], 'test': []}

    if train_phase[1]:
        batch_size = 32
        train_epochs = 500
        retrain_epochs = 100
        patience = 50
        repeat_label = 5
        now_queried = train_sample_amount + valid_sample_amount
        logger.info('Train phase 2')
        datasets['train_1'] = mask_graph_dataset(datasets['train'], train_sample_amount, 1, random_seed=random_seed)
        datasets['valid_1'] = mask_graph_dataset(datasets['valid'], valid_sample_amount, 1, random_seed=random_seed)
        datasets['train_1'].filter(lambda g: not np.isnan(g.y))
        datasets['valid_1'].filter(lambda g: not np.isnan(g.y))
        datasets['train'] = repeat_graph_dataset_element(datasets['train_1'], repeat_label)
        datasets['valid'] = repeat_graph_dataset_element(datasets['valid_1'], 3)

        visited_hash = []
        global_arch_dict = {}
        for i in range(len(datasets['train_1'])):
            global_top_acc_list.append(datasets['train_1'][i].y)
            visited_hash.append(NasBench301Dataset.get_hash_by_darts_cell(datasets['train_1'][i].a, 
                                                                          datasets['train_1'][i].x,
                                                                          hash_version=hash_version))
        for i in range(len(datasets['valid_1'])):
            global_top_acc_list.append(datasets['valid_1'][i].y)
            visited_hash.append(NasBench301Dataset.get_hash_by_darts_cell(datasets['valid_1'][i].a, 
                                                                          datasets['valid_1'][i].x,
                                                                          hash_version=hash_version))
            
        trainers_list = reset_trainer(nvp_config_single, latent_dim, num_layers, d_model, num_heads, dff, num_ops, num_nodes, num_adjs,
                                      dropout_rate, eps_scale, 10, random_seed, logger, pretrained_model, logdir,
                                      x_dim, y_dim, z_dim, retrain_finetune, is_rank_weight, datasets, batch_size, train_epochs, patience)

        run = 0
        history_best = max(global_top_acc_list)
        max_idx = now_queried
        without_improvement_count = 0
        is_reset = False

        sample_amount_per_trainer = 10

        while now_queried < query_budget and run <= 100:
            logger.info('')
            logger.info(f'Retrain run {run}')
            logger.info(f'target: {target}, without_improvement_count: {without_improvement_count}')

            top_acc_list, top_arch_list, num_new_found, visited_hash, global_arch_dict = sample_and_select(trainers_list, datasets, visited_hash, 
                                                                                                           logger, repeat_label, top_k, is_reset,
                                                                                                           cur_best=history_best, target=target, hash_version=hash_version, sample_amount_per_trainer=sample_amount_per_trainer, global_arch_dict=global_arch_dict)
            
            now_queried += num_new_found
            logger.info('Now queried: ' + str(now_queried))
            is_reset = False
            if max(top_acc_list) > history_best:
                history_best = max(top_acc_list)
                max_idx = now_queried
                logger.info(f'Find new best {history_best}')
                without_improvement_count = 0
                
            else:
                without_improvement_count += 1
            
            if now_queried >= query_budget:
                break
            break_idx = retrain(trainers_list, datasets, batch_size, retrain_epochs, logdir, logger)
            if len(break_idx) != 0 and len(visited_hash) >= 45:
                    reset_single_trainer(nvp_config, latent_dim, num_layers, d_model, num_heads, dff, num_ops, num_nodes, num_adjs,
                                            dropout_rate, eps_scale, 10, random_seed, logger, pretrained_model, logdir,
                                            x_dim, y_dim, z_dim, retrain_finetune, is_rank_weight, datasets, batch_size, train_epochs, patience, trainers_list, break_idx)

            global_top_acc_list += top_acc_list
            global_top_arch_list += top_arch_list
            run += 1

            record_top['valid'].append({now_queried: sorted(global_top_acc_list, reverse=True)[:5]})

            logger.info(f'History top 5 acc: {sorted(global_top_acc_list, reverse=True)[:5]}')

        with open(f"nb301_found_archs.pkl", "wb") as f:
            pickle.dump(datasets['train_1'], f)

    logger.info('Final result')
    logger.info(f'Best found acc {max(global_top_acc_list)}')
    logger.info(f'Best found acc times: {max_idx}')
    return max(global_top_acc_list), record_top, max_idx


if __name__ == '__main__':
    args = parse_args()
    main(args.seed, args.train_sample_amount, args.valid_sample_amount, args.query_budget,
         args.top_k, args.finetune, args.retrain_finetune, args.rank_weight, target=args.target, undirected=args.undirected, hash_version=args.hash_version,
         max_without_improvement=args.max_without_improvement, decrease_value=args.decrease_value, sample_amount=args.sample_amount, switch_ensemble=args.switch_ensemble)