import argparse
import copy
import pickle
import random
import numpy as np
import tqdm
from tensorflow.python.keras.callbacks import CSVLogger, EarlyStopping, LearningRateScheduler, Callback
from invertible_neural_networks.flow import MMD_multiscale
from GNN import GraphAutoencoder, GraphAutoencoderNVP, get_rank_weight, MyModel
from models.TransformerAE import TransformerAutoencoderNVP
import tensorflow as tf
import os
from NASBenchNLP_Dateset import NASBenchNLPDataset
import nasbenchnlpn_node_data as nodedata
from datasets.utils import train_valid_test_split_dataset, mask_graph_dataset, repeat_graph_dataset_element
from spektral.data import PackedBatchLoader
from evalGAE_nbnlp import eval_query_target, eval_from_lat
from utils.py_utils import get_logdir_and_logger
from spektral.data import Graph
from utils.tf_utils import set_global_determinism
import logging
import random
import gc
import warnings
from sklearn.cluster import KMeans
from copy import deepcopy
from tensorflow_addons.optimizers import AdamW
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.utils import shuffle
from Flow2FlowAgent import Flow2FlowAgent
warnings.filterwarnings("ignore", category=UserWarning, module='tensorflow.python.framework.indexed_slices')

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_sample_amount', type=int, default=50, help='Number of samples to train (default: 50)')
    parser.add_argument('--valid_sample_amount', type=int, default=10, help='Number of samples to train (default: 10)')
    parser.add_argument('--sample_amount', type=int, default=100)
    parser.add_argument('--query_budget', type=int, default=150)
    parser.add_argument('--switch_ensemble', type=int, default=1)
    parser.add_argument('--target', type=float, default=1.0)
    parser.add_argument('--max_without_improvement', type=int, default=5)
    parser.add_argument('--decrease_value', type=float, default=0.1)
    parser.add_argument('--hash_version', type=int, default=2)
    parser.add_argument('--undirected', type=bool, default=False)
    parser.add_argument('--top_k', type=int, default=5)
    parser.add_argument('--finetune', action='store_true')
    parser.add_argument('--no_finetune', dest='finetune', action='store_false')
    parser.set_defaults(finetune=False)
    parser.add_argument('--retrain_finetune', action='store_true')
    parser.add_argument('--no_retrain_finetune', dest='retrain_finetune', action='store_false')
    parser.set_defaults(retrain_finetune=False)
    parser.add_argument('--rank_weight', action='store_true')
    parser.add_argument('--no_rank_weight', dest='rank_weight', action='store_false')
    parser.set_defaults(rank_weight=True)
    parser.add_argument('--seed', type=int, default=0)
    return parser.parse_args()


def cal_ops_adj_loss_for_graph(x_batch_train, ops_cls, adj_cls, reduction='auto', rank_weight=None):
    ops_label, adj_label = x_batch_train
    # adj_label = tf.reshape(adj_label, [tf.shape(adj_label)[0], -1])
    ops_loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False, reduction=reduction)(ops_label, ops_cls)
    #ops_loss = tf.keras.losses.KLDivergence()(ops_label, ops_cls)
    adj_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False, reduction=reduction)(adj_label, adj_cls)
    if reduction == 'none':
        ops_loss = tf.reduce_mean(ops_loss, axis=-1)
        adj_loss = tf.reduce_mean(adj_loss, axis=-1)
    if rank_weight is not None:
        ops_loss = tf.reduce_sum(tf.multiply(ops_loss, rank_weight))
        adj_loss = tf.reduce_sum(tf.multiply(adj_loss, rank_weight))

    return ops_loss, adj_loss

def MMD_per_data_point(x, y):
    """
    Computes the MMD for each data point and returns the average MMD.
    
    x: Tensor of shape [batch_size, feature_dim], real data samples
    y: Tensor of shape [batch_size, feature_dim], generated data samples
    """
    # Expand dims to compute pairwise distances for each data point
    x_exp = tf.expand_dims(x, 1)  # Shape: [batch_size, 1, feature_dim]
    y_exp = tf.expand_dims(y, 1)  # Shape: [batch_size, 1, feature_dim]
    
    # Compute MMD for each data point
    mmd_values = tf.map_fn(lambda i: MMD_multiscale(x_exp[i], y), tf.range(tf.shape(x)[0]), dtype=tf.float32)
    
    # Return the average MMD across all data points
    return tf.reduce_mean(mmd_values)

