# coding:UTF-8
# @Time: 2024/8/26 13:48
# @Author: Lulu Cao
# @File: reward.py
# @Software: PyCharm
import warnings

import numpy as np
import torch as torch
import physym.execute as exec
import physym.evaluate as eva
import torch.nn.functional as F
import physym.batch as Batch
from physym import program

# During programs evaluation, should parallel execution be used ?
USE_PARALLEL_EXE        = False  # Only worth it if n_samples > 1e6
USE_PARALLEL_OPTI_CONST = True   # Only worth it if batch_size > 1k

def SquashedNRMSE (y_target, y_pred,):
    """
    Squashed NRMSE reward.
    Parameters
    ----------
    y_target : torch.tensor of shape (?,) of float
        Target output data.
    y_pred   : torch.tensor of shape (?,) of float
        Predicted data.
    Returns
    -------
    reward : torch.tensor float
        Reward encoding prediction vs target discrepancy in [0,1].
    """
    sigma_targ = y_target.std()
    RMSE = torch.sqrt(torch.mean((y_pred-y_target)**2))
    NRMSE = (1/sigma_targ)*RMSE
    reward = 1/(1 + NRMSE)
    return reward

def SquashedNRMSE_to_R2 (reward):
    """
    Converts SquashedNRMSE reward to R2 score.
    Parameters
    ----------
    reward : torch.tensor float
        Reward encoding prediction vs target discrepancy in [0,1].
    Returns
    -------
    R2 : torch.tensor float
        R2 score.
    """
    R2 = 2/reward - (1/reward)**2
    return R2

def msedi(y_true, y_pred):
    """
    计算差分的均方误差 (Mean Squared Error of Difference, MSEDI)。

    这个适应度函数基于真实值和预测值的经验导数之间的差异应最小化的思想。
    它源于导数差异的对数似然。

    损失计算为 y_true 和 y_pred 的连续差分的均方误差。

    损失 = mean( ( (y_true[i+1] - y_true[i]) - (y_pred[i+1] - y_pred[i]) )^2 )

    参数:
        y_true (np.ndarray): 真实目标值。
        y_pred (np.ndarray): 模型的预测值。

    返回:
        float: MSEDI 值。
    """
    try:
        # 确保输入是 NumPy 数组
        y_true = np.asarray(y_true)
        y_pred = y_pred.detach().numpy()

        # 确保形状匹配
        if y_true.shape != y_pred.shape:
            raise ValueError("y_true 和 y_pred 的形状必须相同。")

        # 如果是列向量，则展平
        if len(y_true.shape) > 1:
            y_true = y_true.flatten()
            y_pred = y_pred.flatten()

        # 计算连续元素之间的差分
        diff_true = np.diff(y_true)
        diff_pred = np.diff(y_pred)

        # 计算差分的平方差
        squared_diff_of_diffs = np.square(diff_true - diff_pred)

        # 返回均值
        return np.mean(squared_diff_of_diffs)
    except:
        warnings.warn("MSEDI calculation failed, returning NaN.")
        return 100

