
import re
import os
import random
import matplotlib.pyplot as plt
import numpy as np 
from itertools import count  
import time
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
import torch.optim as optim  
from scipy.ndimage import gaussian_filter1d
from scipy.stats import pearsonr, spearmanr

from learn.constr2 import const
from learn.mygplearn.subexpr import expend_expr
from learn.newloss import lossop_func, lossvar_func, losstwovar_func
from learn.newprogram import Program

def make_hashable(obj):
    if isinstance(obj, list):
        return tuple(make_hashable(item) for item in obj)
    return obj

def in_str(preorder1, preorder,x_train_with_ones,y2_cache):
    if preorder.__len__()==0:
        return False,y2_cache
    for preorder2 in preorder:
        hashable_preorder2 = make_hashable(preorder2)
        y1 = const(preorder1, x_train_with_ones)
        if hashable_preorder2 in y2_cache:
            y2 = y2_cache[hashable_preorder2]
        else:
            if not isinstance(preorder2, list):
                if isinstance(preorder2, int):
                    y2 = np.full_like(x_train_with_ones[:, 0], preorder2)
                else:
                    preorder2_l = [preorder2]
                    y2 = const(preorder2_l, x_train_with_ones)
            else:
                y2 = const(preorder2, x_train_with_ones)

            y2_cache[hashable_preorder2] = tuple(y2.flatten())

        arr = y1 - y2
        is_all_zero = np.all(arr == 0)  
        if is_all_zero:
            return True,y2_cache
    return False,y2_cache
    
# library_op = ['add', 'sub', 'mul', 'div', 'sin', 'cos', 'sig', 'log', 'sqrt', 'id']


EPSILON = 0.001
EXP_THRESHOLD = 80.
INF = 1e6


def protected_div(x1, x2):
    return torch.where(torch.abs(x2) > EPSILON, torch.divide(x1, x2), torch.tensor(1., dtype=torch.float32))


exp_plateau = np.exp(EXP_THRESHOLD)


def protected_sig(x1):
    return 1 / (1 + torch.exp(-x1))



log_plateau = np.log(np.abs(EPSILON))
log_plateau_tensor = torch.tensor(log_plateau, dtype=torch.float32)


def protected_log(x1):
    return torch.where(torch.abs(x1) >= EPSILON, torch.log(torch.abs(x1)), log_plateau_tensor)


protected_logabs = protected_log


def protected_sqrt(x1):
    return torch.sqrt(torch.abs(x1))


def protected_id(x1):
    return x1


def is_constant_array(arr):
    return np.all(arr == arr[0]) 


def sample_feature(sample, y):
    sample = sample.squeeze()  
    y = y.squeeze()  
    mean = torch.mean(sample) 
    std = torch.std(sample)  
    median = torch.median(sample)  
    sample_np = sample.detach().numpy()
    y_np = y.detach().numpy()
    sample_np = np.nan_to_num(sample_np, nan=0.0, posinf=np.finfo(sample_np.dtype).max,
                              neginf=np.finfo(sample_np.dtype).min)
    y_np = np.nan_to_num(y_np, nan=0.0, posinf=np.finfo(y_np.dtype).max, neginf=np.finfo(y_np.dtype).min)
    if is_constant_array(sample_np) or is_constant_array(y_np):
        pearson_corr = torch.full_like(mean, 1.0)
        spearman_corr = torch.full_like(mean, 1.0)
    else:
        pearson_corr = pearsonr(sample_np, y_np)[0] 
        spearman_corr = spearmanr(sample_np, y_np)[0] 
    features = torch.tensor([mean, std, median, pearson_corr, spearman_corr], dtype=torch.float)
    return features


def mapping(mapping_set, str):
    if str in mapping_set:
        return mapping_set[str]
    return str


def to_one_hot(action, size=10):
    one_hot = torch.zeros(size) 
    one_hot[action] = 1  
    return one_hot

def generate_subprogram(op, var1, var2=None):
    str = None
    if var2:
        str = [op, var1, var2]
    else:
        str = [op, var1]
    return str




