# coding:utf-8
# solve one PDE
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

# 添加子文件夹的路径到模块搜索路径
subfolder_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'PSRPDE/NetGP'))
sys.path.append(subfolder_path)

# External packages
import numpy as np
import matplotlib.pyplot as plt
import torch
import config
import random
import copy
import time
import pickle
from sklearn.model_selection import train_test_split
import pandas as pd

# Internal code import
from physym import batch as Batch
from physym import reward
from learn import rnn
from learn import loss




# Seed
# seed = 0
# np.random.seed(seed)
# torch.manual_seed(seed)


def single_point_crossover(parent1, parent2,length):
    # 随机选择一个交叉点
    point = random.randint(1, length - 1)
    # 生成子代
    child1 = np.concatenate((parent1[:point] , parent2[point:]))
    child2 = np.concatenate((parent2[:point] ,parent1[point:]))
    return child1, child2




def create_model(X,y,run_config):
    def batch_reseter():
        return  Batch.Batch(library_args=run_config["library_config"],
                            priors_config=run_config["priors_config"],
                            batch_size=run_config["learning_config"]["batch_size"],
                            max_time_step=run_config["learning_config"]["max_time_step"],
                            free_const_opti_args=run_config["free_const_opti_args"],
                            X=X,
                            y_target=y,
                            )
    batch = batch_reseter()

    def cell_reseter ():
        input_size  = batch.obs_size
        output_size = batch.n_choices
        cell = rnn.Cell (input_size  = input_size,
                         output_size = output_size,
                         **run_config["cell_config"],
                        )

        return cell

    cell = cell_reseter()
    optimizer = run_config["learning_config"]["get_optimizer"](cell)
    return cell, optimizer