def RewardsComputer(programs,
                    X,
                    y_target,
                    evaluate_function,
                    free_const_opti_args = None,
                    zero_out_duplicates = False,
                    keep_lowest_complexity_duplicate = False,
                    parallel_mode = False,
                    n_cpus = None,
                    progress_bar = False,
                    ):
    """
    Computes rewards of programs on X data accordingly with target y_target and reward reward_function using torch
    for acceleration.
    Parameters
    ----------
    programs : Program.VectProgram
        Programs contained in batch to evaluate.
    X : torch.tensor of shape (n_dim, ?,) of float
        Values of the input variables of the problem with n_dim = nb of input variables.
    y_target : torch.tensor of shape (?,) of float
        Values of the target symbolic function on input variables contained in X_target.
    free_const_opti_args : dict or None, optional
        Arguments to pass to free_const.optimize_free_const for free constants optimization. By default,
        free_const.DEFAULT_OPTI_ARGS arguments are used.


    zero_out_duplicates : bool
        Should duplicate programs (equal symbolic value when simplified) be zeroed out ?
    keep_lowest_complexity_duplicate : bool
        If True, when eliminating duplicates (via zero_out_duplicates = True), the least complex duplicate is kept, else
        a random duplicate is kept.
    Returns
    -------
    rewards : numpy.array of shape (?,) of float
        Rewards of programs.
    """

    # ----- SETUP -----

    # mask : should program reward NOT be zeroed out ie. is program invalid ?
    # 默认所有的reward都是有效的
    mask_valid = np.full(shape=programs.batch_size, fill_value=True, dtype=bool)                         # (batch_size,)

                                                       # (batch_size,)

    # ----- DUPLICATES -----
    # if zero_out_duplicates:
    #     # Compute rewards (even if programs have non-optimized free consts) to serve as a unique numeric identifier of
    #     # functional forms (programs having equivalent forms will have the same reward).
    #
    #     # Only use parallel mode if enabled in function param and in USE_PARALLEL_EXE flag.
    #     # This way users can use flags to specifically enable or disable parallel exe and/or const opti.
    #     parallel_mode_exe = parallel_mode and USE_PARALLEL_EXE
    #     rewards_non_opt = programs.batch_exe_reward (X        = X,
    #                                                  y_target = y_target,
    #                                                  reward_function = reward_function,
    #                                                  mask            = mask_valid,
    #                                                  pad_with        = 0.0,
    #                                                  # Parallel related
    #                                                  parallel_mode   = parallel_mode_exe,
    #                                                  n_cpus          = n_cpus,
    #                                                 )
    #     # mask : is program a unique one we should keep ?
    #     # By default, all programs are eliminated.
    #     mask_unique_keep = np.full(shape=programs.batch_size, fill_value=False, dtype=bool)              # (batch_size,)
    #     # Identifying unique programs.
    #     unique_rewards, unique_idx = np.unique(rewards_non_opt, return_index=True)                       # (n_unique,), (n_unique,)
    #     if keep_lowest_complexity_duplicate:
    #         unique_idx_lowest_comp = []
    #         # Iterating through unique rewards
    #         for r in unique_rewards:
    #             # mask: does program have current unique reward ?
    #             mask_have_r = (rewards_non_opt == r)                                                     # (batch_size,)
    #             # complexities of programs having current unique reward
    #             complexities_at_r = programs.n_complexity[mask_have_r]                                   # (n_at_r,)
    #             # idx in batch of program having current unique reward of the lowest complexity
    #             idx_lowest_comp = np.arange(programs.batch_size)[mask_have_r][complexities_at_r.argmin()]
    #             unique_idx_lowest_comp.append(idx_lowest_comp)
    #         # Idx of unique programs (having the lowest complexity among their duplicates)
    #         unique_idx_lowest_comp = np.array(unique_idx_lowest_comp)
    #         # Keeping the lowest complexity duplicate of unique programs
    #         mask_unique_keep[unique_idx_lowest_comp] = True
    #     else:
    #         # Keeping first occurrences of unique programs (random)
    #         mask_unique_keep[unique_idx] = True                                                          # (n_unique,)
    #     # Update mask to zero out duplicate programs
    #     mask_valid = (mask_valid & mask_unique_keep)                                                     # (batch_size,)

    # ----- FREE CONST OPTIMIZATION -----
    # If there are free constants in the library, we have to optimize.py them
    if programs.library.n_free_const > -1:
        # Only use parallel mode if enabled in function param and in USE_PARALLEL_OPTI_CONST flag.
        # This way users can use flags to specifically enable or disable parallel exe and/or const opti.
        parallel_mode_const_opti = parallel_mode and USE_PARALLEL_OPTI_CONST
        # Opti const
        # batch_optimize_free_const (programs, X, y_target, args_opti = free_const_opti_args, mask_valid = mask_valid)
        programs.batch_optimize_constants(X        = X,
                                          y_target = y_target,
                                          free_const_opti_args = free_const_opti_args,
                                          mask                 = mask_valid,)
    results = []
    pb = lambda x: x
    #evaluate_function = None
    for i in pb(range(programs.batch_size)):
        prog = programs.get_prog(i, skeleton=True)
        y_pred, X_temp = prog.torch_exec(X, prog.tokens, prog.free_const_values)
        result_data = F.mse_loss(y_pred,y_target).detach().numpy()
        if evaluate_function is None:
            result_pde =  msedi(y_target, y_pred)
        else:
            result_pde = evaluate_function(y_pred, X_temp).detach().numpy()
        
        result = 1 / (1 + result_pde+result_data/0.001)
        # result = 1 / (1 + result_data/0.001)
        # result = -result
        results.append(result)

        




    return results