def calculate_histogram(x, y, xbins=10, ybins=10):
    xmin, xmax = x.min(), x.max()
    ymin, ymax = y.min(), y.max()
    x_interval = (xmax - xmin) / xbins
    y_interval = (ymax - ymin) / ybins
    x_in_bini = []
    y_in_binj = []
    xhist, yhist = np.zeros(xbins), np.zeros(ybins)
    L = np.zeros((xbins, ybins))
    for i in range(xbins):
        x_margin = (xmin + i * x_interval, xmin + (i + 1) * x_interval)
        if i == xbins - 1:
            x_in_bini.append(x >= x_margin[0])
        else:
            x_in_bini.append(np.logical_and(x >= x_margin[0], x < x_margin[1]))
        xhist[i] = np.sum(x_in_bini[i])
    for j in range(ybins):
        y_margin = (ymin + j * y_interval, ymin + (j + 1) * y_interval)
        if j == ybins - 1:
            y_in_binj.append(y >= y_margin[0])
        else:
            y_in_binj.append(np.logical_and(y >= y_margin[0], y < y_margin[1]))
        yhist[j] = np.sum(y_in_binj[j])
    for i in range(xbins):
        for j in range(ybins):
            xy_in_binij = np.logical_and(x_in_bini[i], y_in_binj[j])
            L[i, j] = np.sum(xy_in_binij)
    return xhist, yhist, L


def semantic_similarity(output, target):
    N = len(output)
    x, y = output, target
    xhist, yhist, L = calculate_histogram(x, y, xbins=int(np.sqrt(N)), ybins=int(np.sqrt(N)))
    px = xhist / N
    py = yhist / N
    pxy = L / N
    I = np.sum(pxy * np.log(pxy / (np.outer(px, py) + 1e-6) + 1e-6))
    return I


def mutate(x, y):
    xx = x.numpy()
    yy = y.numpy()
    ss = semantic_similarity(xx, yy)
    return ss

def is_sincos(list):
    if 'sin' in list or 'cos' in list:
        return True
    return False

def is_all_zeros_or_all_ones(tensor):
    is_all_zeros = torch.all(tensor == 0)
    is_all_ones = torch.all(tensor == 1)
    return is_all_zeros or is_all_ones