def generated_population(X,y,run_config, evaluate_function):
    generation = 10
    n_keep = int(0.1 * config.BATCH_SIZE)
    cxpb = 0.6
    mutpb = 0.6
    model,optimizer = create_model(X,y,run_config)

    nextgen = None


    fitness_generation = []
    best_prog = None
    best_fitness = 0
    for g in range(generation):
        #sum_len = 0
        fitness_temp = 0
        fitness_count = 0
        # -----------------新种群---------------------
        batch = Batch.Batch(library_args=run_config["library_config"],
                            priors_config=run_config["priors_config"],
                            batch_size=run_config["learning_config"]["batch_size"],
                            max_time_step=run_config["learning_config"]["max_time_step"],
                            free_const_opti_args=run_config["free_const_opti_args"],
                            X=X,
                            y_target=y,
                            )
        batch_size = batch.batch_size

        # Initial RNN cell input
        states = model.get_zeros_initial_state(batch_size)
        # Optimizer reset
        optimizer.zero_grad()


        # Candidates
        logits = []
        actions = []



        for i in range(config.MAX_LENGTH):
            # ------------ OBSERVATIONS ------------
            # (embedding output)
            observations = torch.tensor(batch.get_obs().astype(np.float32),
                                        requires_grad=False, )  # (batch_size, obs_size)

            # ------------ MODEL ------------

            # Giving up-to-date observations
            output, states = model(input_tensor=observations,
                                   # (batch_size, output_size), (n_layers, 2, batch_size, hidden_size)
                                   states=states)

            # Getting raw prob distribution for action n°i
            outlogit = output

            # ------------ PRIOR -----------
            # (embedding output)
            prior_array = batch.prior().astype(np.float32)  # (batch_size, output_size)
            # 0 protection so there is always something to sample
            epsilon = 0  # 1e-14 #1e0*np.finfo(np.float32).eps
            prior_array[prior_array == 0] = epsilon

            # To log
            prior = torch.tensor(prior_array, requires_grad=False)  # (batch_size, output_size)
            logprior = torch.log(prior)  # (batch_size, output_size)

            # ------------ SAMPLING ------------

            logit = outlogit + logprior  # (batch_size, output_size)
            action = torch.multinomial(torch.exp(logit),  # (batch_size,)
                                       num_samples=1)[:, 0]

            # ------------ ACTION ------------

            # Saving action n°i
            logits.append(logit)
            actions.append(action)

            # Informing embedding of new action
            # (embedding input)
            batch.programs.append(action.detach().cpu().numpy())

        # -------------------------------------------------
        # ------------------ CANDIDATES  ------------------
        # -------------------------------------------------

        # Keeping prob distribution history for backpropagation
        logits = torch.stack(logits, dim=0)  # (max_time_step, batch_size, n_choices, )
        actions = torch.stack(actions, dim=0)  # (max_time_step, batch_size,)

        # Programs as numpy array for black box reward computation
        actions_array = actions.detach().cpu().numpy()

        #sum_len += batch.programs.n_lengths.sum()
        R = reward.RewardsComputer(programs=batch.programs,
                               X=X,
                               y_target=y,
                               evaluate_function=evaluate_function,
                               free_const_opti_args=run_config["free_const_opti_args"],
                               )
        

        R = np.array(R)
        R = np.nan_to_num(R)
        fitness_temp+=(1/R-1).sum()
        fitness_count+=len(R)

        keep = R.argsort()[::-1][0:n_keep].copy()

        # ----------------- Train batch : black box part (NUMPY) -----------------
        # Elite candidates
        actions_array_train = copy.deepcopy(actions_array[:, keep])

        # Elite candidates as one-hot target probs
        ideal_probs_array_train = np.eye(batch.n_choices)[actions_array_train]
        R_train = torch.tensor(R[keep], requires_grad=False)  # (n_keep,)

        # Elite candidates pred logprobs
        logits_train = logits[:, keep]
        # Lengths of programs
        lengths = batch.programs.n_lengths[keep]

        if nextgen is not None and g % 2:
            ideal_probs_array_train = np.eye(batch.n_choices)[nextgen]
            R_train = R_train_next
            lengths = lengths_next
        # Elite candidates rewards

        R_lim = R_train.min()

        # Elite candidates as one-hot in torch
        # (non-differentiable tensors)
        ideal_probs_train = torch.tensor(  # (max_time_step, n_keep, n_choices,)
            ideal_probs_array_train.astype(np.float32),
            requires_grad=False, )

        baseline = R_lim

        # Loss
        loss_val = loss.loss_func(logits_train=logits_train,
                                  ideal_probs_train=ideal_probs_train,
                                  R_train=R_train,
                                  baseline=baseline,
                                  lengths=lengths,
                                  gamma_decay=run_config["learning_config"]["gamma_decay"],
                                  entropy_weight=run_config["learning_config"]["entropy_weight"], )


        # BACKPROPAGATION
        # -------------------------------------------------
        if model.is_lobotomized:
            pass
        else:
            loss_val.backward()
            optimizer.step()


        # -----------------交叉、变异算子---------------------
        offspring = copy.deepcopy(actions_array[:,keep])
        keep_length = batch.programs.n_lengths

        if nextgen is not None:
            offspring = np.concatenate((offspring,copy.deepcopy(nextgen)),axis=1)

        new_offspring = []
        # 对选出的个体进行交叉和变异
        for child1_i, child2_i in zip(range(0,offspring.shape[1],2),range(1,offspring.shape[1],2)):
            if random.random() < cxpb:
                child1,child2 = offspring[:, child1_i].T, offspring[:, child2_i].T
                templength = min(keep_length[child1_i],keep_length[child2_i])
                temp_child1,temp_child2 = single_point_crossover(child1,child2,templength)
                new_offspring.append(temp_child1.copy())
                new_offspring.append(temp_child2.copy())

        split_token = len(run_config['library_config']["args_make_tokens"]["op_names"])
        for i in range(len(offspring[0])):
            if random.random() < mutpb:
                mutant=offspring[:,i].copy()
                # 随机选择mutant的一个点，变异成另外一个
                mutation_point = random.randint(0, keep_length[i]-1)
                if mutant[mutation_point]<split_token:
                    mutant[mutation_point] = random.randint(0,split_token-1)
                else:
                    mutant[mutation_point] = random.randint(split_token, batch.library.n_choices-1)
                new_offspring.append(mutant)

        # -----------------修正交叉、变异算子---------------------
        new_offspring = np.array(new_offspring)
        #mask_need_action = np.full(shape=(new_offspring.shape), fill_value=True, dtype=bool)
        batch_new = Batch.Batch(library_args=run_config["library_config"],
                            priors_config=run_config["priors_config"],
                            batch_size=len(new_offspring),
                            max_time_step=run_config["learning_config"]["max_time_step"],
                            free_const_opti_args=run_config["free_const_opti_args"],
                            X=X,
                            y_target=y,
                            )
        # Candidates
        logits_new = []
        actions_new = []

        for i in range(config.MAX_LENGTH):
            # ------------ PRIOR -----------
            # (embedding output)
            prior_array = batch_new.prior().astype(np.float32)  # (batch_size, output_size)
            action2 = new_offspring[:, i]
            prior_action = [bool(prior_array[idx][act]) for idx, act in enumerate(action2)]
            prior_action = list(map(bool, prior_action))
            # prior_action变为bool

            #mask_need_action = np.logical_and(prior_action, action2)


            # 0 protection so there is always something to sample
            epsilon = 0  # 1e-14 #1e0*np.finfo(np.float32).eps
            prior_array[prior_array == 0] = epsilon
            # To log
            prior = torch.tensor(prior_array, requires_grad=False)  # (batch_size, output_size)
            logprior = torch.log(prior)  # (batch_size, output_size)
            action1 = torch.multinomial(torch.exp(logprior),  # (batch_size,)
                                       num_samples=1)[:, 0]




            # ------------ SAMPLING ------------

            action = np.where(prior_action, action2, action1)

            # ------------ ACTION ------------

            # Saving action n°i
            logits_new.append(logit)
            actions_new.append(action)

            # Informing embedding of new action
            # (embedding input)
            batch_new.programs.append(action)


        # -----------------合并、选择下一代---------------------
        if nextgen is not None:
            actions_new = np.concatenate((copy.deepcopy(actions_new),copy.deepcopy(nextgen)),axis=1)
        actions_new = torch.tensor(actions_new)
        actions_new = torch.cat((copy.deepcopy(actions), copy.deepcopy(actions_new)), dim=1)
        actions_new = actions_new.detach().cpu().numpy()
        batch_combine = Batch.Batch(library_args=run_config["library_config"],
                                priors_config=run_config["priors_config"],
                                batch_size=actions_new.shape[1],
                                max_time_step=run_config["learning_config"]["max_time_step"],
                                free_const_opti_args=run_config["free_const_opti_args"],
                                X=X,
                                y_target=y,
                                )

        for i in range(config.MAX_LENGTH):
            action = actions_new[i,:]
            batch_combine.programs.append(action)

        R = reward.RewardsComputer(programs=batch_combine.programs,
                               X=X,
                               y_target=y,
                               evaluate_function=evaluate_function,
                               free_const_opti_args=run_config["free_const_opti_args"],
                               )
        R = np.array(R)
        R = np.nan_to_num(R)
        fitness_temp+=(1/R-1).sum()
        fitness_count+=len(R)

        keep = R.argsort()[::-1][0:n_keep].copy()
        nextgen = copy.deepcopy(actions_new[:, keep])
        R_train_next = torch.tensor(R[keep], requires_grad=False)  # (n_keep,)
        lengths_next = batch_combine.programs.n_lengths[keep]
        #sum_len += batch_combine.programs.n_lengths.sum()



        # -------------------------


        hall_of_fame = [batch_combine.programs.get_prog(keep[0])]

        expr_str = hall_of_fame[0].get_infix_pretty(do_simplify=True)
        print(expr_str,R[keep][0])

        # 原来的
        #fitness_generation.append(R[keep][0])
        
        #=====================ablation=======================
        y_pred, X_temp = hall_of_fame[0].torch_exec(X, hall_of_fame[0].tokens, hall_of_fame[0].free_const_values)
        result_pde = evaluate_function(y_pred, X_temp).detach().numpy()
        fitness_generation.append(result_pde)

        if best_fitness < R[keep][0]:
            best_fitness = R[keep][0]
            best_prog = hall_of_fame[0]
            print("best fitness: ",best_fitness)
            print("best prog: ",best_prog.get_infix_pretty(do_simplify=True))
        if R[keep][0]>0.999:
            return best_prog
    with open('PSR-2.pkl', 'wb') as f:
        pickle.dump({5:fitness_generation}, f)  # 保存
    
    return best_prog