def make_RewardsComputer(reward_function     = SquashedNRMSE,
                         zero_out_unphysical = False,
                         zero_out_duplicates = False,
                         keep_lowest_complexity_duplicate = False,
                         # Parallel related
                         parallel_mode = True,
                         n_cpus        = None,
                         ):
    """
    Helper function to make custom reward computing function.
    Parameters
    ----------
    reward_function : callable
        Reward function to use that takes y_target (torch.tensor of shape (?,) of float) and y_pred (torch.tensor of
        shape (?,) of float) as key arguments and returns a float reward of an individual program.
    zero_out_unphysical : bool
        Should unphysical programs be zeroed out ?
    zero_out_duplicates : bool
        Should duplicate programs (equal symbolic value when simplified) be zeroed out ?
    keep_lowest_complexity_duplicate : bool
        If True, when eliminating duplicates (via zero_out_duplicates = True), the least complex duplicate is kept, else
        a random duplicate is kept.
    parallel_mode : bool
        Tries to use parallel execution if True (availability will be checked by execute.ParallelExeAvailability),
        execution in a loop else.
    n_cpus : int or None
        Number of CPUs to use when running in parallel mode. By default, uses the maximum number of CPUs available.
    Returns
    -------
    rewards_computer : callable
         Custom reward computing function taking programs (program.VectPrograms), X (torch.tensor of shape (n_dim,?,)
         of float), y_target (torch.tensor of shape (?,) of float), free_const_opti_args as key arguments and returning reward for each
         program (array_like of float).
    """
    # Check that parallel execution is available on this system
    recommended_config = exec.ParallelExeAvailability()
    is_parallel_mode_available_on_system = recommended_config["parallel_mode"]
    # If not available and parallel_mode was still instructed warn and disable
    if not is_parallel_mode_available_on_system and parallel_mode:
        exec.ParallelExeAvailability(verbose=True) # prints explanation
        warnings.warn("Parallel mode is not available on this system, switching to non parallel mode.")
        parallel_mode = False

    # rewards_computer
    def rewards_computer(programs, X, y_target, free_const_opti_args):
        R = RewardsComputer(programs = programs,
                            X        = X,
                            y_target = y_target,
                            free_const_opti_args = free_const_opti_args,
                            # Frozen args
                            reward_function     = reward_function,
                            zero_out_unphysical = zero_out_unphysical,
                            zero_out_duplicates = zero_out_duplicates,
                            keep_lowest_complexity_duplicate = keep_lowest_complexity_duplicate,
                            # Parallel related
                            parallel_mode = parallel_mode,
                            n_cpus        = n_cpus,
                            )
        return R

    return rewards_computer