class Trainer2(tf.keras.Model):
    def __init__(self, model: TransformerAutoencoderNVP, x_dim, y_dim, z_dim, finetune=False, is_rank_weight=True):
        super(Trainer2, self).__init__()
        self.model = model
        self.x_dim = x_dim
        self.y_dim = y_dim
        self.z_dim = z_dim
        self.finetune = finetune
        self.is_rank_weight = is_rank_weight

        # For reg loss weight
        self.w1 = 5.
        # For latent loss weight
        self.w2 = 1.
        # For rev loss weight
        self.w3 = 10.

        if self.is_rank_weight:
            self.reduction = 'none'
        else:
            self.reduction = 'auto'
        self.reg_loss_fn = tf.keras.losses.MeanSquaredError(reduction=self.reduction)
        self.loss_latent = MMD_multiscale
        self.loss_backward = tf.keras.losses.MeanSquaredError(reduction=self.reduction)
        self.loss_tracker = {
            'total_loss': tf.keras.metrics.Mean(name="total_loss"),
            'reg_loss': tf.keras.metrics.Mean(name="reg_loss"),
            'latent_loss': tf.keras.metrics.Mean(name="latent_loss"),
            'rev_loss': tf.keras.metrics.Mean(name="rev_loss")
        }
        if self.finetune:
            self.loss_tracker.update({
                'rec_loss': tf.keras.metrics.Mean(name="rec_loss"),
                'ops_loss': tf.keras.metrics.Mean(name="ops_loss"),
                'adj_loss': tf.keras.metrics.Mean(name="adj_loss"),
                'kl_loss': tf.keras.metrics.Mean(name="kl_loss")
            })
            self.ops_weight = 1
            self.adj_weight = 1
            self.kl_weight = 0.16

    def cal_reg_and_latent_loss(self, y, z, y_out, nan_mask, rank_weight=None, train=False):
        reg_loss = self.reg_loss_fn(tf.boolean_mask(y, nan_mask),
                                    tf.boolean_mask(y_out[:, self.z_dim:], nan_mask))
        latent_loss = self.loss_latent(tf.boolean_mask(tf.concat([z, y], axis=-1), nan_mask),
                                       tf.boolean_mask(
                                           tf.concat([y_out[:, :self.z_dim], y_out[:, -self.y_dim:]], axis=-1),
                                           nan_mask))  # * x_batch_train.shape[0]
        if self.is_rank_weight and train:
            # reg_loss (batch_size)
            reg_loss = tf.multiply(reg_loss, rank_weight)
            reg_loss = tf.reduce_sum(reg_loss)
        return reg_loss, latent_loss

    def cal_rev_loss(self, undirected_x_batch_train, y, z, nan_mask, noise_scale, rank_weight=None, train=False):
        y = tf.boolean_mask(y, nan_mask)
        z = tf.boolean_mask(z, nan_mask)
        y = y + noise_scale * tf.random.normal(shape=tf.shape(y), dtype=tf.float32)
        non_nan_x_batch = (tf.boolean_mask(undirected_x_batch_train[0], nan_mask),
                           tf.boolean_mask(undirected_x_batch_train[1], nan_mask))
        _, _, _, _, x_encoding = self.model(non_nan_x_batch, training=True)  # Logits for this minibatch
        x_rev = self.model.inverse(tf.concat([z, y], axis=-1))
        rev_loss = self.loss_backward(x_rev, x_encoding)  # * x_batch_train.shape[0]
        if self.is_rank_weight and train:
            # rev_loss (batch_size)
            rev_loss = tf.multiply(rev_loss, rank_weight)
            rev_loss = tf.reduce_sum(rev_loss)
        return rev_loss

    def train_step(self, data):
        if not self.finetune:
            self.model.encoder.trainable = False
            self.model.decoder.trainable = False
        list_data = list(data)
        new_data = list(list_data[0])
        for i in range(len(new_data)):
          if new_data[i].dtype == tf.int64:
              new_data[i] = tf.cast(new_data[i], tf.float32)
          else:
              print("The tensor is not int64.")
        #print("new_data",new_data)
        list_data[0] = tuple(new_data)
        data = tuple(list_data)
        x_batch_train, y_batch_train = data
        # undirected_x_batch_train = (x_batch_train[0], to_undiredted_adj(x_batch_train[1]))
        y = y_batch_train[:, -self.y_dim:]
        z = tf.random.normal([tf.shape(y_batch_train)[0], self.z_dim])
        nan_mask = tf.where(~tf.math.is_nan(tf.reduce_sum(y, axis=-1)), x=True, y=False)
        rank_weight = get_rank_weight(tf.boolean_mask(y, nan_mask)) if self.is_rank_weight else None

        with tf.GradientTape() as tape:
            ops_cls, adj_cls, kl_loss, y_out, x_encoding = self.model(x_batch_train, kl_reduction='none', training=True)

            # To avoid nan loss when batch size is small
            reg_loss, latent_loss = tf.cond(tf.reduce_any(nan_mask),
                                            lambda: self.cal_reg_and_latent_loss(y, z, y_out, nan_mask, rank_weight, True),
                                            lambda: (0., 0.))

        grads = tape.gradient(latent_loss, self.model.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))

        with tf.GradientTape() as tape:
            ops_cls, adj_cls, kl_loss, y_out, x_encoding = self.model(x_batch_train, kl_reduction='none', training=True)

            # To avoid nan loss when batch size is small
            reg_loss, latent_loss = tf.cond(tf.reduce_any(nan_mask),
                                            lambda: self.cal_reg_and_latent_loss(y, z, y_out, nan_mask, rank_weight, True),
                                            lambda: (0., 0.))
            forward_loss = self.w1 * reg_loss
            rec_loss = 0. 
            if self.finetune:
                ops_loss, adj_loss = cal_ops_adj_loss_for_graph(x_batch_train, ops_cls, adj_cls, 'auto', None)
                '''
                if rank_weight is not None:
                    kl_loss = tf.reduce_sum(tf.multiply(kl_loss, rank_weight))
                else:
                    kl_loss = tf.reduce_mean(kl_loss)
                '''
                kl_loss = tf.reduce_mean(kl_loss)
                rec_loss = self.ops_weight * ops_loss + self.adj_weight * adj_loss + self.kl_weight * kl_loss
                forward_loss += rec_loss

        grads = tape.gradient(forward_loss, self.model.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
        forward_loss += latent_loss

        # Backward loss
        with tf.GradientTape() as tape:
            self.model.encoder.trainable = False
            self.model.decoder.trainable = False
            # To avoid nan loss when batch size is small
            rev_loss = tf.cond(tf.reduce_any(nan_mask),
                               lambda: self.cal_rev_loss(x_batch_train, y, z, nan_mask, 0.0001, rank_weight, True),
                               lambda: 0.)
            # l2_loss = tf.add_n([self.regularizer(w) for w in self.model.trainable_weights])
            backward_loss = self.w3 * rev_loss

        grads = tape.gradient(backward_loss, self.model.trainable_weights)
        # grads = [tf.clip_by_norm(g, 1.) for g in grads]
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
        self.model.encoder.trainable = True
        self.model.decoder.trainable = True

        self.loss_tracker['total_loss'].update_state(forward_loss + backward_loss)
        self.loss_tracker['reg_loss'].update_state(reg_loss)
        self.loss_tracker['latent_loss'].update_state(latent_loss)
        self.loss_tracker['rev_loss'].update_state(rev_loss)
        # self.loss_tracker['l2_loss'].update_state(l2_loss)
        if self.finetune:
            self.loss_tracker['rec_loss'].update_state(rec_loss)
            self.loss_tracker['ops_loss'].update_state(ops_loss)
            self.loss_tracker['adj_loss'].update_state(adj_loss)
            self.loss_tracker['kl_loss'].update_state(kl_loss)

        return {key: value.result() for key, value in self.loss_tracker.items()}

    def test_step(self, data):
        self.model.encoder.trainable = False
        self.model.decoder.trainable = False
        list_data = list(data)
        new_data = list(list_data[0])
        for i in range(len(new_data)):
          if new_data[i].dtype == tf.int64:
              new_data[i] = tf.cast(new_data[i], tf.float32)
          else:
              print("The tensor is not int64.")
        #print("new_data",new_data)
        list_data[0] = tuple(new_data)
        data = tuple(list_data)
        x_batch_train, y_batch_train = data
        # undirected_x_batch_train = (x_batch_train[0], to_undiredted_adj(x_batch_train[1]))
        y = y_batch_train[:, -self.y_dim:]
        z = tf.random.normal(shape=[tf.shape(y_batch_train)[0], self.z_dim])
        nan_mask = tf.where(~tf.math.is_nan(tf.reduce_sum(y, axis=-1)), x=True, y=False)
        rank_weight = get_rank_weight(tf.boolean_mask(y, nan_mask)) if self.is_rank_weight else None

        ops_cls, adj_cls, kl_loss, y_out, x_encoding = self.model(x_batch_train, kl_reduction='none', training=False)
        reg_loss, latent_loss = tf.cond(tf.reduce_any(nan_mask),
                                        lambda: self.cal_reg_and_latent_loss(y, z, y_out, nan_mask, rank_weight),
                                        lambda: (0., 0.))
        forward_loss = self.w1 * reg_loss + self.w2 * latent_loss
        rev_loss = tf.cond(tf.reduce_any(nan_mask),
                           lambda: self.cal_rev_loss(x_batch_train, y, z, nan_mask, 0., rank_weight),
                           lambda: 0.)
        # l2_loss = tf.add_n([self.regularizer(w) for w in self.model.trainable_weights])
        # backward_loss = self.w3 * rev_loss + l2_loss
        backward_loss = self.w3 * rev_loss
        if self.finetune:
            ops_loss, adj_loss = cal_ops_adj_loss_for_graph(x_batch_train, ops_cls, adj_cls, 'auto', None)
            '''
            if rank_weight is not None:
                kl_loss = tf.reduce_sum(tf.multiply(kl_loss, rank_weight))
            else:
                kl_loss = tf.reduce_mean(kl_loss)
            '''
            kl_loss = tf.reduce_mean(kl_loss)
            rec_loss = self.ops_weight * ops_loss + self.adj_weight * adj_loss + self.kl_weight * kl_loss
            forward_loss += rec_loss

        self.loss_tracker['total_loss'].update_state(forward_loss + backward_loss)
        self.loss_tracker['reg_loss'].update_state(reg_loss)
        self.loss_tracker['latent_loss'].update_state(latent_loss)
        self.loss_tracker['rev_loss'].update_state(rev_loss)
        # self.loss_tracker['l2_loss'].update_state(l2_loss)
        if self.finetune:
            self.loss_tracker['rec_loss'].update_state(rec_loss)
            self.loss_tracker['ops_loss'].update_state(ops_loss)
            self.loss_tracker['adj_loss'].update_state(adj_loss)
            self.loss_tracker['kl_loss'].update_state(kl_loss)

        self.model.encoder.trainable = True
        self.model.decoder.trainable = True

        return {key: value.result() for key, value in self.loss_tracker.items()}

    @property
    def metrics(self):
        return [value for _, value in self.loss_tracker.items()]

def to_loader(datasets, batch_size: int, epochs: int, repeat: True):
    loader = {}
    
    # Create a deep copy of the datasets to avoid modifying the original one
    datasets_copy = deepcopy(datasets)
    
    # Ensure 'train' dataset length is a multiple of batch_size
    if len(datasets_copy['train'].graphs) % batch_size != 0:
        num_extra = batch_size - (len(datasets_copy['train'].graphs) % batch_size)
        extra_indices = np.random.choice(len(datasets_copy['train'].graphs), num_extra, replace=True)
        extra_graphs = [datasets_copy['train'].graphs[i] for i in extra_indices]
        datasets_copy['train'].graphs.extend(extra_graphs)
    
    for key, value in datasets_copy.items():
        if (key == 'train' or key == 'valid') and repeat:
            loader[key] = PackedBatchLoader(value, batch_size=batch_size, shuffle=True, epochs=epochs * 2)
        elif key == 'test' or key == 'train_1' or not repeat:
            loader[key] = PackedBatchLoader(value, batch_size=batch_size, shuffle=False, epochs=1)
    
    return loader

import logging
import random

def sample_arch_candidates(trainers_list, x_dim, z_dim, visited, sample_amount_per_trainer=20, cur_best=0.94, target=1, hash_version=0):
    logger = logging.getLogger(__name__)
    found_arch_list_set = {i: {} for i in range(len(trainers_list))}
    max_retry = 5
    std_idx = 0
    noise_std_list = [1e-5, 1e-2, 2e-2, 3e-2, 5e-2]
    amount_scale_list = [1.5, 2, 2.5, 2.5, 5]
    global_found_arch_list = set()

    while any(len(archs) < sample_amount_per_trainer for archs in found_arch_list_set.values()) and std_idx < len(noise_std_list):
        retry = 0
        while any(len(archs) < sample_amount_per_trainer for archs in found_arch_list_set.values()) and retry < max_retry:
            for i, trainer in enumerate(trainers_list):
                if len(found_arch_list_set[i]) >= sample_amount_per_trainer:
                    continue

                query_amount = int(sample_amount_per_trainer * amount_scale_list[std_idx])
                if std_idx == len(noise_std_list) - 1 and retry == 0:
                    print(f"Trainer {i}: Already found: {len(found_arch_list_set[i])} / {sample_amount_per_trainer}, cur query: {query_amount}")

                found_arch_list = eval_query_target(trainer.model, x_dim, z_dim,
                                                    noise_scale=noise_std_list[std_idx],
                                                    query_amount=query_amount, 
                                                    target=target,
                                                    hash_version=hash_version)

                for arch_hash in found_arch_list:
                    if arch_hash not in visited and arch_hash not in global_found_arch_list:
                        if arch_hash not in found_arch_list_set[i]:
                            found_arch_list_set[i][arch_hash] = found_arch_list[arch_hash]
                            global_found_arch_list.add(arch_hash)

                if len(found_arch_list_set[i]) > sample_amount_per_trainer:
                    keys = list(found_arch_list_set[i].keys())
                    num_keys_to_select = sample_amount_per_trainer
                    selected_keys = random.sample(keys, num_keys_to_select)
                    found_arch_list_set[i] = {key: found_arch_list_set[i][key] for key in selected_keys}

            retry += 1

        logger.info(f'std scale {noise_std_list[std_idx]}, num samples: {[len(archs) for archs in found_arch_list_set.values()]}')
        std_idx += 1

    return found_arch_list_set

# Note: Make sure you have the eval_query_target function defined elsewhere in your code.



def predict_arch_acc(found_arch_list_set, trainers_list):
    """
    Predict accuracy by INN (performance predictor) and assign
    the predicted value to found_arch_list_set
    """
    # Predict accuracy by INN (performance predictor)
    x = tf.stack([tf.constant(found_arch_list_set[i]['x']) for i in found_arch_list_set])
    a = tf.stack([tf.constant(found_arch_list_set[i]['a']) for i in found_arch_list_set])
    
    if tf.shape(x)[0] != 0:
        # Collect predictions from all models
        predictions = []
        for i, trainer in enumerate(trainers_list):
            _, _, _, reg, _ = trainer.model((x, a), training=False)
            predictions.append(reg[:, -1].numpy())  # Collect predictions for each model

        # Average predictions across all models
        avg_predictions = np.mean(predictions, axis=0)
        
        idx = 0
        for i in found_arch_list_set:
            found_arch_list_set[i]['y_pred'] = avg_predictions[idx]
            idx += 1

def sample_and_select(trainers_list, datasets, visited, logger, repeat, top_k=5, is_reset=False, cur_best=0.94, target=1, hash_version=0, sample_amount_per_trainer=20, global_arch_dict={}):
    # Generate architectures per trainer
    # if len(visited) <= 40:
    #     sample_amount_per_trainer = 30
    found_arch_list_set = sample_arch_candidates(trainers_list, trainers_list[0].x_dim, trainers_list[0].z_dim, visited,
                                                 sample_amount_per_trainer=sample_amount_per_trainer, cur_best=cur_best, target=target, hash_version=hash_version)

    # Combine all architectures from each trainer into a single list
    combined_arch_list = {arch_hash: arch for found_arch_list in found_arch_list_set.values() for arch_hash, arch in found_arch_list.items()}

    # Predict accuracy by INN (performance predictor) on the combined architecture list
    predict_arch_acc(combined_arch_list, trainers_list)

    # Select top 5 architectures with highest y_pred values among the combined architectures
    top_5_archs = dict(sorted(combined_arch_list.items(), key=lambda item: item[1]['y_pred'], reverse=True)[:top_k])
    # if len(visited) <= 40:
    #     predict_arch_acc(found_arch_list_set[0], trainers_list)
    #     top_5_archs = dict(sorted(found_arch_list_set[0].items(), key=lambda item: item[1]['y_pred'], reverse=True)[:top_k])

    acc_list = [model_info['y'].item() for model_info in top_5_archs.values()]
    pred_list = [model_info['y_pred'] for model_info in top_5_archs.values()]

    # Print top 10 real y values
    top_10_real_y_archs = dict(sorted(combined_arch_list.items(), key=lambda item: item[1]['y'], reverse=True)[:10])
    top_10_real_y_values = [model_info['y'].item() for model_info in top_10_real_y_archs.values()]
    print("Top 10 real y values: ", top_10_real_y_values)

    if len(acc_list) != 0:
        logger.info('Top acc list: {}'.format(acc_list))
        logger.info(f'Avg found acc {sum(acc_list) / len(acc_list)}')
        logger.info(f'Best found acc {max(acc_list)}')
        logger.info(f'pred_acc list: {pred_list}')
    else:
        logger.info('Top acc list is [] in this run')

    sorted_acc_list = sorted(acc_list, reverse=True)

    num_new_found = 0
    for arch_hash in top_5_archs:
        if arch_hash not in visited:
            visited.append(arch_hash)
            num_new_found += 1

            if top_5_archs[arch_hash]['y'].item() == sorted_acc_list[-1]:
                print(sorted_acc_list[-1])
                datasets['train'].graphs.extend([Graph(x=top_5_archs[arch_hash]['x'], 
                                                       a=top_5_archs[arch_hash]['a'], 
                                                       y=top_5_archs[arch_hash]['y'])] * repeat)
            else:
                datasets['train'].graphs.extend([Graph(x=top_5_archs[arch_hash]['x'], 
                                                       a=top_5_archs[arch_hash]['a'], 
                                                       y=top_5_archs[arch_hash]['y'])] * repeat)

    return acc_list, found_arch_list_set, num_new_found, visited, global_arch_dict




def prepare_model(latent_dim, num_layers, d_model, num_heads, dff, num_ops, num_nodes, num_adjs,
                  dropout_rate, eps_scale):   
    pretrained_model = GraphAutoencoder(latent_dim=latent_dim, num_layers=num_layers,
                                        d_model=d_model, num_heads=num_heads,
                                        dff=dff, num_ops=num_ops, num_nodes=num_nodes,
                                        num_adjs=num_adjs, dropout_rate=dropout_rate, eps_scale=eps_scale)
    pretrained_model(
        (tf.random.normal(shape=(1, num_nodes, num_ops)), tf.random.normal(shape=(1, num_nodes, num_nodes))))

    return pretrained_model

def reset_trainer(nvp_config, latent_dim, num_layers, d_model, num_heads, dff, num_ops, num_nodes, num_adjs,
                  dropout_rate, eps_scale, num_models, init_seed, logger, pretrained_model, logdir,
                  x_dim, y_dim, z_dim, retrain_finetune, is_rank_weight, datasets, batch_size, train_epochs, patience):
    def create_model(seed):
        # tf.random.set_seed(seed)
        return GraphAutoencoderNVP(nvp_config=nvp_config, latent_dim=latent_dim, num_layers=num_layers,
                                d_model=d_model, num_heads=num_heads,
                                dff=dff, num_ops=num_ops, num_nodes=num_nodes,
                                num_adjs=num_adjs, dropout_rate=dropout_rate, eps_scale=eps_scale)
  
    models_list = [create_model(init_seed+seed) for seed in range(num_models)]
    input_ops = tf.random.normal(shape=(1, num_nodes, num_ops))
    input_adjs = tf.random.normal(shape=(1, num_nodes, num_nodes))
    for model in models_list:
        model((input_ops, input_adjs))
        model.encoder.set_weights(pretrained_model.encoder.get_weights())
        model.decoder.set_weights(pretrained_model.decoder.get_weights())

    def create_callbacks(logdir, idx):
            return [
                CSVLogger(os.path.join(logdir, f"learning_curve_phase2_model_{idx}.csv")),
                EarlyStopping(monitor='val_total_loss', patience=30, restore_best_weights=True)
            ]

    trainers_list = [Trainer2(model, x_dim, y_dim, z_dim, finetune=retrain_finetune, is_rank_weight=is_rank_weight)
        for model in models_list]
    for idx, trainer in enumerate(trainers_list):
        loader = to_loader(datasets, batch_size, train_epochs, True)
        callbacks = create_callbacks(logdir, idx)
        try:
            kw = {'validation_steps': loader['valid'].steps_per_epoch,
                'steps_per_epoch': loader['train'].steps_per_epoch}
        except:
            kw = {}

        # trainer.compile(optimizer=AdamW(learning_rate=0.001, weight_decay=0.01), run_eagerly=False)
        trainer.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), run_eagerly=False)
        trainer.fit(loader['train'].load(),
                    validation_data=loader['valid'].load(),
                    epochs=train_epochs,
                    callbacks=callbacks,
                    **kw,
                    verbose=0)
        trainer.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), run_eagerly=False)
        results = trainer.evaluate(loader['valid'].load(), steps=loader['valid'].steps_per_epoch, verbose=0)
        logger.info(f"Model {idx} evaluation results: {str(dict(zip(trainer.metrics_names, results)))}")
        del loader
        tf.keras.backend.clear_session()
        gc.collect()
    return trainers_list