def netgprun(X_train, y_train, feature_names, op_names, evaluate_function):
    print('netgp')

    X_train = torch.tensor(X_train)
    y_train = torch.tensor(y_train)
    #n_dim, data_size = X_train.shape
    args_make_tokens = {
        # operations
        "op_names": op_names,
        # input variables
        "input_var_ids": {feature_names[i]: i for i in range(len(feature_names))},
        # free constants
        "free_constants"       : {"c",},
    }

    library_config = {"args_make_tokens": args_make_tokens, }
    run_config = copy.deepcopy(config.config0)
    run_config.update({"library_config": library_config})
    best_prog = generated_population(X_train,y_train,run_config=run_config,evaluate_function=evaluate_function)

    return best_prog




def discover_sub_equations(grouping_strategies, input_data, op_names, evaluate_function=None):
    label = 'target'
    features_names = [x for x in input_data.columns.values if x != label]
    
    
  
    all_eqs = []
    for strategy in grouping_strategies:
        strategy_name = '_and_'.join(strategy)
        print(f"\n\n===== Applying Grouping Strategy: Keep [{strategy_name}] constant =====")
    
        if not strategy:
            # If no strategy, use the entire dataset
            group_df = input_data
        else:
            grouped = input_data.groupby(strategy) 
            print(f"Found {len(grouped)} groups for this strategy.")
            # Find and select only the largest group
            largest_group_name = grouped.size().idxmax()
            group_df = grouped.get_group(largest_group_name)
        
        # 1. Prepare data for the largest group
        varying_feature_names = [f for f in features_names if f not in strategy]
        
        # Select data only from the varying feature columns
        features = group_df[varying_feature_names].values.T
        target = group_df['target'].values
        

        X_train = torch.tensor(features)
        y_train = torch.tensor(target)
        #n_dim, data_size = X_train.shape
        args_make_tokens = {
            # operations
            "op_names": op_names,
            # input variables
            "input_var_ids": {varying_feature_names[i]: i for i in range(len(varying_feature_names))},
            # free constants
            "free_constants"       : {"c",},
        }

        library_config = {"args_make_tokens": args_make_tokens, }
        run_config = config.config0
        run_config.update({"library_config": library_config})
        best_prog = generated_population(X_train,y_train,run_config=run_config, evaluate_function=evaluate_function)
        print(f"Best program for strategy '{strategy_name}': {best_prog.get_infix_pretty(do_simplify=True)}")
        all_eqs.append(best_prog)

    return all_eqs
        
    