class HierarchicalRewardsComputer:
    """
    Computes rewards for hierarchically structured programs by expanding them first.
    This class takes a batch of high-level programs (e.g., expressed in terms of f1, f2)
    and computes their rewards by:
    1. Expanding each high-level program into its full low-level representation
       (e.g., in terms of x1, x2, t).
    2. Using the standard RewardsComputer to evaluate this expanded program, which
       includes constant optimization and fitness calculation.
    """
    def __init__(self,
                 programs,
                 X_original,
                 y_original,
                 run_config_original,
                 original_library,
                 f_eqs,
                 evaluate_function,
                 free_const_opti_args,
                 ):
        """
        Parameters
        ----------
        programs : physym.program.VectPrograms
            A batch of high-level programs to evaluate.
        X_original : torch.tensor
            The original low-level input data (e.g., with x1, x2, t).
        y_original : torch.tensor
            The original low-level target data.
        run_config_original : dict
            The run configuration containing the library for the low-level programs.
        original_library : physym.library.Library
            The library corresponding to the original, low-level variables.
        f_eqs : dict
            A dictionary mapping high-level feature names (str) to their low-level
            program objects (physym.program.Program).
        evaluate_function : callable or None
            The function to evaluate the physics-based part of the reward.
        free_const_opti_args : dict
            Arguments for the constant optimization step.
        """
        self.programs_higher = programs
        self.X_original = X_original
        self.y_original = y_original
        self.run_config_original = run_config_original
        self.original_library = original_library
        self.f_eqs = f_eqs
        self.evaluate_function = evaluate_function
        self.free_const_opti_args = free_const_opti_args

        # We need a library instance to map token names to indices.
        self.original_lib_name_to_idx = original_library.lib_choosable_name_to_idx

        self.rewards = self._compute_rewards()

    def __getitem__(self, i):
        return self.rewards[i]

    def __len__(self):
        return len(self.rewards)

    def _expand_program(self, prog_higher):
        """
        Expands a single high-level program into its low-level token sequence.
        It traverses the high-level program and replaces each token with the
        corresponding token(s) from the low-level library.
        """
        original_level_tokens = []
        for token_idx in prog_higher.tokens:

            # If the token is an input variable (e.g., 'f1'), replace it with its sub-expression
            if token_idx.name in self.f_eqs.keys():
                feature_name = token_idx.name
                sub_prog = self.f_eqs.get(feature_name)
                if sub_prog:
                    # Append the tokens of the sub-expression
                    for sub_token_idx in sub_prog.tokens:
                        original_idx = self.original_lib_name_to_idx.get(sub_token_idx.name)
                        if original_idx is not None:
                            original_level_tokens.append(original_idx)
                        else:
                            return None  # Expansion fails if a token is not in the main library
                else:
                    return None  # Expansion fails if a sub-expression is missing
            # If the token is an operator or constant, find its index in the original library
            else:
                original_idx = self.original_lib_name_to_idx.get(token_idx.name)
                if original_idx is not None:
                    original_level_tokens.append(original_idx)
                else:
                    return None # Expansion fails if an operator is missing
        return original_level_tokens

    def _compute_rewards(self):
        """
        Computes rewards for all high-level programs in the batch by expanding
        and then evaluating them.
        """
        all_rewards = []
        for i in range(self.programs_higher.batch_size):
            prog_higher = self.programs_higher.get_prog(i)

            

            # Expand the program from high-level to low-level tokens
            original_tokens = self._expand_program(prog_higher)

            # If expansion fails, assign a very low reward
            if original_tokens is None:
                all_rewards.append(-np.inf)
                continue

            # Create a temporary batch for the single expanded program to evaluate it
            batch_final = Batch.Batch(
                library_args=self.run_config_original["library_config"],
                priors_config=self.run_config_original["priors_config"],
                batch_size=1,
                max_time_step=len(original_tokens),
                free_const_opti_args=self.free_const_opti_args,
                X=self.X_original,
                y_target=self.y_original,
            )
            #tokens = batch_final.library.lib_tokens[original_tokens]
            original_tokens = np.array(original_tokens, dtype=np.int64).reshape(-1, 1)
            for i in range(len(original_tokens)):
                action = original_tokens[i,:]
                batch_final.programs.append(action)
            # Use the standard RewardsComputer for the expanded, low-level program
            # This will handle constant optimization and reward calculation.
            R = RewardsComputer(
                programs=batch_final.programs,
                X=self.X_original,
                y_target=self.y_original,
                evaluate_function=self.evaluate_function,
                free_const_opti_args=self.free_const_opti_args,
            )
            all_rewards.append(R[0])

        return np.array(all_rewards)