def reset_single_trainer(nvp_config, latent_dim, num_layers, d_model, num_heads, dff, num_ops, num_nodes, num_adjs,
                  dropout_rate, eps_scale, num_models, init_seed, logger, pretrained_model, logdir,
                  x_dim, y_dim, z_dim, retrain_finetune, is_rank_weight, datasets, batch_size, train_epochs, patience, trainers_list, break_idx):
    def create_model(seed):
        # tf.random.set_seed(seed)
        return GraphAutoencoderNVP(nvp_config=nvp_config, latent_dim=latent_dim, num_layers=num_layers,
                                d_model=d_model, num_heads=num_heads,
                                dff=dff, num_ops=num_ops, num_nodes=num_nodes,
                                num_adjs=num_adjs, dropout_rate=dropout_rate, eps_scale=eps_scale)
  
    input_ops = tf.random.normal(shape=(1, num_nodes, num_ops))
    input_adjs = tf.random.normal(shape=(1, num_nodes, num_nodes))

    def create_callbacks(logdir, idx):
            return [
                CSVLogger(os.path.join(logdir, f"learning_curve_phase2_model_{idx}.csv")),
                EarlyStopping(monitor='val_total_loss', patience=30, restore_best_weights=True)
            ]

    for idx, trainer in enumerate(trainers_list):
        if idx in break_idx:     
            print("reset trainer ", idx)       
            loader = to_loader(datasets, batch_size, train_epochs, True)
            callbacks = create_callbacks(logdir, idx)
            try:
                kw = {'validation_steps': loader['valid'].steps_per_epoch,
                    'steps_per_epoch': loader['train'].steps_per_epoch}
            except:
                kw = {}
            model = create_model(init_seed+idx)
            model((input_ops, input_adjs))
            model.encoder.set_weights(pretrained_model.encoder.get_weights())
            model.decoder.set_weights(pretrained_model.decoder.get_weights())
            trainer = Trainer2(model, x_dim, y_dim, z_dim, finetune=retrain_finetune, is_rank_weight=is_rank_weight)
            # trainer.compile(optimizer=AdamW(learning_rate=0.001, weight_decay=0.01), run_eagerly=False)
            trainer.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), run_eagerly=False)
            trainer.fit(loader['train'].load(),
                        validation_data=loader['valid'].load(),
                        epochs=train_epochs,
                        callbacks=callbacks,
                        **kw,
                        verbose=0)
            trainer.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), run_eagerly=False)
            results = trainer.evaluate(loader['valid'].load(), steps=loader['valid'].steps_per_epoch, verbose=0)
            logger.info(f"Model {idx} evaluation results: {str(dict(zip(trainer.metrics_names, results)))}")
            del loader
            tf.keras.backend.clear_session()
            gc.collect()