def run(input_data,op_names,evaluate_function):
    label = 'target'
    features_names = [x for x in input_data.columns.values if x != label]
    if "x3" in features_names:
        grouping_strategies = [['x2', 'x3'], ['x1', 'x3',], ['x1', 'x2', ]]
    elif "x2" in features_names:
        grouping_strategies = [['x1', ],['x2', ]]
    else:
        grouping_strategies = []

    # Loop over each strategy
    if grouping_strategies:
        all_eqs = discover_sub_equations(grouping_strategies, input_data, op_names)
    else:
        all_eqs = discover_sub_equations([grouping_strategies], input_data, op_names, evaluate_function)
        #return all_eqs[0],all_eqs[0]

    # =================================================================
    # New logic to combine results and find the final equation
    # =================================================================
    print("\n\n===== Combining discovered equations to find the final PDE =====")

    # 1. Evaluate the discovered equations on the full dataset to create new features
    # Assuming the order of strategies gives equations for x1, x2, x3 respectively.
    f_features = {}
    f_eqs = {}
    
    for i, eq in enumerate(all_eqs):
        feature_name = f'f{i+1}'
        f_eqs[feature_name] = eq
        if grouping_strategies:
            strategy = grouping_strategies[i]
        else:
            strategy = []
        varying_feature_names = [f for f in features_names if f not in strategy]
        # The discovered program `eq` was trained on data with a specific order of features.
        # We must use the same order here.
        eq_input_data = torch.tensor(input_data[varying_feature_names].values.T, dtype=torch.float32)
        f_features[feature_name] = eq.execute(eq_input_data).detach().numpy().flatten()

    

    # 2. Create the new dataset with features f1, f2, f3 and 't'
    final_features_df = pd.DataFrame(f_features)
    # if 't' in features_names:
    #     print("Adding 't' as a feature for the final regression.")
    #     final_features_df['t'] = input_data['t'].values
    
    #final_features_df['t'] = input_data['t'].values  # Add 't' as a feature
    # The target variable is the original 'target' from the input data
    final_target = input_data['target'].values
    
    # Prepare data for the final symbolic regression run
    X_final = torch.tensor(final_features_df.values.T, dtype=torch.float32)
    y_final = torch.tensor(final_target, dtype=torch.float32)
    final_feature_names = list(final_features_df.columns)

    print(f"New dataset created with features: {final_feature_names}")
    print(f"Shape of new features X_final: {X_final.shape}")
    print(f"Shape of new target y_final: {y_final.shape}")

    # 3. Run symbolic regression on the new dataset
    
    #######################################
    ## higher level library: f1,f2,f3
    #######################################
    args_make_tokens_final = {
        "op_names": op_names,  # Typically, the final combination is simpler
        "input_var_ids": {name: i for i, name in enumerate(final_feature_names)},
        "free_constants": {"c",},
    }
    library_config_final = {"args_make_tokens": args_make_tokens_final}
    run_config_final = copy.deepcopy(config.config0)    
    run_config_final.update({"library_config": library_config_final})


    #######################################
    ## original level library: x1,x2,x3,t
    #######################################
    args_make_tokens_original = {
        "op_names": op_names,
        "input_var_ids": {name: i for i, name in enumerate(features_names)},
        "free_constants": {"c",},
    }
    library_config_original = {"args_make_tokens": args_make_tokens_original}
    run_config_original = copy.deepcopy(config.config0)
    run_config_original.update({"library_config": library_config_original})

    
    
    ################################################################################
    ## Running final symbolic regression to find relationship between f1, f2, f3
    ################################################################################
    print("\n--- Running final symbolic regression to find relationship between f1, f2, f3, and t ---")
    # The original 'evaluate_function' should be used here to evaluate the final PDE form
    # final_prog = generated_population(X_final, y_final, run_config=run_config_final, evaluate_function=None)
    
    X_original = torch.tensor(input_data[features_names].values.T, dtype=torch.float32)
    y_original = torch.tensor(input_data['target'].values, dtype=torch.float32)

    final_prog, final_higher_prog = hierarchical_symbolic_regression(
        X_higher=X_final, 
        y_higher=y_final, 
        run_config_higher=run_config_final,
        X_original=X_original,
        y_original=y_original,
        run_config_original=run_config_original,
        f_eqs=f_eqs,
        evaluate_function=evaluate_function
    )

    print("\n\n===== Final Result =====")
    print(f"Discovered relationship between sub-equations: {final_prog.get_infix_pretty(do_simplify=True)}")
    
    # You can further process or pretty-print the final equation by substituting f1, f2, f3
    # with their actual expressions from all_eqs.
    
    return final_prog, final_higher_prog