def learner(
        X_train,
        y_train,
        mu,
        std,
        nnvar,
        optimizervar,
        nnop,
        optimizerop,
        nntwovar,
        optimizertwovar,
        n_epochs,
        risk_factor,
        gamma_decay,
        entropy_weight,
        stop_reward,
        lib,
        batch_news,
        batch_size,
        max_time_step
        
):
    library_op = lib
    liblenwithoutid = len(lib)
    if 'id' in lib:
        liblenwithoutid = liblenwithoutid - 1

    y2_cache = {}
    

    x_train_numpy = X_train.numpy()
    ones_column = np.ones((x_train_numpy.shape[0], 1))
    x_train_with_ones = np.hstack((ones_column, x_train_numpy))

    ones_counts = []
    tmp_ones_count = X_train.shape[1] + 1
    ones_counts.append(tmp_ones_count)
    ones_counts.append(tmp_ones_count + liblenwithoutid)
    ones_counts.append(tmp_ones_count + liblenwithoutid * 2)
    ones_counts.append(tmp_ones_count + liblenwithoutid * 3)
    ones_counts.append(tmp_ones_count + liblenwithoutid * 4)
    ones_counts.append(tmp_ones_count + liblenwithoutid * 5)
    vector_size = tmp_ones_count + liblenwithoutid * (n_epochs-1)

    prior_arrays = np.array([np.concatenate((np.ones(count), np.zeros(vector_size - count))) for count in ones_counts])
    reward_final = []
    reward_average = []

    mapping_set = {} 

    new_best = 0
    new_best_program = ""
    new_best_r2 = 0
    new_best_infix = ""

    X_train_shape1 = X_train.shape[1]
    stringlist = ['x' + str(i) for i in range(X_train_shape1 + 1)]
    stringlist_save = ['x' + str(i+1) for i in range(X_train_shape1 + 1)]
    savelist = stringlist.copy()
    stringlist_copy = stringlist.copy()


    for epoch in range(n_epochs):
        if new_best_r2 > stop_reward:
            break

        batch_new = batch_news[epoch] 
        stringlist = stringlist_copy.copy()

        sub_length = len(stringlist)

        candidate_programs = []
        candidate_programs_values = []
        candidate_programs_raw = []
        save_programs = []
        for epoch_new in range(batch_new):
           

            n_keep = int(risk_factor * batch_size)  

            optimizervar.zero_grad()
            optimizerop.zero_grad()
            optimizertwovar.zero_grad()

            logitsvar = [] 
            actionsvar = []
            logitsop = []
            actionsop = []
            logitstwovar = []  
            actionstwovar = []
            masktwovar = []

            programs = []
            rewards = []
            sub_programs = []
            sub_programs_values = []
            y_state_value = y_train

            for population in range(batch_size):
                program = Program()  
                programs.append(program)  
                sub_program = []
                sub_programs.append(sub_program)
                sub_programs_value = []
                sub_programs_values.append(sub_programs_value)
            y_state = sample_feature(y_state_value, y_train)

            y_state = y_state.expand(batch_size, y_state.shape[0]).detach()  
           
            for i in range(max_time_step - 1):
                

                outlogit = nnvar(y_state)
                prior_array = np.tile(prior_arrays[epoch], (batch_size, 1))
                epsilon = 0  
                prior_array[prior_array == 0] = epsilon
                prior = torch.tensor(prior_array, requires_grad=False)  
                logprior = torch.log(prior)  
                logit = outlogit + logprior
                actionone = torch.multinomial(torch.exp(logit), num_samples=1)[:, 0]
                logitsvar.append(logit)
                actionsvar.append(actionone)


        
                batch_actioneone_present = []
                prior_array = []
                for cnt in range(batch_size):
                    array_of_ones = np.ones(10)
                    if is_sincos(mapping(mapping_set, stringlist[actionone[cnt]])):
                        array_of_ones[4] = 0
                        array_of_ones[5] = 0
                    prior_array.append(array_of_ones)
                prior_array = np.array(prior_array)
                for cnt in range(batch_size):
                    actioneone_present = to_one_hot(actionone[cnt], vector_size)
                    batch_actioneone_present.append(actioneone_present.unsqueeze(0))
                batch_actioneone_present = torch.cat(batch_actioneone_present, dim=0).detach()
                output = nnop(y_state, batch_actioneone_present)
                epsilon = 0  
                prior_array[prior_array == 0] = epsilon
                prior = torch.tensor(prior_array, requires_grad=False) 
                logprior = torch.log(prior)  
                logit = output + logprior
                option = torch.multinomial(torch.exp(logit), num_samples=1)[:, 0]
                logitsop.append(output)
                actionsop.append(option)


                mask = (option < 4).int()  
           
                batch_one_hot = []
                for cnt in range(batch_size):
                    actioneone_present = to_one_hot(option[cnt], len(lib))
                    batch_one_hot.append(actioneone_present.unsqueeze(0))
                batch_one_hot = torch.cat(batch_one_hot, dim=0).detach()
                outlogit = nntwovar(y_state, batch_actioneone_present, batch_one_hot)
                prior_array = np.tile(prior_arrays[epoch], (batch_size, 1))
                epsilon = 0  
                prior_array[prior_array == 0] = epsilon
                prior = torch.tensor(prior_array, requires_grad=False) 
                logprior = torch.log(prior)  
                logit = outlogit + logprior
                actiontwo = torch.multinomial(torch.exp(logit), num_samples=1)[:, 0]
                logitstwovar.append(logit)
                actionstwovar.append(actiontwo)
                masktwovar.append(mask)


                y_state_tmp = []
                for cnt in range(batch_size):
                    if mask[cnt]:
                        programs[cnt].build_program(library_op[option[cnt]],
                                                    mapping(mapping_set, stringlist[actionone[cnt]]),
                                                    mapping(mapping_set, stringlist[actiontwo[cnt]]))
                        sub_programs[cnt].append(
                            generate_subprogram(library_op[option[cnt]], stringlist[actionone[cnt]],
                                                stringlist[actiontwo[cnt]]))

                    else:
                        programs[cnt].build_program(library_op[option[cnt]],
                                                    mapping(mapping_set, stringlist[actionone[cnt]]))
                        sub_programs[cnt].append(
                            generate_subprogram(library_op[option[cnt]], stringlist[actionone[cnt]]))


                    programs[cnt].get_infix_str()
                    a1,a2,a3 = programs[cnt].const_optimize(X_train, y_train,mu,std)
                    if a2 == 1.0:
                        return a1,a2,a3
       
                    y_state_value_cnt = y_train - programs[cnt].actual_value
                    y_state_tmp.append(sample_feature(y_state_value_cnt, y_train).unsqueeze(0))
                    if i == max_time_step - 2:
                        sub_programs_values_return = programs[cnt].const_sub(X_train,max_time_step-1)
                        for value in sub_programs_values_return:
                            sub_programs_values[cnt].append(value)
                        rewards.append(programs[cnt].reward)
                y_state = torch.cat(y_state_tmp, dim=0)
                y_state = torch.where(torch.isinf(y_state) | torch.isnan(y_state), torch.tensor(1.0), y_state).detach()

            if epoch_new == batch_new-1:
                for save_cnt in range(batch_size):
                    save_programs.append(programs[save_cnt].actual_program)



    
            logitsvar = torch.stack(logitsvar, dim=0) 
            actionsvar = torch.stack(actionsvar, dim=0) 
            logitsop = torch.stack(logitsop, dim=0)  
            actionsop = torch.stack(actionsop, dim=0) 
            logitstwovar = torch.stack(logitstwovar, dim=0) 
            actionstwovar = torch.stack(actionstwovar, dim=0)  
            masktwovar = torch.stack(masktwovar, dim=0)  
            actionsvar_array = actionsvar.detach().cpu().numpy()  
            actionsop_array = actionsop.detach().cpu().numpy()
            actionstwovar_array = actionstwovar.detach().cpu().numpy()  

            R = np.array(rewards) 
            keep = R.argsort()[batch_size - n_keep:].copy() 
            R_train = torch.tensor(R[keep], requires_grad=False)  
            R_lim = R_train.min()
            baseline = R_lim

            actionsvar_array_train = actionsvar_array[:, keep]  
            actionsop_array_train = actionsop_array[:, keep]
            actionstwovar_array_train = actionstwovar_array[:, keep]

            ideal_probs_arrayvar_train = F.one_hot(torch.from_numpy(actionsvar_array_train),
                                                   num_classes=vector_size)  

            ideal_probs_arrayvar_train = ideal_probs_arrayvar_train.detach().cpu().numpy()
            ideal_probs_vartrain = torch.tensor(ideal_probs_arrayvar_train.astype(np.float32), requires_grad=False, )
            logitsvar_train = logitsvar[:,keep] 
            lengths = [max_time_step - 1] * n_keep

            lossvar_val = lossvar_func(logits_train=logitsvar_train,
                                       ideal_probs_train=ideal_probs_vartrain,
                                       R_train=R_train,
                                       baseline=baseline,
                                       lengths=lengths,
                                       gamma_decay=gamma_decay,
                                       entropy_weight=entropy_weight, )
            lossvar_val.backward()
            optimizervar.step()
            ideal_probs_arrayop_train = F.one_hot(torch.from_numpy(actionsop_array_train),
                                                  num_classes=len(lib)) 
            ideal_probs_arrayop_train = ideal_probs_arrayop_train.detach().cpu().numpy()
            ideal_probs_optrain = torch.tensor(ideal_probs_arrayop_train.astype(np.float32),
                                               requires_grad=False, )
            logitsop_train = logitsop[:,keep]  
            lengths = [max_time_step - 1] * n_keep
            lossop_val = lossop_func(logits_train=logitsop_train,
                                     ideal_probs_train=ideal_probs_optrain,
                                     R_train=R_train,
                                     baseline=baseline,
                                     lengths=lengths,
                                     gamma_decay=gamma_decay,
                                     entropy_weight=entropy_weight, )
            lossop_val.backward()
            optimizerop.step()
            ideal_probs_arraytwovar_train = F.one_hot(torch.from_numpy(actionstwovar_array_train),
                                                      num_classes=vector_size)  
            ideal_probs_arraytwovar_train = ideal_probs_arraytwovar_train.detach().cpu().numpy()
            ideal_probs_twovartrain = torch.tensor(ideal_probs_arraytwovar_train.astype(np.float32),
                                                   requires_grad=False, )
            logitstwovar_train = logitstwovar[:, keep]  
            masktwovar_train = masktwovar[:, keep]

            losstwovar_val = losstwovar_func(logits_train=logitstwovar_train,
                                             ideal_probs_train=ideal_probs_twovartrain,
                                             R_train=R_train,
                                             baseline=baseline,
                                             lengths=lengths,
                                             gamma_decay=gamma_decay,
                                             entropy_weight=entropy_weight,
                                             additional_mask=masktwovar_train)
            losstwovar_val.backward()
            optimizertwovar.step()

            if R[keep[-1]] > new_best:
                new_best = R[keep[-1]]
                new_best_program = programs[keep[-1]].actual_program
                new_best_r2 = programs[keep[-1]].r2
                new_best_infix = programs[keep[-1]].infix_str
                print("*********************BESTNEW*********************")
                print("--------------------NNEpoch %i/%i----------------" % (epoch, n_epochs))
                print("==================Epoch %i/%i====================" % (epoch_new, batch_new))
                print("OVERALLBEST:", new_best)
                print("OVERALLBESTprogram:", new_best_program)
                print("OVERALLBESTr2:", new_best_r2)



            reward_final.append(R[keep[-1]])
            reward_average.append(np.mean(R))

            if new_best_r2 > stop_reward:
                return new_best_program, new_best_r2,new_best_infix
                break
            if epoch == n_epochs-1:
                continue
            for keep_index in keep:
                keep_program = sub_programs[keep_index]
                keep_program_value = sub_programs_values[keep_index]
                for sub_keep_program, sub_keep_program_value in zip(keep_program, keep_program_value):
                    sub_flag = False
                    for sub in sub_keep_program:
                        if any(item in sub for item in stringlist_save):
                            sub_flag = True
                            break
                    if sub_flag:
                        if not is_all_zeros_or_all_ones(sub_keep_program_value):
                            tmp_list2 = []
                            ll = sub_keep_program.copy()
                            for tmp_list_cnt in range(len(ll)):
                                tmptmp = mapping(mapping_set, ll[tmp_list_cnt])
                                if isinstance(tmptmp, list):
                                    tmp_list2 += tmptmp
                                else:
                                    tmp_list2.append(tmptmp)
                            flag1, y2_cache = in_str(tmp_list2.copy(), candidate_programs_raw.copy(),
                                                     x_train_with_ones[:10,:], y2_cache)
                            if not flag1:
                                flag2, y2_cache = in_str(tmp_list2.copy(),
                                                         list(mapping_set.values()) + [1, 0] + savelist.copy(),
                                                         x_train_with_ones[:10,:], y2_cache)
                                if not flag2:
                                        candidate_programs.append(sub_keep_program)
                                        candidate_programs_raw.append(tmp_list2.copy())
                                        candidate_programs_values.append(sub_keep_program_value)
                                




        if epoch == n_epochs-1:
            continue
        candidate_programs_new, candidate_programs_values_new,y2_cache = expend_expr(X_train, y_train, save_programs,
                                                                            len(candidate_programs), candidate_programs,
                                                                            candidate_programs_values, epoch + 2,candidate_programs_raw.copy(),list(mapping_set.values())+[1,0]+savelist.copy(),x_train_with_ones[:10,:],y2_cache)
        candidate_programs = candidate_programs_new
        candidate_programs_values = candidate_programs_values_new

        mutate_arrays = []
        if len(candidate_programs) < liblenwithoutid:
            return new_best_program, new_best_r2
        for candidate_programs_value in candidate_programs_values:
            mutate_arrays.append(mutate(candidate_programs_value, y_train))
        M = np.array(mutate_arrays)

        if epoch % 2 !=0 :
            keep_mutate = M.argsort()[::-1].copy()
            tmp_cnt = 0
            for keep_mutate_index in keep_mutate:
                program_tmp = Program()
                program_tmp.program = candidate_programs[keep_mutate_index]
                program_tmp.get_infix_str()
                if tmp_cnt == liblenwithoutid:
                    break
                if program_tmp.infix_str in stringlist_copy:
                    continue

                tmp_cnt = tmp_cnt + 1

                sub_length = sub_length + 1
                tmp_list = []
                for tmp_list_cnt in range(len(candidate_programs[keep_mutate_index])):
                    tmptmp = mapping(mapping_set, candidate_programs[keep_mutate_index][tmp_list_cnt])
                    if isinstance(tmptmp, list):
                        tmp_list += tmptmp
                    else:
                        tmp_list.append(tmptmp)
                mapping_set[program_tmp.infix_str] = tmp_list
                stringlist_copy.append(program_tmp.infix_str)
        else:
            keep_mutate = M.argsort()[::-1].copy()
            add0 = 0
            sub0 = 0
            mul0 = 0
            div0 = 0
            sin0 = 0
            cos0 = 0
            log0 = 0
            sqrt0 = 0
            sig0 = 0
            tmp_cnt = 0

            while tmp_cnt != liblenwithoutid:
                for keep_mutate_index in keep_mutate:
                    program_tmp = Program()
                    program_tmp.program = candidate_programs[keep_mutate_index]
                    program_tmp.get_infix_str()
                    if tmp_cnt == liblenwithoutid:
                        break
                    if program_tmp.infix_str in stringlist_copy:
                        continue

                    if candidate_programs[keep_mutate_index][0] == 'add' and add0 == 0:
                        add0 = 1
                        tmp_cnt = tmp_cnt + 1


                        sub_length = sub_length + 1
                        tmp_list = []
                        for tmp_list_cnt in range(len(candidate_programs[keep_mutate_index])):
                            tmptmp = mapping(mapping_set, candidate_programs[keep_mutate_index][tmp_list_cnt])
                            if isinstance(tmptmp, list):
                                tmp_list += tmptmp
                            else:
                                tmp_list.append(tmptmp)
                        mapping_set[program_tmp.infix_str] = tmp_list
                        stringlist_copy.append(program_tmp.infix_str)


                    elif candidate_programs[keep_mutate_index][0] == 'sub' and sub0 == 0:
                        sub0 = 1
                        tmp_cnt = tmp_cnt + 1


                        sub_length = sub_length + 1
                        tmp_list = []
                        for tmp_list_cnt in range(len(candidate_programs[keep_mutate_index])):
                            tmptmp = mapping(mapping_set, candidate_programs[keep_mutate_index][tmp_list_cnt])
                            if isinstance(tmptmp, list):
                                tmp_list += tmptmp
                            else:
                                tmp_list.append(tmptmp)
                        mapping_set[program_tmp.infix_str] = tmp_list
                        stringlist_copy.append(program_tmp.infix_str)

                    elif candidate_programs[keep_mutate_index][0] == 'mul' and mul0 == 0:
                        mul0 = 1
                        tmp_cnt = tmp_cnt + 1


                        sub_length = sub_length + 1
                        tmp_list = []
                        for tmp_list_cnt in range(len(candidate_programs[keep_mutate_index])):
                            tmptmp = mapping(mapping_set, candidate_programs[keep_mutate_index][tmp_list_cnt])
                            if isinstance(tmptmp, list):
                                tmp_list += tmptmp
                            else:
                                tmp_list.append(tmptmp)
                        mapping_set[program_tmp.infix_str] = tmp_list
                        stringlist_copy.append(program_tmp.infix_str)

                    elif candidate_programs[keep_mutate_index][0] == 'div' and div0 == 0:
                        div0 = 1
                        tmp_cnt = tmp_cnt + 1


                        sub_length = sub_length + 1
                        tmp_list = []
                        for tmp_list_cnt in range(len(candidate_programs[keep_mutate_index])):
                            tmptmp = mapping(mapping_set, candidate_programs[keep_mutate_index][tmp_list_cnt])
                            if isinstance(tmptmp, list):
                                tmp_list += tmptmp
                            else:
                                tmp_list.append(tmptmp)
                        mapping_set[program_tmp.infix_str] = tmp_list
                        stringlist_copy.append(program_tmp.infix_str)

                    elif candidate_programs[keep_mutate_index][0] == 'sin' and sin0 == 0:
                        if 'cos' in candidate_programs[keep_mutate_index][1:] or 'sin' in candidate_programs[keep_mutate_index][1:]:
                            continue
                        sin0 = 1
                        tmp_cnt = tmp_cnt + 1


                        sub_length = sub_length + 1
                        tmp_list = []
                        for tmp_list_cnt in range(len(candidate_programs[keep_mutate_index])):
                            tmptmp = mapping(mapping_set, candidate_programs[keep_mutate_index][tmp_list_cnt])
                            if isinstance(tmptmp, list):
                                tmp_list += tmptmp
                            else:
                                tmp_list.append(tmptmp)
                        mapping_set[program_tmp.infix_str] = tmp_list
                        stringlist_copy.append(program_tmp.infix_str)

                    elif candidate_programs[keep_mutate_index][0] == 'cos' and cos0 == 0:
                        if 'cos' in candidate_programs[keep_mutate_index][1:] or 'sin' in candidate_programs[keep_mutate_index][1:]:
                            continue
                        cos0 = 1
                        tmp_cnt = tmp_cnt + 1


                        sub_length = sub_length + 1
                        tmp_list = []
                        for tmp_list_cnt in range(len(candidate_programs[keep_mutate_index])):
                            tmptmp = mapping(mapping_set, candidate_programs[keep_mutate_index][tmp_list_cnt])
                            if isinstance(tmptmp, list):
                                tmp_list += tmptmp
                            else:
                                tmp_list.append(tmptmp)
                        mapping_set[program_tmp.infix_str] = tmp_list
                        stringlist_copy.append(program_tmp.infix_str)

                    elif candidate_programs[keep_mutate_index][0] == 'log' and log0 == 0:
                        log0 = 1
                        tmp_cnt = tmp_cnt + 1


                        sub_length = sub_length + 1
                        tmp_list = []
                        for tmp_list_cnt in range(len(candidate_programs[keep_mutate_index])):
                            tmptmp = mapping(mapping_set, candidate_programs[keep_mutate_index][tmp_list_cnt])
                            if isinstance(tmptmp, list):
                                tmp_list += tmptmp
                            else:
                                tmp_list.append(tmptmp)
                        mapping_set[program_tmp.infix_str] = tmp_list
                        stringlist_copy.append(program_tmp.infix_str)

                    elif candidate_programs[keep_mutate_index][0] == 'sqrt' and sqrt0 == 0:
                        sqrt0 = 1
                        tmp_cnt = tmp_cnt + 1


                        sub_length = sub_length + 1
                        tmp_list = []
                        for tmp_list_cnt in range(len(candidate_programs[keep_mutate_index])):
                            tmptmp = mapping(mapping_set, candidate_programs[keep_mutate_index][tmp_list_cnt])
                            if isinstance(tmptmp, list):
                                tmp_list += tmptmp
                            else:
                                tmp_list.append(tmptmp)
                        mapping_set[program_tmp.infix_str] = tmp_list
                        stringlist_copy.append(program_tmp.infix_str)

                    elif candidate_programs[keep_mutate_index][0] == 'sig' and sig0 == 0:
                        if candidate_programs[keep_mutate_index][1] == 'log':
                            continue
                        sig0 = 1
                        tmp_cnt = tmp_cnt + 1

                        sub_length = sub_length + 1
                        tmp_list = []
                        for tmp_list_cnt in range(len(candidate_programs[keep_mutate_index])):
                            tmptmp = mapping(mapping_set, candidate_programs[keep_mutate_index][tmp_list_cnt])
                            if isinstance(tmptmp, list):
                                tmp_list += tmptmp
                            else:
                                tmp_list.append(tmptmp)
                        mapping_set[program_tmp.infix_str] = tmp_list
                        stringlist_copy.append(program_tmp.infix_str)


                add0 = 0
                sub0 = 0
                mul0 = 0
                div0 = 0
                sin0 = 0
                cos0 = 0
                log0 = 0
                sqrt0 = 0
                sig0 = 0

        max_time_step = max_time_step - 1
    return new_best_program
















































































































































