def retrain(trainers_list, datasets, batch_size, train_epochs, logdir, logger):
    
    def create_callbacks(logdir, idx):
        return [
            CSVLogger(os.path.join(logdir, f"learning_curve_phase2_retrain_{idx}.csv")),
            EarlyStopping(monitor='val_total_loss', patience=30, restore_best_weights=True)
        ]
    break_idx = []
    for idx, trainer in enumerate(trainers_list):
        loader = to_loader(datasets, batch_size, train_epochs, True)
        callbacks = create_callbacks(logdir, idx)    

        silent = 0
        if idx == 0:
            silent = 2
        trainer.fit(loader['train'].load(),
                    validation_data=loader['valid'].load(),
                    epochs=train_epochs,
                    callbacks=callbacks,
                    steps_per_epoch=loader['train'].steps_per_epoch,
                    validation_steps=loader['valid'].steps_per_epoch,
                    verbose=silent)

        results = trainer.evaluate(loader['valid'].load(), steps=loader['valid'].steps_per_epoch, verbose=0)
        if dict(zip(trainer.metrics_names, results))["total_loss"] > 2.5:
            break_idx.append(idx)
        logger.info(f"Model {idx} evaluation results: {str(dict(zip(trainer.metrics_names, results)))}")
        del loader, callbacks
        tf.keras.backend.clear_session()
        gc.collect()
    return break_idx