def hierarchical_symbolic_regression(X_higher, y_higher, run_config_higher, X_original, y_original, run_config_original, f_eqs, evaluate_function):
    """
    Performs hierarchical symbolic regression using a genetic algorithm.
    1. It evolves a population of programs (expressions) that are defined in terms of high-level features (e.g., f1, f2).
    2. For each high-level program, the reward is calculated by expanding it into its full low-level representation
       (in terms of x1, x2, t, etc.) and then evaluating its fitness, including constant optimization.
    3. The best-performing expanded program is returned.
    """
    print("\n--- Running Hierarchical Symbolic Regression ---")
    
    # Create the original library once to pass to the reward computer
    original_library = Batch.Batch(
        library_args=run_config_original["library_config"],
        priors_config=run_config_original["priors_config"],
        batch_size=1,
        max_time_step=1,
        free_const_opti_args=run_config_original["free_const_opti_args"],
        X=X_original,
        y_target=y_original,
    ).library

    generation = 5
    n_keep = int(0.1 * config.BATCH_SIZE)
    cxpb = 0.6
    mutpb = 0.6
    model,optimizer = create_model(X_higher, y_higher, run_config_higher)

    nextgen = None


    fitness_generation = []
    best_prog = None
    best_fitness = 0
    for g in range(generation):
        #sum_len = 0
        fitness_temp = 0
        fitness_count = 0
        # -----------------新种群---------------------
        batch = Batch.Batch(library_args=run_config_higher["library_config"],
                            priors_config=run_config_higher["priors_config"],
                            batch_size=run_config_higher["learning_config"]["batch_size"],
                            max_time_step=run_config_higher["learning_config"]["max_time_step"],
                            free_const_opti_args=run_config_higher["free_const_opti_args"],
                            X=X_higher,
                            y_target=y_higher,
                            )
        batch_size = batch.batch_size

        # Initial RNN cell input
        states = model.get_zeros_initial_state(batch_size)
        # Optimizer reset
        optimizer.zero_grad()


        # Candidates
        logits = []
        actions = []



        for i in range(config.MAX_LENGTH):
            # ------------ OBSERVATIONS ------------
            # (embedding output)
            observations = torch.tensor(batch.get_obs().astype(np.float32),
                                        requires_grad=False, )  # (batch_size, obs_size)

            # ------------ MODEL ------------

            # Giving up-to-date observations
            output, states = model(input_tensor=observations,
                                   # (batch_size, output_size), (n_layers, 2, batch_size, hidden_size)
                                   states=states)

            # Getting raw prob distribution for action n°i
            outlogit = output

            # ------------ PRIOR -----------
            # (embedding output)
            prior_array = batch.prior().astype(np.float32)  # (batch_size, output_size)
            # 0 protection so there is always something to sample
            epsilon = 0  # 1e-14 #1e0*np.finfo(np.float32).eps
            prior_array[prior_array == 0] = epsilon

            # To log
            prior = torch.tensor(prior_array, requires_grad=False)  # (batch_size, output_size)
            logprior = torch.log(prior)  # (batch_size, output_size)

            # ------------ SAMPLING ------------

            logit = outlogit + logprior  # (batch_size, output_size)
            action = torch.multinomial(torch.exp(logit),  # (batch_size,)
                                       num_samples=1)[:, 0]

            # ------------ ACTION ------------

            # Saving action n°i
            logits.append(logit)
            actions.append(action)

            # Informing embedding of new action
            # (embedding input)
            batch.programs.append(action.detach().cpu().numpy())

        # -------------------------------------------------
        # ------------------ CANDIDATES  ------------------
        # -------------------------------------------------

        # Keeping prob distribution history for backpropagation
        logits = torch.stack(logits, dim=0)  # (max_time_step, batch_size, n_choices, )
        actions = torch.stack(actions, dim=0)  # (max_time_step, batch_size,)

        # Programs as numpy array for black box reward computation
        actions_array = actions.detach().cpu().numpy()

        #sum_len += batch.programs.n_lengths.sum()
        R = reward.HierarchicalRewardsComputer(
            programs=batch.programs,
            X_original=X_original,
            y_original=y_original,
            run_config_original=run_config_original,
            original_library=original_library,
            f_eqs=f_eqs,
            evaluate_function=evaluate_function,
            free_const_opti_args=run_config_higher["free_const_opti_args"]
        )
        

        R = np.array(R)
        R = np.nan_to_num(R)
        fitness_temp+=(1/R-1).sum()
        fitness_count+=len(R)

        keep = R.argsort()[::-1][0:n_keep].copy()

        # ----------------- Train batch : black box part (NUMPY) -----------------
        # Elite candidates
        actions_array_train = copy.deepcopy(actions_array[:, keep])

        # Elite candidates as one-hot target probs
        ideal_probs_array_train = np.eye(batch.n_choices)[actions_array_train]
        R_train = torch.tensor(R[keep], requires_grad=False)  # (n_keep,)

        # Elite candidates pred logprobs
        logits_train = logits[:, keep]
        # Lengths of programs
        lengths = batch.programs.n_lengths[keep]

        if nextgen is not None and g % 2:
            ideal_probs_array_train = np.eye(batch.n_choices)[nextgen]
            R_train = R_train_next
            lengths = lengths_next
        # Elite candidates rewards

        R_lim = R_train.min()

        # Elite candidates as one-hot in torch
        # (non-differentiable tensors)
        ideal_probs_train = torch.tensor(  # (max_time_step, n_keep, n_choices,)
            ideal_probs_array_train.astype(np.float32),
            requires_grad=False, )

        baseline = R_lim

        # Loss
        loss_val = loss.loss_func(logits_train=logits_train,
                                  ideal_probs_train=ideal_probs_train,
                                  R_train=R_train,
                                  baseline=baseline,
                                  lengths=lengths,
                                  gamma_decay=run_config_higher["learning_config"]["gamma_decay"],
                                  entropy_weight=run_config_higher["learning_config"]["entropy_weight"], )


        # BACKPROPAGATION
        # -------------------------------------------------
        if model.is_lobotomized:
            pass
        else:
            loss_val.backward()
            optimizer.step()


        # -----------------交叉、变异算子---------------------
        offspring = copy.deepcopy(actions_array[:,keep])
        keep_length = batch.programs.n_lengths

        if nextgen is not None:
            offspring = np.concatenate((offspring,copy.deepcopy(nextgen)),axis=1)

        new_offspring = []
        # 对选出的个体进行交叉和变异
        for child1_i, child2_i in zip(range(0,offspring.shape[1],2),range(1,offspring.shape[1],2)):
            if random.random() < cxpb:
                child1,child2 = offspring[:, child1_i].T, offspring[:, child2_i].T
                templength = min(keep_length[child1_i],keep_length[child2_i])
                temp_child1,temp_child2 = single_point_crossover(child1,child2,templength)
                new_offspring.append(temp_child1.copy())
                new_offspring.append(temp_child2.copy())

        split_token = len(run_config_higher['library_config']["args_make_tokens"]["op_names"])
        for i in range(len(offspring[0])):
            if random.random() < mutpb:
                mutant=offspring[:,i].copy()
                # 随机选择mutant的一个点，变异成另外一个
                mutation_point = random.randint(0, keep_length[i]-1)
                if mutant[mutation_point]<split_token:
                    mutant[mutation_point] = random.randint(0,split_token-1)
                else:
                    mutant[mutation_point] = random.randint(split_token, batch.library.n_choices-1)
                new_offspring.append(mutant)

        # -----------------修正交叉、变异算子---------------------
        new_offspring = np.array(new_offspring)
        #mask_need_action = np.full(shape=(new_offspring.shape), fill_value=True, dtype=bool)
        batch_new = Batch.Batch(library_args=run_config_higher["library_config"],
                            priors_config=run_config_higher["priors_config"],
                            batch_size=len(new_offspring),
                            max_time_step=run_config_higher["learning_config"]["max_time_step"],
                            free_const_opti_args=run_config_higher["free_const_opti_args"],
                            X=X_higher,
                            y_target=y_higher,
                            )
        # Candidates
        logits_new = []
        actions_new = []

        for i in range(config.MAX_LENGTH):
            # ------------ PRIOR -----------
            # (embedding output)
            prior_array = batch_new.prior().astype(np.float32)  # (batch_size, output_size)
            action2 = new_offspring[:, i]
            prior_action = [bool(prior_array[idx][act]) for idx, act in enumerate(action2)]
            prior_action = list(map(bool, prior_action))
            # prior_action变为bool

            #mask_need_action = np.logical_and(prior_action, action2)


            # 0 protection so there is always something to sample
            epsilon = 0  # 1e-14 #1e0*np.finfo(np.float32).eps
            prior_array[prior_array == 0] = epsilon
            # To log
            prior = torch.tensor(prior_array, requires_grad=False)  # (batch_size, output_size)
            logprior = torch.log(prior)  # (batch_size, output_size)
            action1 = torch.multinomial(torch.exp(logprior),  # (batch_size,)
                                       num_samples=1)[:, 0]




            # ------------ SAMPLING ------------

            action = np.where(prior_action, action2, action1)

            # ------------ ACTION ------------

            # Saving action n°i
            logits_new.append(logit)
            actions_new.append(action)

            # Informing embedding of new action
            # (embedding input)
            batch_new.programs.append(action)


        # -----------------合并、选择下一代---------------------
        if nextgen is not None:
            actions_new = np.concatenate((copy.deepcopy(actions_new),copy.deepcopy(nextgen)),axis=1)
        actions_new = torch.tensor(actions_new)
        actions_new = torch.cat((copy.deepcopy(actions), copy.deepcopy(actions_new)), dim=1)
        actions_new = actions_new.detach().cpu().numpy()
        batch_combine = Batch.Batch(library_args=run_config_higher["library_config"],
                                priors_config=run_config_higher["priors_config"],
                                batch_size=actions_new.shape[1],
                                max_time_step=run_config_higher["learning_config"]["max_time_step"],
                                free_const_opti_args=run_config_higher["free_const_opti_args"],
                                X=X_higher,
                                y_target=y_higher,
                                )

        for i in range(config.MAX_LENGTH):
            action = actions_new[i,:]
            batch_combine.programs.append(action)

        R = reward.HierarchicalRewardsComputer(
            programs=batch_combine.programs,
            X_original=X_original,
            y_original=y_original,
            run_config_original=run_config_original,
            original_library=original_library,
            f_eqs=f_eqs,
            evaluate_function=evaluate_function,
            free_const_opti_args=run_config_higher["free_const_opti_args"]
        )
        R = np.array(R)
        R = np.nan_to_num(R)
        fitness_temp+=(1/R-1).sum()
        fitness_count+=len(R)

        keep = R.argsort()[::-1][0:n_keep].copy()
        nextgen = copy.deepcopy(actions_new[:, keep])
        R_train_next = torch.tensor(R[keep], requires_grad=False)  # (n_keep,)
        lengths_next = batch_combine.programs.n_lengths[keep]
        #sum_len += batch_combine.programs.n_lengths.sum()



        # -------------------------


        hall_of_fame = [batch_combine.programs.get_prog(keep[0])]
        hall_of_fame[0].free_const_values = torch.tensor(np.ones(batch_combine.programs.n_free_const_occurrences[keep[0]]))
        # For display, we show the high-level program
        expr_str = hall_of_fame[0].get_infix_pretty(do_simplify=True)
        print(expr_str,R[keep][0])
        fitness_generation.append(R[keep][0])
        
        if best_fitness < R[keep][0]:
            best_fitness = R[keep][0]
            best_prog = hall_of_fame[0]
            print("best fitness: ",best_fitness)
            # Also show the expanded program for the best one
            # This part is for display and debugging, not essential for the algorithm
            try:
                # Create a dummy HierarchicalRewardsComputer to access the expand method
                expander = reward.HierarchicalRewardsComputer(
                    programs=batch_combine.programs, 
                    X_original=X_original, 
                    y_original=y_original, 
                    run_config_original=run_config_original, 
                    original_library=original_library,
                    f_eqs=f_eqs, 
                    evaluate_function=evaluate_function, 
                    free_const_opti_args=run_config_higher["free_const_opti_args"])
                expanded_tokens = expander._expand_program(best_prog)
                if expanded_tokens:
                    expanded_tokens = np.array(expanded_tokens, dtype=np.int64).reshape(-1, 1)
                    batch_final = Batch.Batch(
                    library_args=run_config_original["library_config"],
                    priors_config=run_config_original["priors_config"],
                    batch_size=1,
                    max_time_step=len(expanded_tokens),
                    free_const_opti_args=run_config_higher["free_const_opti_args"],
                    X=X_original,
                    y_target=y_original,
                    )
                    for i in range(len(expanded_tokens)):
                        action = expanded_tokens[i,:]
                        batch_final.programs.append(action)
                    _ = reward.RewardsComputer(
                        programs=batch_final.programs,
                        X=X_original,
                        y_target=y_original,
                        evaluate_function=evaluate_function,
                        free_const_opti_args=run_config_original["free_const_opti_args"],
                    )
                    expanded_prog_obj = batch_final.programs.get_prog(0)
                    print(f"best prog expanded: {expanded_prog_obj.get_infix_pretty(do_simplify=True)}")
                else:
                    print("best prog expanded: Could not expand.")
            except Exception as e:
                print(f"Could not expand best program for display: {e}")

        if R[keep][0]>0.999:
            return expanded_prog_obj,best_prog
            
    with open('NetGP.pkl', 'wb') as f:
        pickle.dump({5:fitness_generation}, f)  # 保存
    
    return expanded_prog_obj, best_prog