# 补齐维度的函数
def pad_vectors_to_same_dim(vectors_a, vectors_b):
    """
    比较两种向量的维度并补齐较少的一种。
    :param vectors_a: 第一种向量 (list of numpy arrays)
    :param vectors_b: 第二种向量 (list of numpy arrays)
    :return: 补齐后的两个向量
    """
    dim_a = vectors_a[0].shape[0]
    dim_b = vectors_b[0].shape[0]
    print("dim_a",dim_a)
    print("dim_b",dim_b)
    
    if dim_a > dim_b:
        # 补齐 B
        for i in range(len(vectors_b)):
            padding = np.random.normal(loc=0, scale=0.15, size=(dim_a- dim_b))
            vectors_b[i] = np.concatenate([vectors_b[i], padding]).astype(np.float32)
    elif dim_b > dim_a:
        # 补齐 A
        for i in range(len(vectors_a)):
            padding = np.random.normal(loc=0, scale=0.15, size=(dim_b - dim_a,))
            vectors_a[i] = np.concatenate([vectors_a[i], padding]).astype(np.float32)
    
    return vectors_a, vectors_b

# 自定义数据集类
class EncodedDataset(Dataset):
    def __init__(self, vectors_a, vectors_b):
        """
        初始化数据集。
        :param vectors_a: 域 A 的向量列表。
        :param vectors_b: 域 B 的向量列表。
        """
        self.vectors_a, self.vectors_b = shuffle(vectors_a, vectors_b)
    
    def __len__(self):
        return len(self.vectors_a)
    
    def __getitem__(self, idx):
        return {'src': torch.tensor(self.vectors_a[idx], dtype=torch.float32),
                'tgt': torch.tensor(self.vectors_b[idx], dtype=torch.float32)}

import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

import random

import re
import ast
import os


op2idx = {op: i for i, op in enumerate(nodedata.nlp_operations)} # operation title

def main(seed, train_sample_amount, valid_sample_amount, query_budget, top_k, finetune, retrain_finetune, is_rank_weight, target, undirected, hash_version, max_without_improvement, decrease_value, sample_amount, switch_ensemble):
    logdir, logger = get_logdir_and_logger(
        os.path.join(f'{train_sample_amount}_{valid_sample_amount}_{query_budget}_finetune{finetune}_rfinetune{retrain_finetune}_rank{is_rank_weight}',
                     "CIFAR-10"), f'trainGAE_two_phase_{seed}.log')
    random_seed = seed
    set_global_determinism(random_seed)
    
    train_phase = [0, 1, 1]  # 0 not train, 1 train
    pretrained_weight =   'logs/phase1_nbnlp_CE_64_directed_mix/modelGAE_weights_phase1'
    pretrained_weight_A = 'logs/phase1_nbnlp_CE_64_directed_mix/modelGAE_weights_phase1'
    pretrained_weight_B = 'logs/phase1_nbnlp_CE_64_directed_mix/modelGAE_weights_phase1'

    eps_scale = 0.05  # 0.1
    d_model = 32
    dropout_rate = 0.0
    dff = 256
    num_layers = 3
    num_heads = 3

    latent_dim = 52

    num_ops = len(nodedata.propertys)  # 11
    num_nodes = len(nodedata.global_node)
    num_adjs = num_nodes ** 2

    if undirected:        
        if os.path.exists('datasets/NasBenchNLPDataset_alignflow.cache'):
            datasets = pickle.load(open('datasets/NasBenchNLPDataset_alignflow.cache', 'rb'))
        else:
            datasets = NASBenchNLPDataset()
            with open('datasets/NasBenchNLPDataset_alignflow.cache', 'wb') as f:
                pickle.dump(datasets, f)
    else:
        if os.path.exists('datasets/NasBenchNLPDataset_alignflow.cache'):
            datasets = pickle.load(open('datasets/NasBenchNLPDataset_alignflow.cache', 'rb'))
        else:
            datasets = NASBenchNLPDataset()
            with open('datasets/NasBenchNLPDataset_alignflow.cache', 'wb') as f:
                pickle.dump(datasets, f)
    # train 用作 domain A， valid 用作 domain B
    datasets = train_valid_test_split_dataset(datasets, ratio=[1.0, 0.0, 0.0], shuffle=True, shuffle_seed=random_seed)

    x_dim = latent_dim * num_nodes
    y_dim = 1  # 1
    z_dim = x_dim - 1  # 127
    # z_dim = latent_dim * 4 - 1
    tot_dim = y_dim + z_dim  # 28
    # pad_dim = tot_dim - x_dim  # 14

    nvp_config = {
        'n_couple_layer': 4,
        'n_hid_layer': 5,
        'n_hid_dim': 128,
        'name': 'NVP',
        'num_couples': 2,
        'inp_dim': tot_dim
    }

    nvp_config_single = {
        'n_couple_layer': 4,
        'n_hid_layer': 5,
        'n_hid_dim': 128,
        'name': 'NVP',
        'num_couples': 2,
        'inp_dim': tot_dim
    }

    pretrained_model = prepare_model(latent_dim, num_layers, d_model, num_heads, dff,
                                                           num_ops, num_nodes, num_adjs, dropout_rate, eps_scale)
    pretrained_model.load_weights(pretrained_weight)

    pretrained_model_A = prepare_model(latent_dim, num_layers, d_model, num_heads, dff,
                                                           num_ops, num_nodes, num_adjs, dropout_rate, eps_scale)
    pretrained_model_A.load_weights(pretrained_weight_A)

    pretrained_model_B = prepare_model(latent_dim, num_layers, d_model, num_heads, dff,
                                                           num_ops, num_nodes, num_adjs, dropout_rate, eps_scale)
    pretrained_model_B.load_weights(pretrained_weight_B)

    if train_phase[1]:
        #分a,b
        graphsA=[]
        graphsB=[]
        print("len train",len(datasets["train"]))
        for i in range(len(datasets["train"])):
            recpie = NASBenchNLPDataset.change_to_recepie(datasets["train"][i].a,datasets["train"][i].x)
            if NASBenchNLPDataset.check_map_network_node_find_less_number(recpie,10):
                graphsA.append(Graph(x=datasets["train"][i].x,a=datasets["train"][i].a,y=datasets["train"][i].y))
            else:
                graphsB.append(Graph(x=datasets["train"][i].x,a=datasets["train"][i].a,y=datasets["train"][i].y))
        datasets["train"].graphs = graphsA
        datasets["valid"].graphs = graphsB
        print("len train2",len(datasets["train"]))
        print("len train2",len(datasets["valid"]))

        graphs = []

        # find test
        sort_node = []
        node_more_than_12=[]
        node_10=[]
        for i in tqdm.tqdm(range(len(datasets['train']))):
            check_recpie = NASBenchNLPDataset.change_to_recepie(datasets['train'][i].a,datasets['train'][i].x)
            if NASBenchNLPDataset.check_map_network_node_find_less_number(check_recpie['recepie'],10): #node < 8
                if datasets['train'][i].y > 0.85:
                    node_10.append(copy.deepcopy(datasets['train'][i]))
            else:
                #print(node_14,":",NASBenchNLPDataset.show_map_network_node(check_recpie['recepie']))
                node_more_than_12.append(copy.deepcopy(datasets['train'][i]))
                node_14+=1
        # print("one num_check",num_check)
        sort_node.append(copy.deepcopy(node_10[0]))
        for i in tqdm.tqdm(range(len(node_10))):
            index = 0
            check_append = True
            while index < len(sort_node):
                if node_10[i].y > sort_node[index].y:
                    sort_node.insert(index,node_10[i])
                    check_append = False
                    break
                index+=1
            if check_append:
                sort_node.append(node_10[i])
        node_10_y = []
        print("==============================finish sorted====================================")
        for i in range(2000):
            node_10_y.append(sort_node[i])
        datasets["test"].graphs = node_10_y

        loader = to_loader(datasets, 50, 1, False)

        x_encoding_list_A = []
        datas = loader['train'].load() #node 10
        #domain A 只含 normal cell
        for data in datas: # A and B is not sroted
            x_batch_train, y_batch_train = data
            _, _, _, x_encoding1 = pretrained_model_A((datas.x,datas.a), training=True)
            x_flatten = tf.reshape(x_encoding1, [x_encoding1.shape[0], -1])
            x_encoding_list_A.extend(x_flatten.numpy())

        x_encoding_list_B = [] # 11 node 
        datas = loader['valid'].load()
        for data in datas:
            x_batch_train, y_batch_train = data
            _, _, _, x_encoding = pretrained_model_B(x_batch_train, training=True)
            x_flatten = tf.reshape(x_encoding, [x_encoding.shape[0], -1])
            x_encoding_list_B.extend(x_flatten.numpy())

        # 以 padding 對齊維度
        x_encoding_list_A, x_encoding_list_B = pad_vectors_to_same_dim(x_encoding_list_A, x_encoding_list_B)

        x_encoding_list_test = []
        x_encoding_list_test_B = []
        datas = loader['test'].load()
        for data in datas:
            # print(len(x_encoding_list_A))
            x_batch_train, y_batch_train = data
            _, _, _, x_encoding1 = pretrained_model_A((data.x,data.a), training=True) # batch_size , 301 11 node,  latent_dim
            x_flatten = tf.reshape(x_encoding1, [x_encoding1.shape[0], -1])
            x_encoding_list_test.extend(x_flatten.numpy())
            _, _, _, x_encoding = pretrained_model_B(x_batch_train, training=True)
            x_flatten = tf.reshape(x_encoding, [x_encoding.shape[0], -1])
            x_encoding_list_test_B.extend(x_flatten.numpy())

        x_encoding_list_A, x_encoding_list_test = pad_vectors_to_same_dim(x_encoding_list_A, x_encoding_list_test)

        # 构建 PyTorch 数据集
        dataset = EncodedDataset(x_encoding_list_A, x_encoding_list_B)
        train_size = int(0.8 * len(dataset))
        val_size = len(dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

        weight_path = "/home/engine/Documents/surrogate-201-main/flow2flow_weights_mix_16.pth"

        # 初始化代理
        input_dim = len(x_encoding_list_A[0])  # 输入向量的维度
        agent = Flow2FlowAgent(input_dim=input_dim)

        if os.path.exists(weight_path):
            # 如果权重文件已存在，加载模型
            print(f"檢測到已存在的模型權重檔案：{weight_path}")
            agent.load_model(weight_path)
        else:
            # 如果权重文件不存在，进行训练并保存
            print("未檢測到已存在的模型權重，開始訓練 Flow2Flow 模型...")
            agent.fit(train_dataset, val_dataset, epochs=100, batch_size=32)
            agent.save_model(weight_path)
            print(f"訓練完成，模型權重已儲存至：{weight_path}")
        
        # 测试生成功能
        print("測試生成功能...")
        b_sample = torch.tensor(x_encoding_list_B[0]).unsqueeze(0)  # 选取一个 B 向量
        a_from_b = agent.genAfromB(b_sample)  # 从 B 生成 A
        print("原始b向量", b_sample)
        print("從 B 生成的 A 向量:", a_from_b)

        a_sample = torch.tensor(x_encoding_list_A[0]).unsqueeze(0)  # 选取一个 A 向量
        b_from_a = agent.genBfromA(a_sample)  # 从 A 生成 B
        print("原始a向量", a_sample)
        print("從 A 生成的 B 向量:", b_from_a)

        print("將測試集的normal cell轉换為 Domain B 向量...")
        test_vectors = torch.tensor(x_encoding_list_test, dtype=torch.float32)
        domain_b_vectors = agent.genBfromA(test_vectors)  # 测试集转换到 Domain B
        # print(domain_b_vectors)
        found_arch_list, origin_index = eval_from_lat(pretrained_model_B, domain_b_vectors, len(domain_b_vectors), 22, 16, hash_version=2)
        # top_30_archs = dict(sorted(found_arch_list.items(), key=lambda item: item[1]['y'], reverse=True)[:20])
        # 创建一个反向映射哈希值到 origin_index 的字典
        hash_to_index = {hash_key: item['origin_index'] for hash_key, item in found_arch_list.items() if item['origin_index'] in origin_index}

        # 过滤并排序 found_arch_list
        sorted_found_arch_list = dict(
            sorted(
                found_arch_list.items(),
                key=lambda item: graphs[hash_to_index[item[0]]]['y'],  # 根据 origin_index 中的 y 值排序
                reverse=True  # 降序排列
            )
        )
        top_30_archs = dict(list(sorted_found_arch_list.items())[:30]) # choose 30 arch for alignfrow best to the best
        for key in top_30_archs:
            # top_30_archs[key]['y'] = graphs[top_30_archs[key]['origin_index']]['y']
            print(top_30_archs[key]['y'])

    if undirected:        
        if os.path.exists('datasets/NasBenchNLPDataset_alignflow.cache'):
            datasets = pickle.load(open('datasets/NasBenchNLPDataset_alignflow.cache', 'rb'))
        else:
            datasets = NASBenchNLPDataset()
            with open('datasets/NasBenchNLPDataset_alignflow.cache', 'wb') as f:
                pickle.dump(datasets, f)
    else:
        if os.path.exists('datasets/NasBenchNLPDataset_alignflow.cache'):
            datasets = pickle.load(open('datasets/NasBenchNLPDataset_alignflow.cache', 'rb'))
        else:
            datasets = NASBenchNLPDataset()
            with open('datasets/NasBenchNLPDataset_alignflow.cache', 'wb') as f:
                pickle.dump(datasets, f)
    datasets = train_valid_test_split_dataset(datasets, ratio=[0.8, 0.1, 0.1], shuffle=True, shuffle_seed=random_seed)

    global_top_acc_list = []
    global_top_arch_list = []
    record_top = {'valid': [], 'test': []}

    if train_phase[2]:
        batch_size = 32
        train_epochs = 500
        retrain_epochs = 200
        patience = 50
        repeat_label = 5
        now_queried = train_sample_amount + valid_sample_amount
        logger.info('Train phase 2')
        datasets['train_1'] = mask_graph_dataset(datasets['train'], train_sample_amount, 1, random_seed=random_seed)
        graphs = datasets['train_1'].graphs
        # graphs = []
        for key in top_30_archs:
            graphs.append(Graph(x=top_30_archs[key]['x'],a=top_30_archs[key]['a'],y=top_30_archs[key]['y']))
        datasets['train_1'].graphs = graphs
        datasets['valid_1'] = mask_graph_dataset(datasets['valid'], valid_sample_amount, 1, random_seed=random_seed)
        datasets['train_1'].filter(lambda g: not np.isnan(g.y))
        datasets['valid_1'].filter(lambda g: not np.isnan(g.y))
        datasets['train'] = repeat_graph_dataset_element(datasets['train_1'], repeat_label)
        datasets['valid'] = repeat_graph_dataset_element(datasets['valid_1'], 3)

        visited_hash = []
        global_arch_dict = {}
        for i in range(len(datasets['train_1'])):
            global_top_acc_list.append(datasets['train_1'][i].y)
            visited_hash.append(NASBenchNLPDataset.change_to_recepie( datasets['train_1'][i].a, 
                                            datasets['train_1'][i].x,
                                            hash_version)['arch_id'])
        for i in range(len(datasets['valid_1'])):
            global_top_acc_list.append(datasets['valid_1'][i].y)
            visited_hash.append(NASBenchNLPDataset.change_to_recepie( datasets['train_1'][i].a, 
                                            datasets['train_1'][i].x,
                                            hash_version)['arch_id'])
            
        trainers_list = reset_trainer(nvp_config_single, latent_dim, num_layers, d_model, num_heads, dff, num_ops, num_nodes, num_adjs,
                                      dropout_rate, eps_scale, 10, random_seed, logger, pretrained_model, logdir,
                                      x_dim, y_dim, z_dim, retrain_finetune, is_rank_weight, datasets, batch_size, train_epochs, patience)

        run = 0
        history_best = max(global_top_acc_list)
        max_idx = now_queried
        without_improvement_count = 0
        is_reset = False

        sample_amount_per_trainer = 30

        while now_queried < query_budget and run <= 100:
            logger.info('')
            logger.info(f'Retrain run {run}')
            logger.info(f'target: {target}, without_improvement_count: {without_improvement_count}')

            top_acc_list, top_arch_list, num_new_found, visited_hash, global_arch_dict = sample_and_select(trainers_list, datasets, visited_hash, 
                                                                                                           logger, repeat_label, top_k, is_reset,
                                                                                                           cur_best=history_best, target=target, hash_version=hash_version, sample_amount_per_trainer=sample_amount_per_trainer, global_arch_dict=global_arch_dict)
            
            now_queried += num_new_found
            logger.info('Now queried: ' + str(now_queried))
            is_reset = False
            if max(top_acc_list) > history_best:
                history_best = max(top_acc_list)
                max_idx = now_queried
                logger.info(f'Find new best {history_best}')
                without_improvement_count = 0
                
            else:
                without_improvement_count += 1
            
            break_idx = retrain(trainers_list, datasets, batch_size, retrain_epochs, logdir, logger)
            if len(break_idx) != 0 and len(visited_hash) >= 45:
                    reset_single_trainer(nvp_config, latent_dim, num_layers, d_model, num_heads, dff, num_ops, num_nodes, num_adjs,
                                            dropout_rate, eps_scale, 10, random_seed, logger, pretrained_model, logdir,
                                            x_dim, y_dim, z_dim, retrain_finetune, is_rank_weight, datasets, batch_size, train_epochs, patience, trainers_list, break_idx)

            if now_queried > query_budget:
                break

            global_top_acc_list += top_acc_list
            global_top_arch_list += top_arch_list
            run += 1

            record_top['valid'].append({now_queried: sorted(global_top_acc_list, reverse=True)[:5]})

            logger.info(f'History top 5 acc: {sorted(global_top_acc_list, reverse=True)[:5]}')

    logger.info('Final result')
    logger.info(f'Best found acc {max(global_top_acc_list)}')
    logger.info(f'Best found acc times: {max_idx}')
    return max(global_top_acc_list), record_top, max_idx


if __name__ == '__main__':
    args = parse_args()
    main(args.seed, args.train_sample_amount, args.valid_sample_amount, args.query_budget,
         args.top_k, args.finetune, args.retrain_finetune, args.rank_weight, target=args.target, undirected=args.undirected, hash_version=args.hash_version,
         max_without_improvement=args.max_without_improvement, decrease_value=args.decrease_value, sample_amount=args.sample_amount, switch_ensemble=args.switch_ensemble)