from collections import OrderedDict
from dataclasses import dataclass
from typing import Dict
from typing import Optional
from typing import Tuple
from typing import Union
from abc import ABCMeta
import numpy as np
from scipy.special import softmax
from sklearn.base import BaseEstimator
from sklearn.base import ClassifierMixin
from sklearn.base import clone
from sklearn.base import is_classifier
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPRegressor
from sklearn.utils import check_random_state
from sklearn.utils import check_scalar
import torch
from sklearn.preprocessing import PolynomialFeatures
import torch.nn as nn
from torch.nn.functional import mse_loss
import torch.optim as optim
from tqdm import tqdm
from obp.utils import check_array
from obp.utils import check_bandit_feedback_inputs
from obp.utils import check_tensor
from sklearn.linear_model import LinearRegression
from obp.ope import RegressionModel
from scipy.stats import truncnorm
from obp.policy import BaseOfflinePolicyLearner
from sklearn.neighbors import KernelDensity
import torch
import torch.nn as nn
import torch.optim as optim
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from torch.distributions import Categorical, Normal, MixtureSameFamily
from sklearn.pipeline import make_pipeline
from dataclasses import dataclass, field
import numpy as np
import math
import time
import matplotlib.pyplot as plt
import scipy.stats as stats
import scipy.special
from skopt.space import Real
from skopt.utils import use_named_args
from torch.distributions import Categorical, Normal, MixtureSameFamily
from sklearn.pipeline import make_pipeline
from dataclasses import dataclass, field
import numpy as np
import math
import time
from scipy.optimize import curve_fit, minimize
import matplotlib.pyplot as plt
import scipy.stats as stats
import scipy.special
from obp.dataset import linear_behavior_policy
from utils.datasets_no_n_user import s_sum_function

@dataclass
class GammaOptimizer3:
    n_actions: int
    context: np.ndarray
    action: np.ndarray
    surrogate_reward: np.ndarray
    reward: np.ndarray
    obs_list: np.ndarray
    alpha_noise: float
    p_o: float
    s_sum: np.ndarray = None
    pscore: np.ndarray = None
    beta: float = None
    add_noise: float = 0.1
    len_list: int = 1   
    dim_context: Optional[int] = None
    s_dim: Optional[int] = None
    off_policy_objective: Optional[str] = None
    policy_reg_param: float = 0.0
    var_reg_param: float = 0.0
    hidden_layer_size: Tuple[int, ...] = (100,)
    activation: str = "relu"
    solver: str = "adam"
    alpha: float = 0.0001
    batch_size: Union[int, str] = "auto"
    learning_rate_init: float = 0.0001
    max_iter: int = 200
    shuffle: bool = True
    random_state: Optional[int] = None
    tol: float = 1e-4
    momentum: float = 0.9
    nesterovs_momentum: bool = True
    early_stopping: bool = False
    validation_fraction: float = 0.1
    beta_1: float = 0.9
    beta_2: float = 0.999
    epsilon: float = 1e-8
    n_iter_no_change: int = 10
    q_func_estimator_hyperparams: Optional[Dict] = None
    q_func_with_s_hyperparams: Optional[Dict] = None

    def __post_init__(self):
        self.dim_context = self.context.shape[1] if self.context is not None else None
        if self.q_func_estimator_hyperparams is not None:
            self.q_func_estimator_hyperparams["n_actions"] = self.n_actions
            self.q_func_estimator_hyperparams["dim_context"] = self.dim_context
            self.q_func_estimator = QFuncEstimator(
                **self.q_func_estimator_hyperparams
            )
        else:
            self.q_func_estimator = QFuncEstimator(
                n_actions=self.n_actions, dim_context=self.dim_context
            )
        if self.q_func_estimator_hyperparams is not None:
            self.q_func_estimator_hyperparams["n_actions"] = self.n_actions
            self.q_func_estimator_hyperparams["dim_context"] = self.dim_context
            self.f_func_estimator = QFuncEstimator(
                **self.q_func_estimator_hyperparams
            )
        else:
            self.f_func_estimator = QFuncEstimator(
                n_actions=self.n_actions, dim_context=self.dim_context
            )
        if self.q_func_with_s_hyperparams is not None:
            self.q_func_with_s_hyperparams["n_actions"]=self.n_actions
            self.q_func_with_s_hyperparams["dim_context"]=self.dim_context
            self.q_func_with_s_hyperparams["s_dim"]=self.s_dim
            self.q_func_with_s_estimator = QFuncEstimatorWithS(
                **self.q_func_with_s_hyperparams
            )
        else:
            self.q_func_with_s_estimator = QFuncEstimatorWithS(
                n_actions=self.n_actions, dim_context=self.dim_context, s_dim=self.s_dim
            )
        if self.q_func_estimator_hyperparams is not None:
            self.q_func_estimator_hyperparams["n_actions"] = self.n_actions
            self.q_func_estimator_hyperparams["dim_context"] = self.dim_context
            self.q_func_estimator_test = QFuncEstimator(
                **self.q_func_estimator_hyperparams
            )
        else:
            self.q_func_estimator_test = QFuncEstimator(
                n_actions=self.n_actions, dim_context=self.dim_context
            )
        if self.q_func_estimator_hyperparams is not None:
            self.q_func_estimator_hyperparams["n_actions"] = self.n_actions
            self.q_func_estimator_hyperparams["dim_context"] = self.dim_context
            self.f_func_estimator_test = QFuncEstimator(
                **self.q_func_estimator_hyperparams
            )
        else:
            self.f_func_estimator_test = QFuncEstimator(
                n_actions=self.n_actions, dim_context=self.dim_context
            )
        if self.q_func_with_s_hyperparams is not None:
            self.q_func_with_s_hyperparams["n_actions"]=self.n_actions
            self.q_func_with_s_hyperparams["dim_context"]=self.dim_context
            self.q_func_with_s_hyperparams["s_dim"]=self.s_dim
            self.q_func_with_s_estimator_test = QFuncEstimatorWithS(
                **self.q_func_with_s_hyperparams
            )
        else:
            self.q_func_with_s_estimator_test = QFuncEstimatorWithS(
                n_actions=self.n_actions, dim_context=self.dim_context, s_dim=self.s_dim
            )
        self.random_ = check_random_state(self.random_state)
        
    
    def split_data(
        self, num_train: int, context: np.ndarray, action: np.ndarray, surrogate_reward: np.ndarray, reward: np.ndarray, obs_list: np.ndarray, s_sum: np.ndarray, pscore: np.ndarray, action_dist: Optional[np.ndarray] = None, random_state: Optional[int] = None
    ):
        n_rounds = context.shape[0]  # Assuming context is not empty and others have the same n_rounds
        assert num_train <= n_rounds, "num_train cannot be greater than the total number of samples"

        # Generate shuffled indices
        indices = np.arange(n_rounds)
        random_gen = np.random.default_rng(random_state)
        random_gen.shuffle(indices)

        # Split indices into training and test sets
        train_indices = indices[:num_train]
        test_indices = indices[num_train:]

        # Function to subset the arrays
        def subset_data(data):
            return data[train_indices], data[test_indices]

        # Applying subsetting to each array
        train_context, test_context = subset_data(context)
        train_action, test_action = subset_data(action)
        train_surrogate_reward, test_surrogate_reward = subset_data(surrogate_reward)
        train_reward, test_reward = subset_data(reward)
        train_obs_list, test_obs_list = subset_data(obs_list)
        train_s_sum, test_s_sum = subset_data(s_sum)
        train_pscore, test_pscore = subset_data(pscore)
        
        train_obs_context = train_context[train_obs_list == 1]
        test_obs_context = test_context[test_obs_list == 1]
        train_obs_action = train_action[train_obs_list == 1]
        test_obs_action = test_action[test_obs_list == 1]
        train_obs_reward = train_reward[train_obs_list == 1]
        test_obs_reward = test_reward[test_obs_list == 1]
        
        # Return dictionaries of training and testing data
        train_data = {
            'context': train_context,
            'action': train_action,
            'surrogate_reward': train_surrogate_reward,
            'reward': train_reward,
            'obs_list': train_obs_list,
            's_sum': train_s_sum,
            'pscore': train_pscore,
            'obs_context': train_obs_context,
            'obs_action': train_obs_action,
            'obs_reward': train_obs_reward,
        }

        test_data = {
            'context': test_context,
            'action': test_action,
            'surrogate_reward': test_surrogate_reward,
            'reward': test_reward,
            'obs_list': test_obs_list,
            's_sum': test_s_sum,
            'pscore': test_pscore,
            'obs_context': test_obs_context,
            'obs_action': test_obs_action,
            'obs_reward': test_obs_reward,
        }
        
        if action_dist is not None:
            train_action_dist, test_action_dist = subset_data(action_dist)
            train_data["action_dist"] = train_action_dist
            test_data["action_dist"] = test_action_dist

        return train_data, test_data
        
    def get_variance(self, beta:float, train_data:dict, test_data:dict, n_gam:int, train_action_dist:np.ndarray, test_action_dist:np.ndarray):
        gammas_a=[]
        gammas_b=[]
        std=[]
        gammas = np.arange(0, 1 + 1/n_gam, 1/n_gam)
        if gammas[-1] > 1:
            gammas[-1] = 1 
        for gamma in gammas:
            learning_gamma_value = self.predict_policy_value(
                context=train_data["context"],
                action=train_data["action"],
                surrogate_reward=train_data["surrogate_reward"],
                reward=train_data["reward"],
                obs_list=train_data["obs_list"],
                pscore=train_data["pscore"],
                p_o=self.p_o,
                s_sum=train_data["s_sum"],
                beta=gamma,
                action_dist=train_action_dist,
                test=False
            )
            diff_gamma_value = self.predict_policy_value(
                context=test_data["context"],
                action=test_data["action"],
                surrogate_reward=test_data["surrogate_reward"],
                reward=test_data["reward"],
                obs_list=test_data["obs_list"],
                pscore=test_data["pscore"],
                p_o=self.p_o,
                s_sum=test_data["s_sum"],
                beta=gamma,
                action_dist=test_action_dist,
                test=True
                )
            
            avg_value = (learning_gamma_value+diff_gamma_value)/2
            
            gammas_a.append(learning_gamma_value/avg_value)
            gammas_b.append(diff_gamma_value/avg_value)
            std.append(abs((learning_gamma_value/avg_value)-(diff_gamma_value/avg_value)))
        return std

    def predict_policy_value(
        self,
        context: np.ndarray, 
        action: np.ndarray, 
        surrogate_reward: np.ndarray, 
        reward: np.ndarray, 
        obs_list: np.ndarray, 
        pscore: np.ndarray, 
        p_o: float, 
        s_sum: np.ndarray, 
        beta: float, 
        action_dist: np.ndarray,
        test: bool
    ):
        context = torch.tensor(context, dtype=torch.float32)
        action = torch.tensor(action, dtype=torch.long)
        surrogate_reward = torch.tensor(surrogate_reward, dtype=torch.float32)
        reward = torch.tensor(reward, dtype=torch.float32)
        obs_list = torch.tensor(obs_list, dtype=torch.float32)
        pscore = torch.tensor(pscore, dtype=torch.float32)
        s_sum = torch.tensor(s_sum, dtype=torch.float32)
        action_dist = torch.tensor(action_dist, dtype=torch.float32)
        action_dist = torch.squeeze(action_dist, -1)
        
        if test:
            q_func_estimator = self.q_func_estimator_test
            q_func_with_s_estimator = self.q_func_with_s_estimator_test
            f_func_estimator = self.f_func_estimator_test
        else:
            q_func_estimator = self.q_func_estimator
            q_func_with_s_estimator = self.q_func_with_s_estimator
            f_func_estimator = self.f_func_estimator
            
        n_rounds = context.shape[0]
        idx_tensor = torch.arange(action.shape[0], dtype=torch.long)
        if p_o != 0.0 and 0 < beta < 1:
            q_hat = q_func_estimator.predict(context)
            q_hat_with_s = q_func_with_s_estimator.predict(context, surrogate_reward)
            q_hat_factual = q_hat[idx_tensor, action]
            q_hat_factual_with_s = q_hat_with_s[idx_tensor, action]
            iw = action_dist[idx_tensor, action] / pscore
            pw = obs_list / p_o
            obs_policy_val = pw * iw * (reward - q_hat_factual_with_s)
            policy_val = iw * (q_hat_factual_with_s - q_hat_factual)
            policy_val += obs_policy_val
            policy_val += torch.sum(q_hat * action_dist, dim=1)
            
            f_hat = f_func_estimator.predict(context)
            f_hat_factual = f_hat[idx_tensor, action]
            side_policy_val = iw * (s_sum - f_hat_factual)
            side_policy_val += torch.sum(f_hat * action_dist, dim=1)
            
            policy_val = beta * side_policy_val + (1 - beta) * policy_val
        elif p_o != 0.0 and beta == 0:
            q_hat = q_func_estimator.predict(context)
            q_hat_with_s = q_func_with_s_estimator.predict(context, surrogate_reward)
            q_hat_factual = q_hat[idx_tensor, action]
            q_hat_factual_with_s = q_hat_with_s[idx_tensor, action]
            iw = action_dist[idx_tensor, action] / pscore
            pw = obs_list / p_o
            obs_policy_val = pw * iw * (reward - q_hat_factual_with_s)
            policy_val = iw * (q_hat_factual_with_s - q_hat_factual)
            policy_val += obs_policy_val
            policy_val += torch.sum(q_hat * action_dist, dim=1)
        else:
            iw = action_dist[idx_tensor, action] / pscore
            f_hat = f_func_estimator.predict(context)
            f_hat_factual = f_hat[idx_tensor, action]
            side_policy_val = iw * (s_sum - f_hat_factual)
            side_policy_val += torch.sum(f_hat * action_dist, dim=1)
            policy_val = side_policy_val

        return policy_val.mean().item()
            

    def evaluate_gamma(self, gamma:float, train_data:dict, test_data:dict):
        ours_gamma = OurLearner(
            n_actions=self.n_actions,
            len_list=self.len_list,
            dim_context=self.dim_context,
            s_dim=self.s_dim,
            off_policy_objective=self.off_policy_objective,
            policy_reg_param=self.policy_reg_param,
            var_reg_param=self.var_reg_param,
            hidden_layer_size=self.hidden_layer_size,
            activation=self.activation,
            solver=self.solver,
            alpha=self.alpha,
            batch_size=self.batch_size,
            learning_rate_init=self.learning_rate_init,
            max_iter=self.max_iter,
            shuffle=self.shuffle,
            random_state=self.random_state,
            tol=self.tol,
            momentum=self.momentum,
            nesterovs_momentum=self.nesterovs_momentum,
            early_stopping=self.early_stopping,
            validation_fraction=self.validation_fraction,
            beta_1=self.beta_1,
            beta_2=self.beta_2,
            epsilon=self.epsilon,
            n_iter_no_change=self.n_iter_no_change,
        )
        ours_gamma.fit(
            context=train_data["context"],
            action=train_data["action"],
            surrogate_reward=train_data["surrogate_reward"],
            reward=train_data["reward"],
            obs_list=train_data["obs_list"],
            p_o=self.p_o,
            s_sum=train_data["s_sum"],
            pscore=train_data["pscore"],
            beta=gamma
        )
        ours_gamma_action_dist = ours_gamma.predict(context=test_data["context"])
        ours_gamma_value = self.predict_policy_value(
            context=test_data["context"],
            action=test_data["action"],
            surrogate_reward=test_data["surrogate_reward"],
            reward=test_data["reward"],
            obs_list=test_data["obs_list"],
            pscore=test_data["pscore"],
            p_o=self.p_o,
            s_sum=test_data["s_sum"],
            beta=self.beta,
            action_dist=ours_gamma_action_dist,
            test=True
            )
        ours_train_gamma_action_dist = ours_gamma.predict(context=train_data["context"])
        ours_gamma_action_dist = ours_gamma.predict_proba(context=test_data["context"])
        ours_train_gamma_value = ours_gamma.predict_proba(context=train_data["context"])
        
        # learning_beta_value = ours_gamma.predict_own(
        #     context=train_data["context"],
        #     action=train_data["action"],
        #     surrogate_reward=train_data["surrogate_reward"],
        #     reward=train_data["reward"],
        #     obs_list=train_data["obs_list"],
        #     pscore=train_data["pscore"],
        #     p_o=self.p_o,
        #     s_sum=train_data["s_sum"],
        #     beta=self.beta,
        #     action_dist=ours_train_gamma_action_dist,
        #     test=False
        # )
        # learning_gamma_value = self.predict_policy_value(
        #     context=train_data["context"],
        #     action=train_data["action"],
        #     surrogate_reward=train_data["surrogate_reward"],
        #     reward=train_data["reward"],
        #     obs_list=train_data["obs_list"],
        #     pscore=train_data["pscore"],
        #     p_o=self.p_o,
        #     s_sum=train_data["s_sum"],
        #     beta=gamma,
        #     action_dist=ours_train_gamma_action_dist,
        #     test=False
        # )
        # diff_gamma_value = self.predict_policy_value(
        #     context=test_data["context"],
        #     action=test_data["action"],
        #     surrogate_reward=test_data["surrogate_reward"],
        #     reward=test_data["reward"],
        #     obs_list=test_data["obs_list"],
        #     pscore=test_data["pscore"],
        #     p_o=self.p_o,
        #     s_sum=test_data["s_sum"],
        #     beta=gamma,
        #     action_dist=ours_gamma_action_dist,
        #     test=True
        #     )
        # std= self.get_variance(beta=self.beta, train_data=train_data, test_data=test_data, n_gam=self.n_gam, train_action_dist=ours_train_gamma_action_dist, test_action_dist=ours_gamma_action_dist)
        # argmax_std = np.argmax(std)/self.n_gam
        # beta_variance = abs(learning_beta_value-ours_gamma_value)
        # learning_bias=abs(learning_beta_value-learning_gamma_value)
        # diff_bias = abs(ours_gamma_value-diff_gamma_value)
        # bias=(learning_bias+diff_bias)/2
        # print(f"Gamma={gamma}, Argmax STD={argmax_std}, Argmin STD={np.argmin(std)/self.n_gam}")
        # print(f"std={std}")
        return -ours_gamma_value, 0
    
    def golden_section_search(self, func, initial_guess, tol=1e-5, max_iter=20):
        gr = (math.sqrt(5) + 1) / 2
        a = 0  # Using self.beta as the lower bound
        b = 1  # Using 1 as the upper bound

        c = b - (b - a) / gr
        d = a + (b - a) / gr

        # Collect all gamma, variance, and bias
        gamma_list, variance_list = [], []

        fc, fc_var = func(c)
        fd, fd_var = func(d)

        gamma_list.extend([c, d])
        variance_list.extend([fc_var, fd_var])

        min_value = min(fc, fd)
        best_gamma = c if fc < fd else d

        iter_count = 0
        while abs(c - d) > tol and iter_count < max_iter:
            if fc < fd:
                b = d
                d = c
                c = b - (b - a) / gr
                fd, fd_var = fc, fc_var
                fc, fc_var = func(c)
            else:
                a = c
                c = d
                d = a + (b - a) / gr
                fc, fc_var = fd, fd_var
                fd, fd_var = func(d)

            gamma_list.extend([c, d])
            variance_list.extend([fc_var, fd_var])

            if fc < min_value:
                min_value = fc
                best_gamma = c
            if fd < min_value:
                min_value = fd
                best_gamma = d

            iter_count += 1

        print(f"Optimized gamma after {iter_count} iterations: {best_gamma}, Minimum function value: {min_value}")
        return best_gamma, variance_list

    def optimize_gamma(self, train_data, test_data, initial_guess: float, n_calls: int = 20):
        def evaluate(gamma):
            value, variance = self.evaluate_gamma(gamma, train_data, test_data)
            print(f"Evaluating gamma={gamma}: Value={value}")
            return value, variance

        print(f"Starting optimization with initial guess: {initial_guess}")
        
        optimal_gamma, variance_list = self.golden_section_search(evaluate, initial_guess, tol=1e-5, max_iter=n_calls)
        
        return optimal_gamma, variance_list

    def optimize(self, num_tries: int = 5, num_gamma: int = 4, n_gam: int = 200):
        start_time = time.time()
        self.n_gam = n_gam
        data_size = int((self.context.shape[0]/5)*3)
        optimized_gamma = []
        half_gamma=[]
        result_gamma=[]
        initial_guess = min(1, self.beta + 0.3)
        n_calls = num_gamma
        gamma_list = []
        std_list = []
        bias_list = []

        for i in range(num_tries):
            train_data, test_data = self.split_data(
                num_train=data_size,
                context=self.context,
                action=self.action,
                surrogate_reward=self.surrogate_reward,
                reward=self.reward,
                obs_list=self.obs_list,
                s_sum=self.s_sum,
                pscore=self.pscore,
                random_state = self.random_state+(500*i)
            )
            train_data = self.bootstrap(train_data=train_data, train_size = self.context.shape[0])
            if self.p_o!=0.0 and self.beta<1.0:
                # self.q_func_estimator.fit(
                #     context=train_data["obs_context"],
                #     action=train_data["obs_action"],
                #     reward=train_data["obs_reward"],
                # )
                
                # self.q_func_with_s_estimator.fit(
                #     context=train_data["context"],
                #     action=train_data["action"],
                #     reward=train_data["reward"],
                #     surrogate_reward = train_data["surrogate_reward"],
                #     obs_list = train_data["obs_list"],
                # )
                
                self.q_func_estimator_test.fit(
                    context=test_data["obs_context"],
                    action=test_data["obs_action"],
                    reward=test_data["obs_reward"],
                )
                self.q_func_with_s_estimator_test.fit(
                    context=test_data["context"],
                    action=test_data["action"],
                    reward=test_data["reward"],
                    surrogate_reward = test_data["surrogate_reward"],
                    obs_list = test_data["obs_list"],
                )
            # self.f_func_estimator.fit(
            #     context=train_data["context"],
            #     action=train_data["action"],
            #     reward=train_data["reward"],
            # )
            self.f_func_estimator_test.fit(
                context=test_data["context"],
                action=test_data["action"],
                reward=test_data["reward"],
            )
        
            optimal_gamma, std = self.optimize_gamma(
                train_data=train_data, test_data=test_data, initial_guess=initial_guess, 
                n_calls=n_calls
            )
            optimized_gamma.append(optimal_gamma)

            std_list.extend(std)
            initial_guess = optimal_gamma  
            n_calls = num_gamma
        
        predicted_gamma = np.mean(optimized_gamma) 
        predicted_gamma = min(1, max(self.beta, predicted_gamma))
        print(f"Optimal at added training data: {predicted_gamma}")
        # print(f"Time taken: {time.time() - start_time:.2f} seconds")
        # print(f"shape of std_list: {np.array(std_list).shape}")
        # stds = np.mean(std_list, axis=0)
        # print(f"STDs: {stds}")
        # gammas = np.arange(0, 1 + 1/n_gam, 1/n_gam)
        # if gammas[-1] > 1:
        #     gammas[-1] = 1 

        # params, _ = curve_fit(lambda gamma, a, b, c: a * gamma ** 2 + b * gamma + c, gammas, stds)
        # sigma = lambda gamma: (params[0] * gamma ** 2 + params[1] * gamma + params[2]) / np.sqrt(2)

        # total_function = lambda gamma, y: y * np.abs(self.beta - gamma) + sigma(gamma)
        # result_y = minimize(lambda y: total_function(predicted_gamma, y), x0=[0.1])
        # y_optimal = result_y.x[0]
        # print(f"Optimal y: {y_optimal}")
        # result_gamma = minimize(lambda gamma: total_function(gamma, y_optimal), x0=[0.5], bounds=[(0, 1)])
        # best_gamma = result_gamma.x[0]
        # print(f"Optimal gamma: {best_gamma}")
        # print(f"Time taken: {time.time() - start_time:.2f} seconds")
        return predicted_gamma
    
    def bootstrap(self, train_data: dict, train_size: int):
        current_size = train_data["context"].shape[0]
        if current_size >= train_size:
            return train_data

        num_to_add = train_size - current_size

        # Randomly sample indices with replacement
        bootstrap_indices = self.random_.choice(np.arange(current_size), size=train_size, replace=True)

        # Create new bootstrapped dataset
        bootstrapped_data = {
            "context": train_data["context"][bootstrap_indices],
            "surrogate_reward": train_data["surrogate_reward"][bootstrap_indices],
            "reward": train_data["reward"][bootstrap_indices],
            "action": train_data["action"][bootstrap_indices],
            "obs_list": train_data["obs_list"][bootstrap_indices],
            "pscore": train_data["pscore"][bootstrap_indices]
        }

        # Recalculate s_sum for the bootstrapped data
        s_sum_result = s_sum_function(
            surrogate_rewards=bootstrapped_data["surrogate_reward"],
            alpha_noise=self.alpha_noise,
            random_state=self.random_state
        )
        
        # Check if s_sum_function returned a tuple, and if so, take the first element
        if isinstance(s_sum_result, tuple):
            bootstrapped_data["s_sum"] = s_sum_result[0]
        else:
            bootstrapped_data["s_sum"] = s_sum_result

        # Update observed data
        obs_mask = bootstrapped_data["obs_list"] == 1
        bootstrapped_data["obs_context"] = bootstrapped_data["context"][obs_mask]
        bootstrapped_data["obs_action"] = bootstrapped_data["action"][obs_mask]
        bootstrapped_data["obs_reward"] = bootstrapped_data["reward"][obs_mask]

        return bootstrapped_data
    
    def optimize_gamma_true(self, train_data, test_data, initial_guess: float, n_calls: int = 20):
        def evaluate_true(gamma):
            value, primary_value, surrogate_value = self.evaluate_gamma_true(gamma, train_data, test_data)
            print(f"Evaluating gamma={gamma}: Value={value}")
            return value, primary_value, surrogate_value

        print(f"Starting optimization with initial guess: {initial_guess}")
        
        optimal_gamma, true_value, primary_value, surrogate_value = self.golden_section_search_true(evaluate_true, initial_guess, tol=1e-5, max_iter=n_calls)
        return optimal_gamma, true_value, primary_value, surrogate_value 
    
    def optimize_true(self, 
                    train_data,
                    test_data,
                    num_gamma: int = 10,
                    ):
        start_time = time.time()
        # if self.p_o!=0.0:
        #     if self.beta<1.0:
        #         self.q_func_estimator.fit(
        #             context=train_data["obs_contexts"],
        #             action=train_data["obs_actions"],
        #             reward=train_data["obs_rewards"],
        #         )

        #         self.q_func_with_s_estimator.fit(
        #             context=train_data["contexts"],
        #             action=train_data["actions"],
        #             reward=train_data["rewards"],
        #             surrogate_reward = train_data["surrogate_rewards"],
        #             obs_list = train_data["obs_list"],
        #         )
        # self.f_func_estimator.fit(
        #     context=train_data["contexts"],
        #     action=train_data["actions"],
        #     reward=train_data["rewards"],
        # )
        optimal_gamma, true_value, primary_value, surrogate_value = self.optimize_gamma_true(
            train_data,
            test_data,
            initial_guess = min(self.beta+0.3, 1),
            n_calls = num_gamma,
        )
        
        return optimal_gamma, -true_value, primary_value, surrogate_value
        
    def golden_section_search_true(self, func, initial_guess, tol=1e-5, max_iter=20):
        gr = (math.sqrt(5) + 1) / 2
        a = max(self.beta-0.4, 0)  # Using self.beta as the lower bound
        b = 1  # Using 1 as the upper bound

        c = b - (b - a) / gr
        d = a + (b - a) / gr

        gamma_list, true_val_list, primary_val_list, surrogate_val_list = [], [], [], []

        fc, fc_primary, fc_surrogate = func(c)
        fd, fd_primary, fd_surrogate = func(d)

        gamma_list.extend([c, d])
        true_val_list.extend([fc, fd])
        primary_val_list.extend([fc_primary, fd_primary])
        surrogate_val_list.extend([fc_surrogate, fd_surrogate])

        min_value = min(fc, fd)
        best_gamma = c if fc < fd else d

        iter_count = 0
        while abs(c - d) > tol and iter_count < max_iter:
            if fc < fd:
                b = d
                d = c
                c = b - (b - a) / gr
                fd, fd_primary, fd_surrogate = fc, fc_primary, fc_surrogate
                fc, fc_primary, fc_surrogate = func(c)
            else:
                a = c
                c = d
                d = a + (b - a) / gr
                fc, fc_primary, fc_surrogate = fd, fd_primary, fd_surrogate
                fd, fd_primary, fd_surrogate = func(d)

            gamma_list.extend([c, d])
            true_val_list.extend([fc, fd])
            primary_val_list.extend([fc_primary, fd_primary])
            surrogate_val_list.extend([fc_surrogate, fd_surrogate])

            if fc < min_value:
                min_value = fc
                best_gamma = c
            if fd < min_value:
                min_value = fd
                best_gamma = d

            iter_count += 1

        best_val = np.argmin(true_val_list)
        best_value = np.min(true_val_list)
        best_gamma = gamma_list[best_val]
        best_primary = primary_val_list[best_val]
        best_surrogate = surrogate_val_list[best_val]
        print(f"Optimized gamma after {iter_count} iterations: {best_gamma}, Minimum function value: {min_value}")
        if best_value != min_value:
            raise ValueError(f"Best value is not {min_value} but {best_value}")
        return best_gamma, best_value, best_primary, best_surrogate

    
    
    def evaluate_gamma_true(self, gamma:float, train_data:dict, test_data:dict):
        ours_gamma = OurLearner(
            n_actions=self.n_actions,
            len_list=self.len_list,
            dim_context=self.dim_context,
            s_dim=self.s_dim,
            off_policy_objective=self.off_policy_objective,
            policy_reg_param=self.policy_reg_param,
            var_reg_param=self.var_reg_param,
            hidden_layer_size=self.hidden_layer_size,
            activation=self.activation,
            solver=self.solver,
            alpha=self.alpha,
            batch_size=self.batch_size,
            learning_rate_init=self.learning_rate_init,
            max_iter=self.max_iter,
            shuffle=self.shuffle,
            random_state=self.random_state,
            tol=self.tol,
            momentum=self.momentum,
            nesterovs_momentum=self.nesterovs_momentum,
            early_stopping=self.early_stopping,
            validation_fraction=self.validation_fraction,
            beta_1=self.beta_1,
            beta_2=self.beta_2,
            epsilon=self.epsilon,
            n_iter_no_change=self.n_iter_no_change,
        )
        ours_gamma.fit(
            context=train_data["contexts"],
            action=train_data["actions"],
            surrogate_reward=train_data["surrogate_rewards"],
            reward=train_data["rewards"],
            obs_list=train_data["obs_list"],
            p_o=self.p_o,
            s_sum=train_data["s_sum"],
            pscore=train_data["pscores"],
            beta=gamma
        )
        ours_gamma_action_dist = ours_gamma.predict(context=test_data["contexts"])
        ours_gamma_value = self.calc_true_policy_value(
            expected_reward=test_data["all_q_x_a_f"],
            expected_surrogate_reward=test_data["f_sum"],
            action_dist = ours_gamma_action_dist,
            beta=self.beta,
        )
        ours_gamma_primary_value = self.calc_ground_truth_policy_value(
            expected_reward=test_data["all_q_x_a_f"],
            action_dist = ours_gamma_action_dist,
        )
        
        ours_gamma_surrogate_value = self.calc_ground_truth_policy_value(
            expected_reward=test_data["f_sum"],
            action_dist = ours_gamma_action_dist,
        )
        
        return -ours_gamma_value, ours_gamma_primary_value, ours_gamma_surrogate_value
    
    def calc_true_policy_value(
        self, expected_reward: np.ndarray, expected_surrogate_reward: np.ndarray, action_dist: np.ndarray, beta:float
    ) -> float:

        full_expected_reward = (1-beta)*expected_reward + (beta*expected_surrogate_reward)
        max_rewards_per_round = np.max(full_expected_reward, axis=1)
        average_max_reward = np.mean(max_rewards_per_round)
        expected_reward_given_act_dist = np.average(full_expected_reward, weights=action_dist[:, :, 0], axis=1).mean()
        expected_reward_uniform_policy = np.mean(full_expected_reward)
        return (expected_reward_given_act_dist-expected_reward_uniform_policy)/(average_max_reward-expected_reward_uniform_policy)

    def calc_ground_truth_policy_value(
        self, expected_reward: np.ndarray, action_dist: np.ndarray
    ) -> float:
        check_array(array=expected_reward, name="expected_reward", expected_dim=2)
        check_array(array=action_dist, name="action_dist", expected_dim=3)
        if expected_reward.shape[0] != action_dist.shape[0]:
            raise ValueError(
                "Expected `expected_reward.shape[0] = action_dist.shape[0]`, but found it False"
            )
        if expected_reward.shape[1] != action_dist.shape[1]:
            raise ValueError(
                "Expected `expected_reward.shape[1] = action_dist.shape[1]`, but found it False"
            )
        max_rewards_per_round = np.max(expected_reward, axis=1)
        average_max_reward = np.mean(max_rewards_per_round)
        expected_reward_given_act_dist = np.average(expected_reward, weights=action_dist[:, :, 0], axis=1).mean()
        expected_reward_uniform_policy = np.mean(expected_reward)
        return (expected_reward_given_act_dist-expected_reward_uniform_policy)/(average_max_reward-expected_reward_uniform_policy)
    
@dataclass
class GammaOptimizer2:
    n_actions: int
    context: np.ndarray
    action: np.ndarray
    surrogate_reward: np.ndarray
    reward: np.ndarray
    obs_list: np.ndarray
    p_o: float
    s_sum: np.ndarray = None
    pscore: np.ndarray = None
    beta: float = None
    len_list: int = 1   
    dim_context: Optional[int] = None
    s_dim: Optional[int] = None
    off_policy_objective: Optional[str] = None
    policy_reg_param: float = 0.0
    var_reg_param: float = 0.0
    hidden_layer_size: Tuple[int, ...] = (100,)
    activation: str = "relu"
    solver: str = "adam"
    alpha: float = 0.0001
    batch_size: Union[int, str] = "auto"
    learning_rate_init: float = 0.0001
    max_iter: int = 200
    shuffle: bool = True
    random_state: Optional[int] = None
    tol: float = 1e-4
    momentum: float = 0.9
    nesterovs_momentum: bool = True
    early_stopping: bool = False
    validation_fraction: float = 0.1
    beta_1: float = 0.9
    beta_2: float = 0.999
    epsilon: float = 1e-8
    n_iter_no_change: int = 10
    q_func_estimator_hyperparams: Optional[Dict] = None
    q_func_with_s_hyperparams: Optional[Dict] = None

    def __post_init__(self):
        self.dim_context = self.context.shape[1] if self.context is not None else None
        self.random_gen = np.random.RandomState(self.random_state)
        if self.q_func_estimator_hyperparams is not None:
            self.q_func_estimator_hyperparams["n_actions"] = self.n_actions
            self.q_func_estimator_hyperparams["dim_context"] = self.dim_context
            self.q_func_estimator = QFuncEstimator(
                **self.q_func_estimator_hyperparams
            )
        else:
            self.q_func_estimator = QFuncEstimator(
                n_actions=self.n_actions, dim_context=self.dim_context
            )
        if self.q_func_estimator_hyperparams is not None:
            self.q_func_estimator_hyperparams["n_actions"] = self.n_actions
            self.q_func_estimator_hyperparams["dim_context"] = self.dim_context
            self.f_func_estimator = QFuncEstimator(
                **self.q_func_estimator_hyperparams
            )
        else:
            self.f_func_estimator = QFuncEstimator(
                n_actions=self.n_actions, dim_context=self.dim_context
            )
        if self.q_func_with_s_hyperparams is not None:
            self.q_func_with_s_hyperparams["n_actions"]=self.n_actions
            self.q_func_with_s_hyperparams["dim_context"]=self.dim_context
            self.q_func_with_s_hyperparams["s_dim"]=self.s_dim
            self.q_func_with_s_estimator = QFuncEstimatorWithS(
                **self.q_func_with_s_hyperparams
            )
        else:
            self.q_func_with_s_estimator = QFuncEstimatorWithS(
                n_actions=self.n_actions, dim_context=self.dim_context, s_dim=self.s_dim
            )

        
    def split_data(
        self, num_train: int, context: np.ndarray, action: np.ndarray, surrogate_reward: np.ndarray, reward: np.ndarray, obs_list: np.ndarray, s_sum: np.ndarray, pscore: np.ndarray, action_dist: Optional[np.ndarray] = None
    ):
        n_rounds = context.shape[0]  # Assuming context is not empty and others have the same n_rounds
        assert num_train <= n_rounds, "num_train cannot be greater than the total number of samples"

        # Generate shuffled indices
        indices = np.arange(n_rounds)
        self.random_gen.shuffle(indices)

        # Split indices into training and test sets
        train_indices = indices[:num_train]
        test_indices = indices[num_train:]

        # Function to subset the arrays
        def subset_data(data):
            return data[train_indices], data[test_indices]

        # Applying subsetting to each array
        train_context, test_context = subset_data(context)
        train_action, test_action = subset_data(action)
        train_surrogate_reward, test_surrogate_reward = subset_data(surrogate_reward)
        train_reward, test_reward = subset_data(reward)
        train_obs_list, test_obs_list = subset_data(obs_list)
        train_s_sum, test_s_sum = subset_data(s_sum)
        train_pscore, test_pscore = subset_data(pscore)

        # Return dictionaries of training and testing data
        train_data = {
            'context': train_context,
            'action': train_action,
            'surrogate_reward': train_surrogate_reward,
            'reward': train_reward,
            'obs_list': train_obs_list,
            's_sum': train_s_sum,
            'pscore': train_pscore,
        }

        test_data = {
            'context': test_context,
            'action': test_action,
            'surrogate_reward': test_surrogate_reward,
            'reward': test_reward,
            'obs_list': test_obs_list,
            's_sum': test_s_sum,
            'pscore': test_pscore,
        }
        
        if action_dist is not None:
            train_action_dist, test_action_dist = subset_data(action_dist)
            train_data["action_dist"] = train_action_dist
            test_data["action_dist"] = test_action_dist

        return train_data, test_data


    def predict_policy_value(
        self,
        context: np.ndarray, 
        action: np.ndarray, 
        surrogate_reward: np.ndarray, 
        reward: np.ndarray, 
        obs_list: np.ndarray, 
        pscore: np.ndarray, 
        p_o: float, 
        s_sum: np.ndarray, 
        beta: float, 
        action_dist: np.ndarray
    ):
        # Convert all numpy arrays to PyTorch tensors
        context = torch.tensor(context, dtype=torch.float32)
        action = torch.tensor(action, dtype=torch.long)
        surrogate_reward = torch.tensor(surrogate_reward, dtype=torch.float32)
        reward = torch.tensor(reward, dtype=torch.float32)
        obs_list = torch.tensor(obs_list, dtype=torch.float32)
        pscore = torch.tensor(pscore, dtype=torch.float32)
        s_sum = torch.tensor(s_sum, dtype=torch.float32)
        action_dist = torch.tensor(action_dist, dtype=torch.float32)
        action_dist = torch.squeeze(action_dist, -1)

        n_rounds = context.shape[0]
        idx_tensor = torch.arange(action.shape[0], dtype=torch.long)
        if p_o != 0.0 and 0 < beta < 1:
            q_hat = self.q_func_estimator.predict(context)
            q_hat_with_s = self.q_func_with_s_estimator.predict(context, surrogate_reward)
            q_hat_factual = q_hat[idx_tensor, action]
            q_hat_factual_with_s = q_hat_with_s[idx_tensor, action]
            iw = action_dist[idx_tensor, action] / pscore
            pw = obs_list / p_o
            obs_policy_val = pw * iw * (reward - q_hat_factual_with_s)
            policy_val = iw * (q_hat_factual_with_s - q_hat_factual)
            policy_val += obs_policy_val
            policy_val += torch.sum(q_hat * action_dist, dim=1)
            
            f_hat = self.f_func_estimator.predict(context)
            f_hat_factual = f_hat[idx_tensor, action]
            side_policy_val = iw * (s_sum - f_hat_factual)
            side_policy_val += torch.sum(f_hat * action_dist, dim=1)
            
            policy_val = beta * side_policy_val + (1 - beta) * policy_val
        elif p_o != 0.0 and beta == 0:
            q_hat = self.q_func_estimator.predict(context)
            q_hat_with_s = self.q_func_with_s_estimator.predict(context, surrogate_reward)
            q_hat_factual = q_hat[idx_tensor, action]
            q_hat_factual_with_s = q_hat_with_s[idx_tensor, action]
            iw = action_dist[idx_tensor, action] / pscore
            pw = obs_list / p_o
            obs_policy_val = pw * iw * (reward - q_hat_factual_with_s)
            policy_val = iw * (q_hat_factual_with_s - q_hat_factual)
            policy_val += obs_policy_val
            policy_val += torch.sum(q_hat * action_dist, dim=1)
        else:
            iw = action_dist[idx_tensor, action] / pscore
            f_hat = self.f_func_estimator.predict(context)
            f_hat_factual = f_hat[idx_tensor, action]
            side_policy_val = iw * (s_sum - f_hat_factual)
            side_policy_val += torch.sum(f_hat * action_dist, dim=1)
            policy_val = side_policy_val

        return policy_val.mean().item()
            

    def learn_policy(self, gamma:float, train_data:dict, test_data:dict, n_gam:int):
        ours_gamma = OurLearner(
            n_actions=self.n_actions,
            len_list=self.len_list,
            dim_context=self.dim_context,
            s_dim=self.s_dim,
            off_policy_objective=self.off_policy_objective,
            policy_reg_param=self.policy_reg_param,
            var_reg_param=self.var_reg_param,
            hidden_layer_size=self.hidden_layer_size,
            activation=self.activation,
            solver=self.solver,
            alpha=self.alpha,
            batch_size=self.batch_size,
            learning_rate_init=self.learning_rate_init,
            max_iter=self.max_iter,
            shuffle=self.shuffle,
            random_state=self.random_state,
            tol=self.tol,
            momentum=self.momentum,
            nesterovs_momentum=self.nesterovs_momentum,
            early_stopping=self.early_stopping,
            validation_fraction=self.validation_fraction,
            beta_1=self.beta_1,
            beta_2=self.beta_2,
            epsilon=self.epsilon,
            n_iter_no_change=self.n_iter_no_change,
            q_func_estimator_hyperparams=self.q_func_estimator_hyperparams,
            q_func_with_s_hyperparams=self.q_func_with_s_hyperparams,
        )
        ours_gamma.fit(
            context=train_data["context"],
            action=train_data["action"],
            surrogate_reward=train_data["surrogate_reward"],
            reward=train_data["reward"],
            obs_list=train_data["obs_list"],
            p_o=self.p_o,
            s_sum=train_data["s_sum"],
            pscore=train_data["pscore"],
            beta=gamma
        )
        ours_gamma_action_dist = ours_gamma.predict(context=test_data["context"])
        
        return ours_gamma_action_dist
        
    def evaluate_gamma(self, beta:float, train_data:dict, test_data:dict, gammas:np.ndarray):
        gammas_a=[]
        gammas_b=[]
        biases=[]
        beta_value_a = self.predict_policy_value(
            context=test_data["context"],
            action=test_data["action"],
            surrogate_reward=test_data["surrogate_reward"],
            reward=test_data["reward"],
            obs_list=test_data["obs_list"],
            pscore=test_data["pscore"],
            p_o=self.p_o,
            s_sum=test_data["s_sum"],
            beta=self.beta,
            action_dist=test_data["action_dist"]
            )
        beta_value_b = self.predict_policy_value(
            context=train_data["context"],
            action=train_data["action"],
            surrogate_reward=train_data["surrogate_reward"],
            reward=train_data["reward"],
            obs_list=train_data["obs_list"],
            pscore=train_data["pscore"],
            p_o=self.p_o,
            s_sum=train_data["s_sum"],
            beta=self.beta,
            action_dist=train_data["action_dist"]
        )
        for gamma in gammas:
            learning_gamma_value = self.predict_policy_value(
                context=train_data["context"],
                action=train_data["action"],
                surrogate_reward=train_data["surrogate_reward"],
                reward=train_data["reward"],
                obs_list=train_data["obs_list"],
                pscore=train_data["pscore"],
                p_o=self.p_o,
                s_sum=train_data["s_sum"],
                beta=gamma,
                action_dist=train_data["action_dist"]
            )
            diff_gamma_value = self.predict_policy_value(
                context=test_data["context"],
                action=test_data["action"],
                surrogate_reward=test_data["surrogate_reward"],
                reward=test_data["reward"],
                obs_list=test_data["obs_list"],
                pscore=test_data["pscore"],
                p_o=self.p_o,
                s_sum=test_data["s_sum"],
                beta=gamma,
                action_dist=test_data["action_dist"]
                )
            gammas_a.append(learning_gamma_value)
            gammas_b.append(diff_gamma_value)
            biases.append((abs(learning_gamma_value-beta_value_b)+abs(diff_gamma_value-beta_value_a))/2)
        return gammas_a, gammas_b, biases
    
    def optimize(self, n_gam: int=200, num_tries: int=100):
        start_time = time.time()
        learn_size = self.context.shape[0] // 10
        learn_data, data = self.split_data(
            num_train=learn_size,
            context=self.context,
            action=self.action,
            surrogate_reward=self.surrogate_reward,
            reward=self.reward,
            obs_list=self.obs_list,
            s_sum=self.s_sum,
            pscore=self.pscore
        )
        
        action_dist = self.learn_policy(gamma=self.beta, train_data=learn_data, test_data=data, n_gam=n_gam)
        
        optimized_gamma = []
        range_width = 0.5
        initial_guess = min(1, self.beta + 0.3)
        n_calls = n_gam + 3
        gammas=np.linspace(0, 1, n_gam)
        gamma_a_list = []
        gamma_b_list = []
        bias_list = []
        data_size = (self.context.shape[0] - learn_size)//2

        for _ in range(num_tries):
            train_data, test_data = self.split_data(
                num_train=data_size,
                context=data["context"],
                action=data["action"],
                surrogate_reward=data["surrogate_reward"],
                reward=data["reward"],
                obs_list=data["obs_list"],
                s_sum=data["s_sum"],
                pscore=data["pscore"],
                action_dist=action_dist
            )
            gamma_a, gamma_b, biases = self.evaluate_gamma(beta=self.beta, train_data=train_data, test_data=test_data, gammas=gammas)
            gamma_a_list.append(gamma_a)
            gamma_b_list.append(gamma_b)
            bias_list.append(biases)

        variance = np.mean(np.array(((np.array(gamma_a_list) - np.array(gamma_b_list)))**2), axis=0)
        bias = np.mean(np.array(bias_list), axis=0)
        
        min_bias_index = np.argmin(bias)
        print(f"\nBeta value: {self.beta}")
        
        # get index of min Mean Squared Error
        mse = variance + np.square(bias)
        
        # get index of min Mean Squared Error
        min_mse_index = np.argmin(mse)
        print(f"Min MSE gamma: {min_mse_index/n_gam}")
        
        true_variance = (variance/2)*(10/9)
        true_mse = true_variance + np.square(bias)
        min_true_mse_index = np.argmin(true_mse)
        predicted_gamma = max(self.beta, min_true_mse_index/n_gam)
        print(f"Optimized Gamma is {predicted_gamma}")
        
        return predicted_gamma


@dataclass
class GammaOptimizer:
    n_actions: int
    context: np.ndarray
    action: np.ndarray
    surrogate_reward: np.ndarray
    reward: np.ndarray
    obs_list: np.ndarray
    p_o: float
    s_sum: np.ndarray = None
    pscore: np.ndarray = None
    beta: float = None
    len_list: int = 1   
    dim_context: Optional[int] = None
    s_dim: Optional[int] = None
    off_policy_objective: Optional[str] = None
    policy_reg_param: float = 0.0
    var_reg_param: float = 0.0
    hidden_layer_size: Tuple[int, ...] = (100,)
    activation: str = "relu"
    solver: str = "adam"
    alpha: float = 0.0001
    batch_size: Union[int, str] = "auto"
    learning_rate_init: float = 0.0001
    max_iter: int = 200
    shuffle: bool = True
    random_state: Optional[int] = None
    tol: float = 1e-4
    momentum: float = 0.9
    nesterovs_momentum: bool = True
    early_stopping: bool = False
    validation_fraction: float = 0.1
    beta_1: float = 0.9
    beta_2: float = 0.999
    epsilon: float = 1e-8
    n_iter_no_change: int = 10
    q_func_estimator_hyperparams: Optional[Dict] = None
    q_func_with_s_hyperparams: Optional[Dict] = None

    def __post_init__(self):
        self.dim_context = self.context.shape[1] if self.context is not None else None
        if self.q_func_estimator_hyperparams is not None:
            self.q_func_estimator_hyperparams["n_actions"] = self.n_actions
            self.q_func_estimator_hyperparams["dim_context"] = self.dim_context
            self.q_func_estimator = QFuncEstimator(
                **self.q_func_estimator_hyperparams
            )
        else:
            self.q_func_estimator = QFuncEstimator(
                n_actions=self.n_actions, dim_context=self.dim_context
            )
        if self.q_func_estimator_hyperparams is not None:
            self.q_func_estimator_hyperparams["n_actions"] = self.n_actions
            self.q_func_estimator_hyperparams["dim_context"] = self.dim_context
            self.f_func_estimator = QFuncEstimator(
                **self.q_func_estimator_hyperparams
            )
        else:
            self.f_func_estimator = QFuncEstimator(
                n_actions=self.n_actions, dim_context=self.dim_context
            )
        if self.q_func_with_s_hyperparams is not None:
            self.q_func_with_s_hyperparams["n_actions"]=self.n_actions
            self.q_func_with_s_hyperparams["dim_context"]=self.dim_context
            self.q_func_with_s_hyperparams["s_dim"]=self.s_dim
            self.q_func_with_s_estimator = QFuncEstimatorWithS(
                **self.q_func_with_s_hyperparams
            )
        else:
            self.q_func_with_s_estimator = QFuncEstimatorWithS(
                n_actions=self.n_actions, dim_context=self.dim_context, s_dim=self.s_dim
            )
        

    def generate_size(self, n:int, num_tries: int, random_state:int, mean:float=0.6, std_dev:float=0.7):
        if num_tries == 1:
            return np.array([int(0.5 * n)])
        
        random_ = np.random.RandomState(random_state)

        lower_bound, upper_bound = 0.4999 * n, 0.5001 * n
        
        mean_scaled = lower_bound + mean * (upper_bound - lower_bound)
        std_dev_scaled = std_dev * (upper_bound - lower_bound)

        a, b = (lower_bound - mean_scaled) / std_dev_scaled, (upper_bound - mean_scaled) / std_dev_scaled
        
        truncated_normal_values = truncnorm(a, b, loc=mean_scaled, scale=std_dev_scaled).rvs(num_tries)

        integer_values = np.round(truncated_normal_values).astype(int)

        return integer_values


    def split_data(
        self, num_train: int, context: np.ndarray, action: np.ndarray, surrogate_reward: np.ndarray, reward: np.ndarray, obs_list: np.ndarray, s_sum: np.ndarray, pscore: np.ndarray
    ):
        n_rounds = context.shape[0]  # Assuming context is not empty and others have the same n_rounds
        assert num_train <= n_rounds, "num_train cannot be greater than the total number of samples"

        # Generate shuffled indices
        indices = np.arange(n_rounds)
        np.random.shuffle(indices)

        # Split indices into training and test sets
        train_indices = indices[:num_train]
        test_indices = indices[num_train:]

        # Function to subset the arrays
        def subset_data(data):
            return data[train_indices], data[test_indices]

        # Applying subsetting to each array
        train_context, test_context = subset_data(context)
        train_action, test_action = subset_data(action)
        train_surrogate_reward, test_surrogate_reward = subset_data(surrogate_reward)
        train_reward, test_reward = subset_data(reward)
        train_obs_list, test_obs_list = subset_data(obs_list)
        train_s_sum, test_s_sum = subset_data(s_sum)
        train_pscore, test_pscore = subset_data(pscore)

        # Return dictionaries of training and testing data
        train_data = {
            'context': train_context,
            'action': train_action,
            'surrogate_reward': train_surrogate_reward,
            'reward': train_reward,
            'obs_list': train_obs_list,
            's_sum': train_s_sum,
            'pscore': train_pscore,
        }

        test_data = {
            'context': test_context,
            'action': test_action,
            'surrogate_reward': test_surrogate_reward,
            'reward': test_reward,
            'obs_list': test_obs_list,
            's_sum': test_s_sum,
            'pscore': test_pscore,
        }

        return train_data, test_data


    def predict_policy_value(
        self,
        context: np.ndarray, 
        action: np.ndarray, 
        surrogate_reward: np.ndarray, 
        reward: np.ndarray, 
        obs_list: np.ndarray, 
        pscore: np.ndarray, 
        p_o: float, 
        s_sum: np.ndarray, 
        beta: float, 
        action_dist: np.ndarray
    ):
        # Convert all numpy arrays to PyTorch tensors
        context = torch.tensor(context, dtype=torch.float32)
        action = torch.tensor(action, dtype=torch.long)
        surrogate_reward = torch.tensor(surrogate_reward, dtype=torch.float32)
        reward = torch.tensor(reward, dtype=torch.float32)
        obs_list = torch.tensor(obs_list, dtype=torch.float32)
        pscore = torch.tensor(pscore, dtype=torch.float32)
        s_sum = torch.tensor(s_sum, dtype=torch.float32)
        action_dist = torch.tensor(action_dist, dtype=torch.float32)
        action_dist = torch.squeeze(action_dist, -1)

        n_rounds = context.shape[0]
        idx_tensor = torch.arange(action.shape[0], dtype=torch.long)
        if p_o != 0.0 and 0 < beta < 1:
            q_hat = self.q_func_estimator.predict(context)
            q_hat_with_s = self.q_func_with_s_estimator.predict(context, surrogate_reward)
            q_hat_factual = q_hat[idx_tensor, action]
            q_hat_factual_with_s = q_hat_with_s[idx_tensor, action]
            iw = action_dist[idx_tensor, action] / pscore
            pw = obs_list / p_o
            obs_policy_val = pw * iw * (reward - q_hat_factual_with_s)
            policy_val = iw * (q_hat_factual_with_s - q_hat_factual)
            policy_val += obs_policy_val
            policy_val += torch.sum(q_hat * action_dist, dim=1)
            
            f_hat = self.f_func_estimator.predict(context)
            f_hat_factual = f_hat[idx_tensor, action]
            side_policy_val = iw * (s_sum - f_hat_factual)
            side_policy_val += torch.sum(f_hat * action_dist, dim=1)
            
            policy_val = beta * side_policy_val + (1 - beta) * policy_val
        elif p_o != 0.0 and beta == 0:
            q_hat = self.q_func_estimator.predict(context)
            q_hat_with_s = self.q_func_with_s_estimator.predict(context, surrogate_reward)
            q_hat_factual = q_hat[idx_tensor, action]
            q_hat_factual_with_s = q_hat_with_s[idx_tensor, action]
            iw = action_dist[idx_tensor, action] / pscore
            pw = obs_list / p_o
            obs_policy_val = pw * iw * (reward - q_hat_factual_with_s)
            policy_val = iw * (q_hat_factual_with_s - q_hat_factual)
            policy_val += obs_policy_val
            policy_val += torch.sum(q_hat * action_dist, dim=1)
        else:
            iw = action_dist[idx_tensor, action] / pscore
            f_hat = self.f_func_estimator.predict(context)
            f_hat_factual = f_hat[idx_tensor, action]
            side_policy_val = iw * (s_sum - f_hat_factual)
            side_policy_val += torch.sum(f_hat * action_dist, dim=1)
            policy_val = side_policy_val

        return policy_val.mean().item()
            

    def evaluate_gamma(self, gamma:float, train_data:dict, test_data:dict):
        ours_gamma = OurLearner(
            n_actions=self.n_actions,
            len_list=self.len_list,
            dim_context=self.dim_context,
            s_dim=self.s_dim,
            off_policy_objective=self.off_policy_objective,
            policy_reg_param=self.policy_reg_param,
            var_reg_param=self.var_reg_param,
            hidden_layer_size=self.hidden_layer_size,
            activation=self.activation,
            solver=self.solver,
            alpha=self.alpha,
            batch_size=self.batch_size,
            learning_rate_init=self.learning_rate_init,
            max_iter=self.max_iter,
            shuffle=self.shuffle,
            random_state=self.random_state,
            tol=self.tol,
            momentum=self.momentum,
            nesterovs_momentum=self.nesterovs_momentum,
            early_stopping=self.early_stopping,
            validation_fraction=self.validation_fraction,
            beta_1=self.beta_1,
            beta_2=self.beta_2,
            epsilon=self.epsilon,
            n_iter_no_change=self.n_iter_no_change,
            q_func_estimator_hyperparams=self.q_func_estimator_hyperparams,
            q_func_with_s_hyperparams=self.q_func_with_s_hyperparams,
        )
        ours_gamma.fit(
            context=train_data["context"],
            action=train_data["action"],
            surrogate_reward=train_data["surrogate_reward"],
            reward=train_data["reward"],
            obs_list=train_data["obs_list"],
            p_o=self.p_o,
            s_sum=train_data["s_sum"],
            pscore=train_data["pscore"],
            beta=gamma
        )
        ours_gamma_action_dist = ours_gamma.predict(context=test_data["context"])
        ours_gamma_value = self.predict_policy_value(
            context=test_data["context"],
            action=test_data["action"],
            surrogate_reward=test_data["surrogate_reward"],
            reward=test_data["reward"],
            obs_list=test_data["obs_list"],
            pscore=test_data["pscore"],
            p_o=self.p_o,
            s_sum=test_data["s_sum"],
            beta=self.beta,
            action_dist=ours_gamma_action_dist
            )
        ours_train_gamma_action_dist = ours_gamma.predict(context=train_data["context"])
        learning_beta_value = ours_gamma.predict_own(
            context=train_data["context"],
            action=train_data["action"],
            surrogate_reward=train_data["surrogate_reward"],
            reward=train_data["reward"],
            obs_list=train_data["obs_list"],
            pscore=train_data["pscore"],
            p_o=self.p_o,
            s_sum=train_data["s_sum"],
            beta=self.beta,
            action_dist=ours_train_gamma_action_dist
        )
        learning_gamma_value = ours_gamma.predict_own(
            context=train_data["context"],
            action=train_data["action"],
            surrogate_reward=train_data["surrogate_reward"],
            reward=train_data["reward"],
            obs_list=train_data["obs_list"],
            pscore=train_data["pscore"],
            p_o=self.p_o,
            s_sum=train_data["s_sum"],
            beta=gamma,
            action_dist=ours_train_gamma_action_dist
        )
        diff_gamma_value = self.predict_policy_value(
            context=test_data["context"],
            action=test_data["action"],
            surrogate_reward=test_data["surrogate_reward"],
            reward=test_data["reward"],
            obs_list=test_data["obs_list"],
            pscore=test_data["pscore"],
            p_o=self.p_o,
            s_sum=test_data["s_sum"],
            beta=gamma,
            action_dist=ours_gamma_action_dist
            )
        gamma_variance=abs(learning_gamma_value-diff_gamma_value)
        beta_variance = abs(learning_beta_value-ours_gamma_value)
        learning_bias=abs(learning_beta_value-learning_gamma_value)
        diff_bias = abs(ours_gamma_value-diff_gamma_value)
        bias=(learning_bias+diff_bias)/2
        print(f"Gamma={gamma}, Gamma_variance={gamma_variance}, Beta_variance={beta_variance}, Bias={bias}")
        return -ours_gamma_value, gamma_variance, bias
    
    def golden_section_search(self, func, initial_guess, tol=1e-5, max_iter=20):
        gr = (math.sqrt(5) + 1) / 2
        a = max(self.beta, initial_guess)  # Using self.beta as the lower bound
        b = 1  # Using 1 as the upper bound

        c = b - (b - a) / gr
        d = a + (b - a) / gr

        # Collect all gamma, variance, and bias
        gamma_list, variance_list, bias_list = [], [], []

        fc, fc_var, fc_bias = func(c)
        fd, fd_var, fd_bias = func(d)

        gamma_list.extend([c, d])
        variance_list.extend([fc_var, fd_var])
        bias_list.extend([fc_bias, fd_bias])

        min_value = min(fc, fd)
        best_gamma = c if fc < fd else d

        iter_count = 0
        while abs(c - d) > tol and iter_count < max_iter:
            if fc < fd:
                b = d
                d = c
                c = b - (b - a) / gr
                fd, fd_var, fd_bias = fc, fc_var, fc_bias
                fc, fc_var, fc_bias = func(c)
            else:
                a = c
                c = d
                d = a + (b - a) / gr
                fc, fc_var, fc_bias = fd, fd_var, fd_bias
                fd, fd_var, fd_bias = func(d)

            gamma_list.extend([c, d])
            variance_list.extend([fc_var, fd_var])
            bias_list.extend([fc_bias, fd_bias])

            if fc < min_value:
                min_value = fc
                best_gamma = c
            if fd < min_value:
                min_value = fd
                best_gamma = d

            iter_count += 1

        print(f"Optimized gamma after {iter_count} iterations: {best_gamma}, Minimum function value: {min_value}")
        return best_gamma, variance_list, bias_list

    def optimize_gamma(self, train_data, test_data, initial_guess: float, n_calls: int = 20):
        def evaluate(gamma):
            value, variance, bias = self.evaluate_gamma(gamma, train_data, test_data)
            print(f"Evaluating gamma={gamma}: Value={value}")
            return value, variance, bias

        print(f"Starting optimization with initial guess: {initial_guess}")
        optimal_gamma, variance_list, bias_list = self.golden_section_search(evaluate, initial_guess, tol=1e-5, max_iter=n_calls)
        return optimal_gamma, variance_list, bias_list

    def optimize(self, num_tries: int = 4, num_gamma: int = 4):
        start_time = time.time()
        data_size = self.context.shape[0] // 2
        optimized_gamma = []
        half_gamma=[]
        result_gamma=[]
        initial_guess = min(1, self.beta + 0.3)
        half_initial_guess = min(1, self.beta + 0.3)
        n_calls = num_gamma + 3
        gamma_list = []
        variance_list = []
        bias_list = []

        for _ in range(num_tries):
            train_data, test_data = self.split_data(
                num_train=data_size,
                context=self.context,
                action=self.action,
                surrogate_reward=self.surrogate_reward,
                reward=self.reward,
                obs_list=self.obs_list,
                s_sum=self.s_sum,
                pscore=self.pscore
            )
            train_a_data, train_b_data = self.split_data(
                num_train=data_size//2,
                context=train_data["context"],
                action=train_data["action"],
                surrogate_reward=train_data["surrogate_reward"],
                reward=train_data["reward"],
                obs_list=train_data["obs_list"],
                s_sum=train_data["s_sum"],
                pscore=train_data["pscore"]
            )
            train_c_data, train_d_data = self.split_data(
                num_train=data_size//2,
                context=train_data["context"],
                action=train_data["action"],
                surrogate_reward=train_data["surrogate_reward"],
                reward=train_data["reward"],
                obs_list=train_data["obs_list"],
                s_sum=train_data["s_sum"],
                pscore=train_data["pscore"]
            )
            gamma_a, variance_a, bias_a = self.optimize_gamma(train_data=train_a_data, test_data=test_data, initial_guess=half_initial_guess, n_calls=n_calls)
            gamma_b, variance_b, bias_b = self.optimize_gamma(train_data=train_b_data, test_data=test_data, initial_guess=half_initial_guess, n_calls=n_calls)
            gamma_c, variance_c, bias_c = self.optimize_gamma(train_data=train_c_data, test_data=test_data, initial_guess=half_initial_guess, n_calls=n_calls)
            gamma_d, variance_d, bias_d = self.optimize_gamma(train_data=train_d_data, test_data=test_data, initial_guess=half_initial_guess, n_calls=n_calls)
            half_gamma.append((gamma_a+gamma_b+gamma_c+gamma_d)/4)
            half_initial_guess = (gamma_a+gamma_b+gamma_c+gamma_d)/4
            optimal_gamma, variances, biases = self.optimize_gamma(
                train_data=train_data, test_data=test_data, initial_guess=initial_guess, 
                n_calls=n_calls
            )
            optimized_gamma.append(optimal_gamma)
            res_gam=min(1, max(self.beta, (optimal_gamma-half_gamma[-1])+optimal_gamma))
            variance_list.extend(variances)
            bias_list.extend(biases)
            initial_guess = optimal_gamma  
            n_calls = num_gamma

        predicted_gamma = np.mean(result_gamma) 
        predicted_gamma = min(1, max(0, predicted_gamma))
        print(f"Optimized gamma values: {predicted_gamma}")
        
        # np the lists
        variance_list = np.array(variance_list)
        bias_list = np.array(bias_list)
        
        return predicted_gamma

@dataclass
class OurLearner2:
    """Off-policy learner parameterized by a neural network.

    Parameters
    -----------
    n_actions: int
        Number of actions.
        
    dim_context: int
        Number of dimensions of context vectors.
        
    s_dim: int
        Number of dimensions of surrogate reward.

    off_policy_objective: str
        An OPE estimator used to estimate the policy gradient.
        Must be one of 'ours', 'sdr', 'bdr'.

    policy_reg_param: float, default=0.0
        A hypeparameter to control the policy regularization. :math:`\\lambda_{pol}`.

    var_reg_param: float, default=0.0
        A hypeparameter to control the variance regularization. :math:`\\lambda_{var}`.

    hidden_layer_size: Tuple[int, ...], default = (100,)
        The i-th element specifies the size of the i-th layer.

    activation: str, default='identity'
        Activation function.
        Must be one of the followings:

        - 'identity', the identity function, :math:`f(x) = x`.
        - 'logistic', the sigmoid function, :math:`f(x) = \\frac{1}{1 + \\exp(x)}`.
        - 'tanh', the hyperbolic tangent function, `:math:f(x) = \\frac{\\exp(x) - \\exp(-x)}{\\exp(x) + \\exp(-x)}`
        - 'relu', the rectified linear unit function, `:math:f(x) = \\max(0, x)`

    solver: str, default='adam'
        Optimizer of the neural network.
        Must be one of the followings:

        - 'sgd', Stochastic Gradient Descent.
        - 'adam', Adam (Kingma and Ba 2014).
        - 'adagrad', Adagrad (Duchi et al. 2011).

    alpha: float, default=0.001
        L2 penalty.

    batch_size: Union[int, str], default="auto"
        Batch size for SGD, Adagrad, and Adam.
        If "auto", the maximum of 200 and the number of samples is used.
        If integer, must be positive.

    learning_rate_init: int, default=0.0001
        Initial learning rate for SGD, Adagrad, and Adam.

    max_iter: int, default=200
        Number of epochs for SGD, Adagrad, and Adam.

    shuffle: bool, default=True
        Whether to shuffle samples in SGD and Adam.

    random_state: Optional[int], default=None
        Controls the random seed.

    tol: float, default=1e-4
        Tolerance for training.
        When the training loss is not improved at least `tol' for `n_iter_no_change' consecutive iterations,
        training is stopped.

    momentum: float, default=0.9
        Momentum for SGD.
        Must be in the range of [0., 1.].

    nesterovs_momentum: bool, default=True
        Whether to use Nesterovs momentum.

    early_stopping: bool, default=False
        Whether to use early stopping for SGD, Adagrad, and Adam.
        If set to true, `validation_fraction' of training data is used as validation data,
        and training is stopped when the validation loss is not improved at least `tol' for `n_iter_no_change' consecutive iterations.

    validation_fraction: float, default=0.1
        Fraction of validation data when early stopping is used.
        Must be in the range of (0., 1.].

    beta_1: float, default=0.9
        Coefficient used for computing running average of gradient for Adam.
        Must be in the range of [0., 1.].

    beta_2: float, default=0.999
        Coefficient used for computing running average of the square of gradient for Adam.
        Must be in the range of [0., 1.].

    epsilon: float, default=1e-8
        Term for numerical stability in Adam.

    n_iter_no_change: int, default=10
        Maximum number of not improving epochs when early stopping is used.

    q_func_estimator_hyperparams: Dict, default=None
        A set of hyperparameters to define q function estimator. i.e. \hat{q}(x,a)
        
    q_func_with_s_hyperparams: Dict, default=None
        A set of hyperparameters to define q function estimator. i.e. \hat{q}(x,a,s)
    """
    n_actions: int
    len_list: int = 1   
    dim_context: Optional[int] = None
    s_dim: Optional[int] = None
    off_policy_objective: Optional[str] = None
    policy_reg_param: float = 0.0
    var_reg_param: float = 0.0
    hidden_layer_size: Tuple[int, ...] = (100,)
    activation: str = "relu"
    solver: str = "adam"
    alpha: float = 0.0001
    batch_size: Union[int, str] = "auto"
    learning_rate_init: float = 0.0001
    max_iter: int = 200
    shuffle: bool = True
    random_state: Optional[int] = None
    tol: float = 1e-4
    momentum: float = 0.9
    nesterovs_momentum: bool = True
    early_stopping: bool = False
    validation_fraction: float = 0.1
    beta_1: float = 0.9
    beta_2: float = 0.999
    epsilon: float = 1e-8
    n_iter_no_change: int = 10
    q_func_estimator: Optional[torch.nn.Module] = field(default=None, repr=False)
    q_func_with_s_estimator: Optional[torch.nn.Module] = field(default=None, repr=False)
    f_func_estimator: Optional[torch.nn.Module] = field(default=None, repr=False)
    
    def __post_init__(self) -> None:
        """Initialize class."""
        
        check_scalar(self.n_actions, "n_actions", int, min_val=1)

        check_scalar(self.dim_context, "dim_context", int, min_val=1)
        
        check_scalar(self.s_dim, "s_dim", int, min_val=1)
        
        if self.off_policy_objective not in [
            "ours",
            "sdr",
            "bdr",
            "sdr-both",
            "ours-gamma"
        ]:
            raise ValueError(
                "`off_policy_objective` {self.off_policy_objective} is given"
            )

        check_scalar(
            self.policy_reg_param,
            "policy_reg_param",
            (int, float),
            min_val=0.0,
        )

        check_scalar(
            self.var_reg_param,
            "var_reg_param",
            (int, float),
            min_val=0.0,
        )
        
        

        if not isinstance(self.hidden_layer_size, tuple) or any(
            [not isinstance(h, int) or h <= 0 for h in self.hidden_layer_size]
        ):
            raise ValueError(
                f"`hidden_layer_size` must be a tuple of positive integers, but {self.hidden_layer_size} is given"
            )

        if self.solver not in ("adagrad", "sgd", "adam"):
            raise ValueError(
                f"`solver` must be one of 'adam', 'adagrad', or 'sgd', but {self.solver} is given"
            )
            
        # if self.w_xs_estimation not in ("estimate_prob", "estimate_weight"):
        #     raise ValueError(
        #         f"`w_xs_estimation` must be one of 'estimate_prob', 'estimate_weight', but {self.w_xs_estimation} is given"
        #     )

        check_scalar(self.alpha, "alpha", float, min_val=0.0)

        if self.batch_size != "auto" and (
            not isinstance(self.batch_size, int) or self.batch_size <= 0
        ):
            raise ValueError(
                f"`batch_size` must be a positive integer or 'auto', but {self.batch_size} is given"
            )

        check_scalar(self.learning_rate_init, "learning_rate_init", float)
        if self.learning_rate_init <= 0.0:
            raise ValueError(
                f"`learning_rate_init`= {self.learning_rate_init}, must be > 0.0"
            )

        check_scalar(self.max_iter, "max_iter", int, min_val=1)

        if not isinstance(self.shuffle, bool):
            raise ValueError(f"`shuffle` must be a bool, but {self.shuffle} is given")

        check_scalar(self.tol, "tol", float)
        if self.tol <= 0.0:
            raise ValueError(f"`tol`= {self.tol}, must be > 0.0")

        check_scalar(self.momentum, "momentum", float, min_val=0.0, max_val=1.0)

        if not isinstance(self.nesterovs_momentum, bool):
            raise ValueError(
                f"`nesterovs_momentum` must be a bool, but {self.nesterovs_momentum} is given"
            )

        if not isinstance(self.early_stopping, bool):
            raise ValueError(
                f"`early_stopping` must be a bool, but {self.early_stopping} is given"
            )

        check_scalar(
            self.validation_fraction, "validation_fraction", float, max_val=1.0
        )
        if self.validation_fraction <= 0.0:
            raise ValueError(
                f"`validation_fraction`= {self.validation_fraction}, must be > 0.0"
            )

        check_scalar(self.beta_1, "beta_1", float, min_val=0.0, max_val=1.0)
        check_scalar(self.beta_2, "beta_2", float, min_val=0.0, max_val=1.0)
        check_scalar(self.epsilon, "epsilon", float, min_val=0.0)
        check_scalar(self.n_iter_no_change, "n_iter_no_change", int, min_val=1)

        if self.random_state is not None:
            self.random_ = check_random_state(self.random_state)
            torch.manual_seed(self.random_state)

        if self.activation == "identity":
            activation_layer = nn.Identity
        elif self.activation == "logistic":
            activation_layer = nn.Sigmoid
        elif self.activation == "tanh":
            activation_layer = nn.Tanh
        elif self.activation == "relu":
            activation_layer = nn.ReLU
        elif self.activation == "elu":
            activation_layer = nn.ELU
        else:
            raise ValueError(
                "`activation` must be one of 'identity', 'logistic', 'tanh', 'relu', or 'elu'"
                f", but {self.activation} is given"
            )

        layer_list = []
        input_size = self.dim_context

        for i, h in enumerate(self.hidden_layer_size):
            layer_list.append(("l{}".format(i), nn.Linear(input_size, h)))
            layer_list.append(("a{}".format(i), activation_layer()))
            input_size = h
        layer_list.append(("output", nn.Linear(input_size, self.n_actions)))
        layer_list.append(("softmax", nn.Softmax(dim=1)))

        self.nn_model = nn.Sequential(OrderedDict(layer_list))
        

    def _create_train_data_for_opl(
        self,
        context: np.ndarray,
        action: np.ndarray,
        surrogate_reward: np.ndarray,
        reward: np.ndarray,
        obs_list: np.ndarray,
        pscore: np.ndarray,
        s_sum: Optional[np.ndarray] = None,
        # pi_b: np.ndarray,
        # f_x_a: np.ndarray,
        # position: np.ndarray,
        **kwargs,
    ) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.DataLoader]]:
        """Create training data for off-policy learning.

        Parameters
        -----------
        context: array-like, shape (n_rounds, dim_context)
            Context vectors observed for each data, i.e., :math:`x_i`.

        action: array-like, shape (n_rounds,)
            Actions sampled by the logging/behavior policy for each data in logged bandit data, i.e., :math:`a_i`.

        surrogate_reward: array-like, shape (n_rounds,)
            Surrogate rewards observed for each data in logged bandit data, i.e., :math:`s_i`.

        reward: array-like, shape (n_rounds,)
            Rewards observed for each data in logged bandit data, i.e., :math:`r_i`.

        pscore: array-like, shape (n_rounds,), default=None
            Action choice probabilities of the logging/behavior policy (propensity scores), i.e., :math:`\\pi_b(a_i|x_i)`.

        position: array-like, shape (n_rounds,), default=None
            Indices to differentiate positions in a recommendation interface where the actions are presented.
            If None, a learner assumes that only a single action is chosen for each data.

        Returns
        --------
        (training_data_loader, validation_data_loader): Tuple[DataLoader, Optional[DataLoader]]
            Training and validation data loaders in PyTorch

        """
        if self.batch_size == "auto":
            batch_size_ = min(200, context.shape[0])
        else:
            check_scalar(self.batch_size, "batch_size", int, min_val=1)
            batch_size_ = self.batch_size
        context = context.astype('float32')
        if not np.issubdtype(s_sum.dtype, np.number):
            raise ValueError("s_sum contains non-numeric elements")
        s_sum = s_sum.astype('float32')
        if s_sum is None:
            dataset = NNPolicyDataset(
                torch.from_numpy(context).float(),
                torch.from_numpy(action).long(),
                torch.from_numpy(surrogate_reward).float(),
                torch.from_numpy(reward).float(),
                torch.from_numpy(obs_list).float(),
                torch.from_numpy(pscore).float(),)
        else:
            dataset = NNPolicyDataset_with_fs(
                torch.from_numpy(context).float(),
                torch.from_numpy(action).long(),
                torch.from_numpy(surrogate_reward).float(),
                torch.from_numpy(s_sum).float(),
                torch.from_numpy(reward).float(),
                torch.from_numpy(obs_list).float(),
                torch.from_numpy(pscore).float(),)


        if self.early_stopping:
            if context.shape[0] <= 1:
                raise ValueError(
                    f"the number of samples is too small ({context.shape[0]}) to create validation data"
                )

            validation_size = max(int(context.shape[0] * self.validation_fraction), 1)
            training_size = context.shape[0] - validation_size
            training_dataset, validation_dataset = torch.utils.data.random_split(
                dataset, [training_size, validation_size]
            )
            training_data_loader = torch.utils.data.DataLoader(
                training_dataset,
                batch_size=batch_size_,
                shuffle=self.shuffle,
            )
            validation_data_loader = torch.utils.data.DataLoader(
                validation_dataset,
                batch_size=batch_size_,
                shuffle=self.shuffle,
            )

            return training_data_loader, validation_data_loader

        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size_,
            shuffle=self.shuffle,
        )

        return data_loader, None

    def fit(
        self,
        context: np.ndarray,
        action: np.ndarray,
        surrogate_reward: np.ndarray,
        reward: np.ndarray,
        obs_list: np.ndarray,
        p_o: float,
        s_sum: Optional[np.ndarray] = None,
        pscore: Optional[np.ndarray] = None,
        beta: Optional[float] = None,
    ) -> None:
        """Fits an offline bandit policy on the given logged bandit data.

        Note
        ----------
        Given the training data :math:`\\mathcal{D}`, this policy maximizes the following objective function:

        .. math::

            \\hat{V}(\\pi_\\theta; \\mathcal{D}) - \\alpha \\Omega(\\theta)

        where :math:`\\hat{V}` is an OPE estimator and :math:`\\alpha \\Omega(\\theta)` is a regularization term.

        Parameters
        -----------
        context: array-like, shape (n_rounds, dim_context)
            Context vectors observed for each data, i.e., :math:`x_i`.

        action: array-like, shape (n_rounds,)
            Actions sampled by the logging/behavior policy for each data in logged bandit data, i.e., :math:`a_i`.

        reward: array-like, shape (n_rounds,)
            Rewards observed for each data in logged bandit data, i.e., :math:`r_i`.

        pscore: array-like, shape (n_rounds,), default=None
            Action choice probabilities of the logging/behavior policy (propensity scores), i.e., :math:`\\pi_b(a_i|x_i)`.

        """
        
        if context.shape[1] != self.dim_context:
            raise ValueError(
                "Expected `context.shape[1] == self.dim_context`, but found it False"
            )
        if pscore is None:
            pscore = np.ones_like(action) / self.n_actions
        if self.len_list == 1:
            position = np.zeros_like(action, dtype=int)
            
        # train 
        # if self.w_xs_estimation == "estimate_prob":
        #     x_a = np.column_stack([context, action])
        #     self.kde.fit(x_a)
        obs_context=context[obs_list==1]
        obs_action=action[obs_list==1]
        obs_reward=reward[obs_list==1]
                    
        # train q_function
        if self.solver == "sgd":
            optimizer = optim.SGD(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                momentum=self.momentum,
                weight_decay=self.alpha,
                nesterov=self.nesterovs_momentum,
            )
        elif self.solver == "adagrad":
            optimizer = optim.Adagrad(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                eps=self.epsilon,
                weight_decay=self.alpha,
            )
        elif self.solver == "adam":
            optimizer = optim.Adam(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                betas=(self.beta_1, self.beta_2),
                eps=self.epsilon,
                weight_decay=self.alpha,
            )
        else:
            raise NotImplementedError(
                "`solver` must be one of 'adam', 'adagrad', or 'sgd'"
            )

        if self.off_policy_objective == "sdr-both":
            training_data_loader, validation_data_loader = self._create_train_data_for_opl(
                context=context, action=action, surrogate_reward=surrogate_reward, reward=reward, obs_list=obs_list, pscore=pscore, s_sum=s_sum
            )
        else:
            training_data_loader, validation_data_loader = self._create_train_data_for_opl(
                context=context, action=action, surrogate_reward=surrogate_reward, reward=reward, obs_list=obs_list, pscore=pscore
            )

        if self.off_policy_objective == "sdr-both":
            n_not_improving_training = 0
            previous_training_loss = None
            n_not_improving_validation = 0
            previous_validation_loss = None
            for _ in tqdm(np.arange(self.max_iter), desc="policy learning"):
                self.nn_model.train()
                for x, a, s, s_sum, r, o, p in training_data_loader:
                    optimizer.zero_grad()
                    pi = self.nn_model(x).unsqueeze(-1)
                    policy_grad_arr = self._estimate_policy_gradient(
                        context=x,
                        reward=r,
                        action=a,
                        surrogate_reward=s,
                        s_sum=s_sum,
                        pscore=p,
                        obs_list=o,
                        p_o=p_o,
                        action_dist=pi,
                        beta=beta,
                    )
                    policy_constraint = self._estimate_policy_constraint(
                        action=a,
                        pscore=p,
                        action_dist=pi,
                    )
                    loss = -policy_grad_arr.mean()
                    loss += self.policy_reg_param * policy_constraint
                    loss += self.var_reg_param * torch.var(policy_grad_arr)
                    loss.backward()
                    optimizer.step()

                    loss_value = loss.item()
                    if previous_training_loss is not None:
                        if loss_value - previous_training_loss < self.tol:
                            n_not_improving_training += 1
                        else:
                            n_not_improving_training = 0
                    if n_not_improving_training >= self.n_iter_no_change:
                        break
                    previous_training_loss = loss_value

                if self.early_stopping:
                    self.nn_model.eval()
                    for x, a, s, s_sum, r, o, p in validation_data_loader:
                        pi = self.nn_model(x).unsqueeze(-1)
                        policy_grad_arr = self._estimate_policy_gradient(
                            context=x,
                            reward=r,
                            action=a,
                            surrogate_reward=s,
                            s_sum=s_sum,
                            pscore=p,
                            obs_list=o,
                            p_o=p_o,
                            action_dist=pi,
                            beta=beta
                        )
                        policy_constraint = self._estimate_policy_constraint(
                            action=a,
                            pscore=p,
                            action_dist=pi,
                        )
                        loss = -policy_grad_arr.mean()
                        loss += self.policy_reg_param * policy_constraint
                        loss += self.var_reg_param * torch.var(policy_grad_arr)
                        loss_value = loss.item()
                        if previous_validation_loss is not None:
                            if loss_value - previous_validation_loss < self.tol:
                                n_not_improving_validation += 1
                            else:
                                n_not_improving_validation = 0
                        if n_not_improving_validation > self.n_iter_no_change:
                            break
                        previous_validation_loss = loss_value
        else:
            n_not_improving_training = 0
            previous_training_loss = None
            n_not_improving_validation = 0
            previous_validation_loss = None
            for _ in tqdm(np.arange(self.max_iter), desc="policy learning"):
                self.nn_model.train()
                for x, a, s, r, o, p in training_data_loader:
                    optimizer.zero_grad()
                    pi = self.nn_model(x).unsqueeze(-1)
                    policy_grad_arr = self._estimate_policy_gradient(
                        context=x,
                        reward=r,
                        action=a,
                        surrogate_reward=s,
                        pscore=p,
                        obs_list=o,
                        p_o=p_o,
                        action_dist=pi,
                    )
                    policy_constraint = self._estimate_policy_constraint(
                        action=a,
                        pscore=p,
                        action_dist=pi,
                    )
                    loss = -policy_grad_arr.mean()
                    loss += self.policy_reg_param * policy_constraint
                    loss += self.var_reg_param * torch.var(policy_grad_arr)
                    loss.backward()
                    optimizer.step()

                    loss_value = loss.item()
                    if previous_training_loss is not None:
                        if loss_value - previous_training_loss < self.tol:
                            n_not_improving_training += 1
                        else:
                            n_not_improving_training = 0
                    if n_not_improving_training >= self.n_iter_no_change:
                        break
                    previous_training_loss = loss_value

                if self.early_stopping:
                    self.nn_model.eval()
                    for x, a, s, r, o, p in validation_data_loader:
                        pi = self.nn_model(x).unsqueeze(-1)
                        policy_grad_arr = self._estimate_policy_gradient(
                            context=x,
                            reward=r,
                            action=a,
                            surrogate_reward=s,
                            pscore=p,
                            obs_list=o,
                            p_o=p_o,
                            action_dist=pi,
                        )
                        policy_constraint = self._estimate_policy_constraint(
                            action=a,
                            pscore=p,
                            action_dist=pi,
                        )
                        loss = -policy_grad_arr.mean()
                        loss += self.policy_reg_param * policy_constraint
                        loss += self.var_reg_param * torch.var(policy_grad_arr)
                        loss_value = loss.item()
                        if previous_validation_loss is not None:
                            if loss_value - previous_validation_loss < self.tol:
                                n_not_improving_validation += 1
                            else:
                                n_not_improving_validation = 0
                        if n_not_improving_validation > self.n_iter_no_change:
                            break
                        previous_validation_loss = loss_value

    #####修正点 - 一番修正必要
    def _estimate_policy_gradient(
        self,
        context: torch.Tensor,
        action: torch.Tensor,
        surrogate_reward: torch.Tensor,
        reward: torch.Tensor,
        obs_list: torch.Tensor,
        pscore: torch.Tensor,
        p_o: float,
        action_dist: torch.Tensor,
        s_sum: Optional[torch.Tensor] = None,
        beta: Optional[float] = None,
    ) -> torch.Tensor:
        """Estimate the policy gradient.

        Parameters
        -----------
        context: array-like, shape (batch_size, dim_context)
            Context vectors observed for each data, i.e., :math:`x_i`.

        action: array-like, shape (batch_size,)
            Actions sampled by the logging/behavior policy for each data in logged bandit data, i.e., :math:`a_i`.

        surrogate_reward: array-like, shape (batch_size,)
            Surrogate rewards observed for each data in logged bandit data, i.e., :math:`s_i`.
        
        reward: array-like, shape (batch_size,)
            Rewards observed for each data in logged bandit data, i.e., :math:`r_i`.

        pscore: array-like, shape (batch_size,), default=None
            Action choice probabilities of the logging/behavior policy (propensity scores), i.e., :math:`\\pi_b(a_i|x_i)`.

        action_dist: array-like, shape (batch_size, n_actions, len_list)
            Action choice probabilities of the evaluation policy (can be deterministic), i.e., :math:`\\pi_e(a_i|x_i)`.

        Returns
        ----------
        estimated_policy_grad_arr: array-like, shape (batch_size,)
            Rewards of each data estimated by an OPE estimator.

        """
        current_pi = action_dist[:, :, 0].detach()
        log_prob = torch.log(action_dist[:, :, 0])
        idx_tensor = torch.arange(action.shape[0], dtype=torch.long)
        # obs_context = context[obs_list==1]
        # obs_action = action[obs_list==1]
        # obs_reward = reward[obs_list==1]
        # obs_surrogate_reward = surrogate_reward[obs_list==1]
        # obs_action_dist = action_dist[obs_list==1]
        # obs_pscore = pscore[obs_list==1]
        # obs_idx_tensor = torch.arange(obs_action.shape[0], dtype=torch.long)
        # obs_=obs_list[obs_list==1]

        if self.off_policy_objective == "ours":
            raise ValueError(
                "ours is not ready"
            )
            q_hat = self.q_func_estimator.predict(
                context=context,
            )
            q_hat_factual = q_hat[idx_tensor, action]
            iw = current_pi[idx_tensor, action] / pscore
            pw = obs_list/p_o
            estimated_policy_grad_arr = pw * iw * (reward - q_hat_factual)
            estimated_policy_grad_arr *= log_prob[idx_tensor, action]
            estimated_policy_grad_arr += torch.sum(q_hat * current_pi * log_prob, dim=1)
        elif self.off_policy_objective == "sdr":
            q_hat = self.q_func_estimator.predict(
                context=context,
            )
            q_hat_with_s = self.q_func_with_s_estimator.predict(
                context=context,
                surrogate_reward = surrogate_reward,
            )
            q_hat_factual = q_hat[idx_tensor, action]
            q_hat_factual_with_s = q_hat_with_s[idx_tensor, action]
            iw = current_pi[idx_tensor, action] / pscore
            if p_o!=0.0:
                pw = obs_list/p_o
                obs_estimated_policy_grad_arr = pw * iw * (reward - q_hat_factual_with_s)
                obs_estimated_policy_grad_arr *= log_prob[idx_tensor, action]
                estimated_policy_grad_arr = iw * (q_hat_factual_with_s - q_hat_factual)
                estimated_policy_grad_arr *= log_prob[idx_tensor, action]
                estimated_policy_grad_arr += obs_estimated_policy_grad_arr
                estimated_policy_grad_arr += torch.sum(q_hat * current_pi * log_prob, dim=1)
            else:
                estimated_policy_grad_arr = iw * (q_hat_factual_with_s - q_hat_factual)
                estimated_policy_grad_arr *= log_prob[idx_tensor, action]
                estimated_policy_grad_arr += torch.sum(q_hat * current_pi * log_prob, dim=1)
        
        elif self.off_policy_objective == "sdr-both":
            if p_o!=0.0 and beta<1 and beta>0:
                q_hat = self.q_func_estimator.predict(
                    context=context,
                )
                q_hat_with_s = self.q_func_with_s_estimator.predict(
                    context=context,
                    surrogate_reward = surrogate_reward,
                )
                q_hat_factual = q_hat[idx_tensor, action]
                q_hat_factual_with_s = q_hat_with_s[idx_tensor, action]
                iw = current_pi[idx_tensor, action] / pscore
                pw = obs_list/p_o
                obs_estimated_policy_grad_arr = pw * iw * (reward - q_hat_factual_with_s)
                obs_estimated_policy_grad_arr *= log_prob[idx_tensor, action]
                estimated_policy_grad_arr = iw * (q_hat_factual_with_s - q_hat_factual)
                estimated_policy_grad_arr *= log_prob[idx_tensor, action]
                estimated_policy_grad_arr += obs_estimated_policy_grad_arr
                estimated_policy_grad_arr += torch.sum(q_hat * current_pi * log_prob, dim=1)
                
                f_hat = self.f_func_estimator.predict(
                    context=context,
                )
                f_hat_factual = f_hat[idx_tensor, action]
                side_estimated_policy_grad_arr = iw * (s_sum-f_hat_factual)
                side_estimated_policy_grad_arr *= log_prob[idx_tensor, action]
                side_estimated_policy_grad_arr += torch.sum(f_hat * current_pi * log_prob, dim=1)
                
                estimated_policy_grad_arr = beta*side_estimated_policy_grad_arr + (1-beta)*estimated_policy_grad_arr
            elif p_o!=0.0 and beta==0:
                q_hat = self.q_func_estimator.predict(
                    context=context,
                )
                q_hat_with_s = self.q_func_with_s_estimator.predict(
                    context=context,
                    surrogate_reward = surrogate_reward,
                )
                q_hat_factual = q_hat[idx_tensor, action]
                q_hat_factual_with_s = q_hat_with_s[idx_tensor, action]
                iw = current_pi[idx_tensor, action] / pscore
                pw = obs_list/p_o
                obs_estimated_policy_grad_arr = pw * iw * (reward - q_hat_factual_with_s)
                obs_estimated_policy_grad_arr *= log_prob[idx_tensor, action]
                estimated_policy_grad_arr = iw * (q_hat_factual_with_s - q_hat_factual)
                estimated_policy_grad_arr *= log_prob[idx_tensor, action]
                estimated_policy_grad_arr += obs_estimated_policy_grad_arr
                estimated_policy_grad_arr += torch.sum(q_hat * current_pi * log_prob, dim=1)
            else:
                iw = current_pi[idx_tensor, action] / pscore
                f_hat = self.f_func_estimator.predict(
                    context=context,
                )
                f_hat_factual = f_hat[idx_tensor, action]
                side_estimated_policy_grad_arr = iw * (s_sum-f_hat_factual)
                side_estimated_policy_grad_arr *= log_prob[idx_tensor, action]
                side_estimated_policy_grad_arr += torch.sum(f_hat * current_pi * log_prob, dim=1)
                estimated_policy_grad_arr = side_estimated_policy_grad_arr
            
        elif self.off_policy_objective == "bdr":
            q_hat = self.q_func_estimator.predict(
                context=context,
            )
            q_hat_factual = q_hat[idx_tensor, action]
            iw = current_pi[idx_tensor, action] / pscore
            pw = obs_list/p_o
            estimated_policy_grad_arr = pw * iw * (reward - q_hat_factual)
            estimated_policy_grad_arr *= log_prob[idx_tensor, action]
            estimated_policy_grad_arr += torch.sum(q_hat * current_pi * log_prob, dim=1)
            
        return estimated_policy_grad_arr

    def _estimate_policy_constraint(
        self,
        action: torch.Tensor,
        pscore: torch.Tensor,
        action_dist: torch.Tensor,
    ) -> torch.Tensor:
        """Estimate the policy constraint term.

        Parameters
        -----------
        action: array-like, shape (n_rounds,)
            Actions sampled by the logging/behavior policy for each data in logged bandit data, i.e., :math:`a_i`.

        pscore: array-like, shape (n_rounds,), default=None
            Action choice probabilities of the logging/behavior policy (propensity scores), i.e., :math:`\\pi_b(a_i|x_i)`.

        action_dist: array-like, shape (n_rounds, n_actions, len_list)
            Action choice probabilities of the evaluation policy (can be deterministic), i.e., :math:`\\pi_e(a_i|x_i)`.

        """
        idx_tensor = torch.arange(action.shape[0], dtype=torch.long)
        iw = action_dist[idx_tensor, action, 0] / pscore

        return torch.log(iw.mean())

    def predict_dist(self, context: np.ndarray, temp: float = 1.0) -> np.ndarray:
        """Predict best actions for new data using softmax with temperature.

        Parameters
        -----------
        context: array-like, shape (n_rounds_of_new_data, dim_context)
            Context vectors for new data.
        temp: float, default=1.0
            Temperature value for adjusting the softmax distribution.
            Higher temperatures make the distribution more uniform,
            while lower temperatures make it more concentrated on the highest values.

        Returns
        -----------
        action_dist: array-like, shape (n_rounds_of_new_data, n_actions, len_list)
            Softmax action probability distribution with temperature adjustment.

        """
        check_array(array=context, name="context", expected_dim=2)
        if context.shape[1] != self.dim_context:
            raise ValueError(
                "Expected `context.shape[1] == self.dim_context`, but found it False"
            )

        self.nn_model.eval()
        x = torch.from_numpy(context).float()
        y = self.nn_model(x).detach().numpy()        
        n = context.shape[0]
        
        # Apply temperature to the predicted scores
        y_temp = y / temp
        
        # Apply softmax to obtain action probability distribution
        action_dist = scipy.special.softmax(y_temp, axis=1)
        action_dist = action_dist.reshape(n, self.n_actions, 1)

        return action_dist
    
    def predict(self, context: np.ndarray) -> np.ndarray:
        """Predict best actions for new data.

        Note
        --------
        Action set predicted by this `predict` method can contain duplicate items.
        If a non-repetitive action set is needed, please use the `sample_action` method.

        Parameters
        -----------
        context: array-like, shape (n_rounds_of_new_data, dim_context)
            Context vectors for new data.

        Returns
        -----------
        action_dist: array-like, shape (n_rounds_of_new_data, n_actions, len_list)
            Action choices made by a classifier, which can contain duplicate items.
            If a non-repetitive action set is needed, please use the `sample_action` method.

        """
        check_array(array=context, name="context", expected_dim=2)
        if context.shape[1] != self.dim_context:
            raise ValueError(
                "Expected `context.shape[1] == self.dim_context`, but found it False"
            )

        self.nn_model.eval()
        x = torch.from_numpy(context).float()
        y = self.nn_model(x).detach().numpy()
        n = context.shape[0]
        predicted_actions = np.argmax(y, axis=1)
        action_dist = np.zeros((n, self.n_actions, 1))
        action_dist[np.arange(n), predicted_actions, 0] = 1

        return action_dist
    
    def predict_own(
        self,
        context: np.ndarray,
        action: np.ndarray,
        surrogate_reward: np.ndarray,
        reward: np.ndarray,
        obs_list: np.ndarray,
        pscore: np.ndarray,
        p_o: float,
        beta: float,
        s_sum: Optional[np.ndarray] = None,
        action_dist: Optional[np.ndarray] = None,
    ) -> np.ndarray:
        context = torch.tensor(context, dtype=torch.float32)
        action = torch.tensor(action, dtype=torch.long)
        surrogate_reward = torch.tensor(surrogate_reward, dtype=torch.float32)
        reward = torch.tensor(reward, dtype=torch.float32)
        obs_list = torch.tensor(obs_list, dtype=torch.float32)
        pscore = torch.tensor(pscore, dtype=torch.float32)
        s_sum = torch.tensor(s_sum, dtype=torch.float32)
        action_dist = torch.tensor(action_dist, dtype=torch.float32)
        action_dist = torch.squeeze(action_dist, -1)

        n_rounds = context.shape[0]
        idx_tensor = torch.arange(action.shape[0], dtype=torch.long)
        if p_o != 0.0 and 0 < beta < 1:
            q_hat = self.q_func_estimator.predict(context)
            q_hat_with_s = self.q_func_with_s_estimator.predict(context, surrogate_reward)
            q_hat_factual = q_hat[idx_tensor, action]
            q_hat_factual_with_s = q_hat_with_s[idx_tensor, action]
            iw = action_dist[idx_tensor, action] / pscore
            pw = obs_list / p_o
            obs_policy_val = pw * iw * (reward - q_hat_factual_with_s)
            policy_val = iw * (q_hat_factual_with_s - q_hat_factual)
            policy_val += obs_policy_val
            policy_val += torch.sum(q_hat * action_dist, dim=1)
            
            f_hat = self.f_func_estimator.predict(context)
            f_hat_factual = f_hat[idx_tensor, action]
            side_policy_val = iw * (s_sum - f_hat_factual)
            side_policy_val += torch.sum(f_hat * action_dist, dim=1)
            
            policy_val = beta * side_policy_val + (1 - beta) * policy_val
        elif p_o != 0.0 and beta == 0:
            q_hat = self.q_func_estimator.predict(context)
            q_hat_with_s = self.q_func_with_s_estimator.predict(context, surrogate_reward)
            q_hat_factual = q_hat[idx_tensor, action]
            q_hat_factual_with_s = q_hat_with_s[idx_tensor, action]
            iw = action_dist[idx_tensor, action] / pscore
            pw = obs_list / p_o
            obs_policy_val = pw * iw * (reward - q_hat_factual_with_s)
            policy_val = iw * (q_hat_factual_with_s - q_hat_factual)
            policy_val += obs_policy_val
            policy_val += torch.sum(q_hat * action_dist, dim=1)
        else:
            iw = action_dist[idx_tensor, action] / pscore
            f_hat = self.f_func_estimator.predict(context)
            f_hat_factual = f_hat[idx_tensor, action]
            side_policy_val = iw * (s_sum - f_hat_factual)
            side_policy_val += torch.sum(f_hat * action_dist, dim=1)
            policy_val = side_policy_val

        return policy_val.mean().item()
        

    def sample_action(
        self,
        context: np.ndarray,
        tau: Union[int, float] = 1.0,
        random_state: Optional[int] = None,
    ) -> np.ndarray:
        """Sample a ranking of (non-repetitive) actions from the Plackett-Luce ranking distribution.

        Note
        --------
        This `sample_action` method samples a **non-repetitive** ranking of actions for new data
        :math:`x \\in \\mathcal{X}` via the so-called "Gumbel Softmax trick" as follows.

        .. math::

            \\s (x,a) = \\hat{f}(x,a) / \\tau + \\gamma_{x,a}, \\quad \\gamma_{x,a} \\sim \\mathrm{Gumbel}(0,1)

        :math:`\\tau` is a temperature hyperparameter.
        :math:`f: \\mathcal{X} \\times \\mathcal{A} \\times \\mathcal{K} \\rightarrow \\mathbb{R}_{+}`
        is a scoring function which is now implemented in the `predict_score` method.
        When `len_list > 0`,  the expected rewards estimated at different positions will be averaged to form :math:`f(x,a)`.
        :math:`\\gamma_{x,a}` is a random variable sampled from the Gumbel distribution.
        By sorting the actions based on :math:`\\s (x,a)` for each context, we can efficiently sample a ranking from
        the Plackett-Luce ranking distribution.

        Parameters
        ----------------
        context: array-like, shape (n_rounds_of_new_data, dim_context)
            Context vectors for new data.

        tau: int or float, default=1.0
            A temperature parameter that controls the randomness of the action choice
            by scaling the scores before applying softmax.
            As :math:`\\tau \\rightarrow \\infty`, the algorithm will select arms uniformly at random.

        random_state: int, default=None
            Controls the random seed in sampling actions.

        Returns
        -----------
        sampled_action: array-like, shape (n_rounds_of_new_data, n_actions, len_list)
            Ranking of actions sampled from the Plackett-Luce ranking distribution via the Gumbel softmax trick.

        """
        check_array(array=context, name="context", expected_dim=2)
        check_scalar(tau, name="tau", target_type=(int, float), min_val=0)
        if context.shape[1] != self.dim_context:
            raise ValueError(
                "Expected `context.shape[1] == self.dim_context`, but found it False"
            )

        n = context.shape[0]
        random_ = check_random_state(random_state)
        sampled_action = np.zeros((n, self.n_actions, self.len_list))
        scores = self.predict_proba(context=context).mean(2) / tau
        scores += random_.gumbel(size=scores.shape)
        ranking = np.argsort(-scores, axis=1)
        for p in np.arange(self.len_list):
            sampled_action[np.arange(n), ranking[:, p], p] = 1
        return sampled_action

    def predict_proba(
        self,
        context: np.ndarray,
    ) -> np.ndarray:
        """Obtains action choice probabilities for new data.

        Note
        --------
        This policy uses multi-layer perceptron (MLP) and the softmax function as the last layer.
        This is a stochastic policy and represented as follows:

        .. math::

            \\pi_\\theta (a \\mid x) = \\frac{\\exp(f_\\theta(x, a))}{\\sum_{a' \\in \\mathcal{A}} \\exp(f_\\theta(x, a'))}

        where :math:`f__\\theta(x, a)` is MLP with parameter :math:`\\theta`.

        Parameters
        ----------------
        context: array-like, shape (n_rounds_of_new_data, dim_context)
            Context vectors for new data.

        Returns
        -----------
        choice_prob: array-like, shape (n_rounds_of_new_data, n_actions, len_list)
            Action choice probabilities obtained by a trained classifier.

        """
        check_array(array=context, name="context", expected_dim=2)
        if context.shape[1] != self.dim_context:
            raise ValueError(
                "Expected `context.shape[1] == self.dim_context`, but found it False"
            )

        self.nn_model.eval()
        x = torch.from_numpy(context).float()
        y = self.nn_model(x).detach().numpy()
        return y[:, :, np.newaxis]

@dataclass
class OurLearner:
    """Off-policy learner parameterized by a neural network.

    Parameters
    -----------
    n_actions: int
        Number of actions.
        
    dim_context: int
        Number of dimensions of context vectors.
        
    s_dim: int
        Number of dimensions of surrogate reward.

    off_policy_objective: str
        An OPE estimator used to estimate the policy gradient.
        Must be one of 'ours', 'sdr', 'bdr'.

    policy_reg_param: float, default=0.0
        A hypeparameter to control the policy regularization. :math:`\\lambda_{pol}`.

    var_reg_param: float, default=0.0
        A hypeparameter to control the variance regularization. :math:`\\lambda_{var}`.

    hidden_layer_size: Tuple[int, ...], default = (100,)
        The i-th element specifies the size of the i-th layer.

    activation: str, default='identity'
        Activation function.
        Must be one of the followings:

        - 'identity', the identity function, :math:`f(x) = x`.
        - 'logistic', the sigmoid function, :math:`f(x) = \\frac{1}{1 + \\exp(x)}`.
        - 'tanh', the hyperbolic tangent function, `:math:f(x) = \\frac{\\exp(x) - \\exp(-x)}{\\exp(x) + \\exp(-x)}`
        - 'relu', the rectified linear unit function, `:math:f(x) = \\max(0, x)`

    solver: str, default='adam'
        Optimizer of the neural network.
        Must be one of the followings:

        - 'sgd', Stochastic Gradient Descent.
        - 'adam', Adam (Kingma and Ba 2014).
        - 'adagrad', Adagrad (Duchi et al. 2011).

    alpha: float, default=0.001
        L2 penalty.

    batch_size: Union[int, str], default="auto"
        Batch size for SGD, Adagrad, and Adam.
        If "auto", the maximum of 200 and the number of samples is used.
        If integer, must be positive.

    learning_rate_init: int, default=0.0001
        Initial learning rate for SGD, Adagrad, and Adam.

    max_iter: int, default=200
        Number of epochs for SGD, Adagrad, and Adam.

    shuffle: bool, default=True
        Whether to shuffle samples in SGD and Adam.

    random_state: Optional[int], default=None
        Controls the random seed.

    tol: float, default=1e-4
        Tolerance for training.
        When the Training Gradient is not improved at least `tol' for `n_iter_no_change' consecutive iterations,
        training is stopped.

    momentum: float, default=0.9
        Momentum for SGD.
        Must be in the range of [0., 1.].

    nesterovs_momentum: bool, default=True
        Whether to use Nesterovs momentum.

    early_stopping: bool, default=False
        Whether to use early stopping for SGD, Adagrad, and Adam.
        If set to true, `validation_fraction' of training data is used as validation data,
        and training is stopped when the Validation Gradient is not improved at least `tol' for `n_iter_no_change' consecutive iterations.

    validation_fraction: float, default=0.1
        Fraction of validation data when early stopping is used.
        Must be in the range of (0., 1.].

    beta_1: float, default=0.9
        Coefficient used for computing running average of gradient for Adam.
        Must be in the range of [0., 1.].

    beta_2: float, default=0.999
        Coefficient used for computing running average of the square of gradient for Adam.
        Must be in the range of [0., 1.].

    epsilon: float, default=1e-8
        Term for numerical stability in Adam.

    n_iter_no_change: int, default=10
        Maximum number of not improving epochs when early stopping is used.

    q_func_estimator_hyperparams: Dict, default=None
        A set of hyperparameters to define q function estimator. i.e. \hat{q}(x,a)
        
    q_func_with_s_hyperparams: Dict, default=None
        A set of hyperparameters to define q function estimator. i.e. \hat{q}(x,a,s)
    """
    n_actions: int
    len_list: int = 1   
    dim_context: Optional[int] = None
    s_dim: Optional[int] = None
    off_policy_objective: Optional[str] = None
    policy_reg_param: float = 0.0
    var_reg_param: float = 0.0
    hidden_layer_size: Tuple[int, ...] = (100,)
    activation: str = "relu"
    solver: str = "adam"
    alpha: float = 0.0001
    batch_size: Union[int, str] = "auto"
    learning_rate_init: float = 0.0001
    max_iter: int = 200
    shuffle: bool = True
    random_state: Optional[int] = None
    tol: float = 1e-4
    momentum: float = 0.9
    nesterovs_momentum: bool = True
    early_stopping: bool = False
    validation_fraction: float = 0.1
    beta_1: float = 0.9
    beta_2: float = 0.999
    epsilon: float = 1e-8
    n_iter_no_change: int = 10
    q_func_estimator_hyperparams: Optional[Dict] = None
    q_func_with_s_hyperparams: Optional[Dict] = None
    
    def __post_init__(self) -> None:
        """Initialize class."""
        
        check_scalar(self.n_actions, "n_actions", int, min_val=1)

        check_scalar(self.dim_context, "dim_context", int, min_val=1)
        
        check_scalar(self.s_dim, "s_dim", int, min_val=1)
        
        if self.off_policy_objective not in [
            "ours",
            "sdr",
            "bdr",
            "sdr-both",
            "ours-gamma"
        ]:
            raise ValueError(
                "`off_policy_objective` {self.off_policy_objective} is given"
            )

        check_scalar(
            self.policy_reg_param,
            "policy_reg_param",
            (int, float),
            min_val=0.0,
        )

        check_scalar(
            self.var_reg_param,
            "var_reg_param",
            (int, float),
            min_val=0.0,
        )
        
        

        if not isinstance(self.hidden_layer_size, tuple) or any(
            [not isinstance(h, int) or h <= 0 for h in self.hidden_layer_size]
        ):
            raise ValueError(
                f"`hidden_layer_size` must be a tuple of positive integers, but {self.hidden_layer_size} is given"
            )

        if self.solver not in ("adagrad", "sgd", "adam"):
            raise ValueError(
                f"`solver` must be one of 'adam', 'adagrad', or 'sgd', but {self.solver} is given"
            )
            
        # if self.w_xs_estimation not in ("estimate_prob", "estimate_weight"):
        #     raise ValueError(
        #         f"`w_xs_estimation` must be one of 'estimate_prob', 'estimate_weight', but {self.w_xs_estimation} is given"
        #     )

        check_scalar(self.alpha, "alpha", float, min_val=0.0)

        if self.batch_size != "auto" and (
            not isinstance(self.batch_size, int) or self.batch_size <= 0
        ):
            raise ValueError(
                f"`batch_size` must be a positive integer or 'auto', but {self.batch_size} is given"
            )

        check_scalar(self.learning_rate_init, "learning_rate_init", float)
        if self.learning_rate_init <= 0.0:
            raise ValueError(
                f"`learning_rate_init`= {self.learning_rate_init}, must be > 0.0"
            )

        check_scalar(self.max_iter, "max_iter", int, min_val=1)

        if not isinstance(self.shuffle, bool):
            raise ValueError(f"`shuffle` must be a bool, but {self.shuffle} is given")

        check_scalar(self.tol, "tol", float)
        if self.tol <= 0.0:
            raise ValueError(f"`tol`= {self.tol}, must be > 0.0")

        check_scalar(self.momentum, "momentum", float, min_val=0.0, max_val=1.0)

        if not isinstance(self.nesterovs_momentum, bool):
            raise ValueError(
                f"`nesterovs_momentum` must be a bool, but {self.nesterovs_momentum} is given"
            )

        if not isinstance(self.early_stopping, bool):
            raise ValueError(
                f"`early_stopping` must be a bool, but {self.early_stopping} is given"
            )

        check_scalar(
            self.validation_fraction, "validation_fraction", float, max_val=1.0
        )
        if self.validation_fraction <= 0.0:
            raise ValueError(
                f"`validation_fraction`= {self.validation_fraction}, must be > 0.0"
            )

        if self.q_func_estimator_hyperparams is not None:
            if not isinstance(self.q_func_estimator_hyperparams, dict):
                raise ValueError(
                    "`q_func_estimator_hyperparams` must be a dict"
                    f", but {type(self.q_func_estimator_hyperparams)} is given"
                )
        check_scalar(self.beta_1, "beta_1", float, min_val=0.0, max_val=1.0)
        check_scalar(self.beta_2, "beta_2", float, min_val=0.0, max_val=1.0)
        check_scalar(self.epsilon, "epsilon", float, min_val=0.0)
        check_scalar(self.n_iter_no_change, "n_iter_no_change", int, min_val=1)

        if self.random_state is not None:
            self.random_ = check_random_state(self.random_state)
            torch.manual_seed(self.random_state)

        if self.activation == "identity":
            activation_layer = nn.Identity
        elif self.activation == "logistic":
            activation_layer = nn.Sigmoid
        elif self.activation == "tanh":
            activation_layer = nn.Tanh
        elif self.activation == "relu":
            activation_layer = nn.ReLU
        elif self.activation == "elu":
            activation_layer = nn.ELU
        else:
            raise ValueError(
                "`activation` must be one of 'identity', 'logistic', 'tanh', 'relu', or 'elu'"
                f", but {self.activation} is given"
            )

        layer_list = []
        input_size = self.dim_context

        for i, h in enumerate(self.hidden_layer_size):
            layer_list.append(("l{}".format(i), nn.Linear(input_size, h)))
            layer_list.append(("a{}".format(i), activation_layer()))
            input_size = h
        layer_list.append(("output", nn.Linear(input_size, self.n_actions)))
        layer_list.append(("softmax", nn.Softmax(dim=1)))

        self.nn_model = nn.Sequential(OrderedDict(layer_list))

        if self.q_func_estimator_hyperparams is not None:
            self.q_func_estimator_hyperparams["n_actions"] = self.n_actions
            self.q_func_estimator_hyperparams["dim_context"] = self.dim_context
            self.q_func_estimator = QFuncEstimator(
                **self.q_func_estimator_hyperparams
            )
        else:
            self.q_func_estimator = QFuncEstimator(
                n_actions=self.n_actions, dim_context=self.dim_context
            )
        if self.q_func_with_s_hyperparams is not None:
            self.q_func_with_s_hyperparams["n_actions"]=self.n_actions
            self.q_func_with_s_hyperparams["dim_context"]=self.dim_context
            self.q_func_with_s_hyperparams["s_dim"]=self.s_dim
            self.q_func_with_s_estimator = QFuncEstimatorWithS(
                **self.q_func_with_s_hyperparams
            )
        else:
            self.q_func_with_s_estimator = QFuncEstimatorWithS(
                n_actions=self.n_actions, dim_context=self.dim_context, s_dim=self.s_dim
            )
        if self.off_policy_objective == "sdr-both" or self.off_policy_objective == "ours-gamma":
            if self.q_func_estimator_hyperparams is not None:
                self.q_func_estimator_hyperparams["n_actions"] = self.n_actions
                self.q_func_estimator_hyperparams["dim_context"] = self.dim_context
                self.f_func_estimator = QFuncEstimator(
                    **self.q_func_estimator_hyperparams
                )
            else:
                self.f_func_estimator = QFuncEstimator(
                    n_actions=self.n_actions, dim_context=self.dim_context
                )
    def evaluate(self, context, action, reward, pscore, action_dist):
        action_dist = np.squeeze(action_dist, -1)
        iw = action_dist[np.arange(action.shape[0]), action] / pscore
        val = reward*iw
        return val.mean()
        

    def _create_train_data_for_opl(
        self,
        context: np.ndarray,
        action: np.ndarray,
        surrogate_reward: np.ndarray,
        reward: np.ndarray,
        obs_list: np.ndarray,
        pscore: np.ndarray,
        s_sum: Optional[np.ndarray] = None,
        # pi_b: np.ndarray,
        # f_x_a: np.ndarray,
        # position: np.ndarray,
        **kwargs,
    ) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.DataLoader]]:
        """Create training data for off-policy learning.

        Parameters
        -----------
        context: array-like, shape (n_rounds, dim_context)
            Context vectors observed for each data, i.e., :math:`x_i`.

        action: array-like, shape (n_rounds,)
            Actions sampled by the logging/behavior policy for each data in logged bandit data, i.e., :math:`a_i`.

        surrogate_reward: array-like, shape (n_rounds,)
            Surrogate rewards observed for each data in logged bandit data, i.e., :math:`s_i`.

        reward: array-like, shape (n_rounds,)
            Rewards observed for each data in logged bandit data, i.e., :math:`r_i`.

        pscore: array-like, shape (n_rounds,), default=None
            Action choice probabilities of the logging/behavior policy (propensity scores), i.e., :math:`\\pi_b(a_i|x_i)`.

        position: array-like, shape (n_rounds,), default=None
            Indices to differentiate positions in a recommendation interface where the actions are presented.
            If None, a learner assumes that only a single action is chosen for each data.

        Returns
        --------
        (training_data_loader, validation_data_loader): Tuple[DataLoader, Optional[DataLoader]]
            Training and validation data loaders in PyTorch

        """
        if self.batch_size == "auto":
            batch_size_ = min(200, context.shape[0])
        else:
            check_scalar(self.batch_size, "batch_size", int, min_val=1)
            batch_size_ = self.batch_size
        context = context.astype('float32')
        if s_sum is None:
            dataset = NNPolicyDataset(
                torch.from_numpy(context).float(),
                torch.from_numpy(action).long(),
                torch.from_numpy(surrogate_reward).float(),
                torch.from_numpy(reward).float(),
                torch.from_numpy(obs_list).float(),
                torch.from_numpy(pscore).float(),)
        else:
            dataset = NNPolicyDataset_with_fs(
                torch.from_numpy(context).float(),
                torch.from_numpy(action).long(),
                torch.from_numpy(surrogate_reward).float(),
                torch.from_numpy(s_sum).float(),
                torch.from_numpy(reward).float(),
                torch.from_numpy(obs_list).float(),
                torch.from_numpy(pscore).float(),)


        if self.early_stopping:
            if context.shape[0] <= 1:
                raise ValueError(
                    f"the number of samples is too small ({context.shape[0]}) to create validation data"
                )

            validation_size = max(int(context.shape[0] * self.validation_fraction), 1)
            training_size = context.shape[0] - validation_size
            training_dataset, validation_dataset = torch.utils.data.random_split(
                dataset, [training_size, validation_size]
            )
            training_data_loader = torch.utils.data.DataLoader(
                training_dataset,
                batch_size=batch_size_,
                shuffle=self.shuffle,
            )
            validation_data_loader = torch.utils.data.DataLoader(
                validation_dataset,
                batch_size=batch_size_,
                shuffle=self.shuffle,
            )

            return training_data_loader, validation_data_loader

        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size_,
            shuffle=self.shuffle,
        )

        return data_loader, None

    def fit(
        self,
        context: np.ndarray,
        action: np.ndarray,
        surrogate_reward: np.ndarray,
        reward: np.ndarray,
        obs_list: np.ndarray,
        p_o: float,
        s_sum: Optional[np.ndarray] = None,
        pscore: Optional[np.ndarray] = None,
        beta: Optional[float] = None,
    ) -> None:
        """Fits an offline bandit policy on the given logged bandit data.

        Note
        ----------
        Given the training data :math:`\\mathcal{D}`, this policy maximizes the following objective function:

        .. math::

            \\hat{V}(\\pi_\\theta; \\mathcal{D}) - \\alpha \\Omega(\\theta)

        where :math:`\\hat{V}` is an OPE estimator and :math:`\\alpha \\Omega(\\theta)` is a regularization term.

        Parameters
        -----------
        context: array-like, shape (n_rounds, dim_context)
            Context vectors observed for each data, i.e., :math:`x_i`.

        action: array-like, shape (n_rounds,)
            Actions sampled by the logging/behavior policy for each data in logged bandit data, i.e., :math:`a_i`.

        reward: array-like, shape (n_rounds,)
            Rewards observed for each data in logged bandit data, i.e., :math:`r_i`.

        pscore: array-like, shape (n_rounds,), default=None
            Action choice probabilities of the logging/behavior policy (propensity scores), i.e., :math:`\\pi_b(a_i|x_i)`.

        """
        
        if context.shape[1] != self.dim_context:
            raise ValueError(
                "Expected `context.shape[1] == self.dim_context`, but found it False"
            )
        if pscore is None:
            pscore = np.ones_like(action) / self.n_actions
        if self.len_list == 1:
            position = np.zeros_like(action, dtype=int)
            
        # train 
        # if self.w_xs_estimation == "estimate_prob":
        #     x_a = np.column_stack([context, action])
        #     self.kde.fit(x_a)
        obs_context=context[obs_list==1]
        obs_action=action[obs_list==1]
        obs_reward=reward[obs_list==1]
        # train q function estimator
        if p_o!=0.0:
            if beta<1.0:
                self.q_func_estimator.fit(
                    context=obs_context,
                    action=obs_action,
                    reward=obs_reward,
                )
                if self.off_policy_objective != "bdr":
                    self.q_func_with_s_estimator.fit(
                        context=context,
                        action=action,
                        reward=reward,
                        surrogate_reward = surrogate_reward,
                        obs_list = obs_list,
                    )
        if self.off_policy_objective == "sdr-both" and beta>0.0:
            self.f_func_estimator.fit(
                context=context,
                action=action,
                reward=s_sum,
            )
        
                    
        # train q_function
        if self.solver == "sgd":
            optimizer = optim.SGD(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                momentum=self.momentum,
                weight_decay=self.alpha,
                nesterov=self.nesterovs_momentum,
            )
        elif self.solver == "adagrad":
            optimizer = optim.Adagrad(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                eps=self.epsilon,
                weight_decay=self.alpha,
            )
        elif self.solver == "adam":
            optimizer = optim.Adam(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                betas=(self.beta_1, self.beta_2),
                eps=self.epsilon,
                weight_decay=self.alpha,
            )
        else:
            raise NotImplementedError(
                "`solver` must be one of 'adam', 'adagrad', or 'sgd'"
            )

        if self.off_policy_objective == "sdr-both":
            training_data_loader, validation_data_loader = self._create_train_data_for_opl(
                context=context, action=action, surrogate_reward=surrogate_reward, reward=reward, obs_list=obs_list, pscore=pscore, s_sum=s_sum
            )
        else:
            training_data_loader, validation_data_loader = self._create_train_data_for_opl(
                context=context, action=action, surrogate_reward=surrogate_reward, reward=reward, obs_list=obs_list, pscore=pscore
            )
        training_losses = []
        validation_losses = []
        validation_steps = []
        if self.off_policy_objective == "sdr-both":
            n_not_improving_training = 0
            previous_training_loss = None
            n_not_improving_validation = 0
            previous_validation_loss = None
            for iteration in tqdm(np.arange(self.max_iter), desc="policy learning"):
                self.nn_model.train()
                train_loss_accumulator = []
                val_loss_accumulator = []
                for x, a, s, s_sum, r, o, p in training_data_loader:
                    optimizer.zero_grad()
                    pi = self.nn_model(x).unsqueeze(-1)
                    policy_grad_arr = self._estimate_policy_gradient(
                        context=x,
                        reward=r,
                        action=a,
                        surrogate_reward=s,
                        s_sum=s_sum,
                        pscore=p,
                        obs_list=o,
                        p_o=p_o,
                        action_dist=pi,
                        beta=beta,
                    )
                    policy_constraint = self._estimate_policy_constraint(
                        action=a,
                        pscore=p,
                        action_dist=pi,
                    )
                    loss = -policy_grad_arr.mean()
                    loss += self.policy_reg_param * policy_constraint
                    loss += self.var_reg_param * torch.var(policy_grad_arr)
                    loss.backward()
                    optimizer.step()

                    loss_value = loss.item()
                    # training_losses.append(loss.item())
                    train_loss_accumulator.append(loss.item())
                    if previous_training_loss is not None:
                        if loss_value - previous_training_loss < self.tol:
                            n_not_improving_training += 1
                        else:
                            n_not_improving_training = 0
                    if n_not_improving_training >= self.n_iter_no_change:
                        break
                    previous_training_loss = loss_value
                average_train_loss = sum(train_loss_accumulator) / len(train_loss_accumulator)
                training_losses.append(average_train_loss)
                if self.early_stopping:
                    self.nn_model.eval()
                    for x, a, s, s_sum, r, o, p in validation_data_loader:
                        pi = self.nn_model(x).unsqueeze(-1)
                        policy_grad_arr = self._estimate_policy_gradient(
                            context=x,
                            reward=r,
                            action=a,
                            surrogate_reward=s,
                            s_sum=s_sum,
                            pscore=p,
                            obs_list=o,
                            p_o=p_o,
                            action_dist=pi,
                            beta=beta
                        )
                        policy_constraint = self._estimate_policy_constraint(
                            action=a,
                            pscore=p,
                            action_dist=pi,
                        )
                        loss = -policy_grad_arr.mean()
                        loss += self.policy_reg_param * policy_constraint
                        loss += self.var_reg_param * torch.var(policy_grad_arr)
                        loss_value = loss.item()
                        val_loss_accumulator.append(loss_value)
                        validation_steps.append(iteration)
                        if previous_validation_loss is not None:
                            if loss_value - previous_validation_loss < self.tol:
                                n_not_improving_validation += 1
                            else:
                                n_not_improving_validation = 0
                        if n_not_improving_validation > self.n_iter_no_change:
                            break
                        previous_validation_loss = loss_value
                    average_val_loss = sum(val_loss_accumulator) / len(val_loss_accumulator)
                    validation_losses.append(average_val_loss)
        else:
            n_not_improving_training = 0
            previous_training_loss = None
            n_not_improving_validation = 0
            previous_validation_loss = None
            for iteration in tqdm(np.arange(self.max_iter), desc="policy learning"):
                self.nn_model.train()
                train_loss_accumulator = []
                val_loss_accumulator = []
                for x, a, s, r, o, p in training_data_loader:
                    optimizer.zero_grad()
                    pi = self.nn_model(x).unsqueeze(-1)
                    policy_grad_arr = self._estimate_policy_gradient(
                        context=x,
                        reward=r,
                        action=a,
                        surrogate_reward=s,
                        pscore=p,
                        obs_list=o,
                        p_o=p_o,
                        action_dist=pi,
                    )
                    policy_constraint = self._estimate_policy_constraint(
                        action=a,
                        pscore=p,
                        action_dist=pi,
                    )
                    loss = -policy_grad_arr.mean()
                    loss += self.policy_reg_param * policy_constraint
                    loss += self.var_reg_param * torch.var(policy_grad_arr)
                    loss.backward()
                    optimizer.step()

                    loss_value = loss.item()
                    train_loss_accumulator.append(loss.item())
                    if previous_training_loss is not None:
                        if loss_value - previous_training_loss < self.tol:
                            n_not_improving_training += 1
                        else:
                            n_not_improving_training = 0
                    if n_not_improving_training >= self.n_iter_no_change:
                        break
                    previous_training_loss = loss_value
                average_train_loss = sum(train_loss_accumulator) / len(train_loss_accumulator)
                training_losses.append(average_train_loss)

                if self.early_stopping:
                    self.nn_model.eval()
                    for x, a, s, r, o, p in validation_data_loader:
                        pi = self.nn_model(x).unsqueeze(-1)
                        policy_grad_arr = self._estimate_policy_gradient(
                            context=x,
                            reward=r,
                            action=a,
                            surrogate_reward=s,
                            pscore=p,
                            obs_list=o,
                            p_o=p_o,
                            action_dist=pi,
                        )
                        policy_constraint = self._estimate_policy_constraint(
                            action=a,
                            pscore=p,
                            action_dist=pi,
                        )
                        loss = -policy_grad_arr.mean()
                        loss += self.policy_reg_param * policy_constraint
                        loss += self.var_reg_param * torch.var(policy_grad_arr)
                        loss_value = loss.item()
                        val_loss_accumulator.append(loss_value)
                        validation_steps.append(iteration)
                        if previous_validation_loss is not None:
                            if loss_value - previous_validation_loss < self.tol:
                                n_not_improving_validation += 1
                            else:
                                n_not_improving_validation = 0
                        if n_not_improving_validation > self.n_iter_no_change:
                            break
                        previous_validation_loss = loss_value
                    average_val_loss = sum(val_loss_accumulator) / len(val_loss_accumulator)
                    validation_losses.append(average_val_loss)
        # if self.early_stopping:
        #     plt.figure(figsize=(12, 6))
        #     plt.plot(range(len(training_losses)), training_losses, label='Training Gradient')
        #     plt.plot(range(len(validation_losses)), validation_losses, label='Validation Gradient')
        #     plt.xlabel('Epochs')
        #     plt.ylabel('Loss')
        #     plt.title('Learning Curve')
        #     plt.legend()
        #     plt.grid(True)
        #     plt.show()
        # else:
        #     plt.figure(figsize=(12, 6))
        #     plt.plot(range(len(training_losses)), training_losses, label='Training Gradient')
        #     plt.xlabel('Epochs')
        #     plt.ylabel('Loss')
        #     plt.title('Learning Curve')
        #     plt.legend()
        #     plt.grid(True)
        #     plt.show()
                        
    #####修正点 - 一番修正必要
    def _estimate_policy_gradient(
        self,
        context: torch.Tensor,
        action: torch.Tensor,
        surrogate_reward: torch.Tensor,
        reward: torch.Tensor,
        obs_list: torch.Tensor,
        pscore: torch.Tensor,
        p_o: float,
        action_dist: torch.Tensor,
        s_sum: Optional[torch.Tensor] = None,
        beta: Optional[float] = None,
    ) -> torch.Tensor:
        """Estimate the policy gradient.

        Parameters
        -----------
        context: array-like, shape (batch_size, dim_context)
            Context vectors observed for each data, i.e., :math:`x_i`.

        action: array-like, shape (batch_size,)
            Actions sampled by the logging/behavior policy for each data in logged bandit data, i.e., :math:`a_i`.

        surrogate_reward: array-like, shape (batch_size,)
            Surrogate rewards observed for each data in logged bandit data, i.e., :math:`s_i`.
        
        reward: array-like, shape (batch_size,)
            Rewards observed for each data in logged bandit data, i.e., :math:`r_i`.

        pscore: array-like, shape (batch_size,), default=None
            Action choice probabilities of the logging/behavior policy (propensity scores), i.e., :math:`\\pi_b(a_i|x_i)`.

        action_dist: array-like, shape (batch_size, n_actions, len_list)
            Action choice probabilities of the evaluation policy (can be deterministic), i.e., :math:`\\pi_e(a_i|x_i)`.

        Returns
        ----------
        estimated_policy_grad_arr: array-like, shape (batch_size,)
            Rewards of each data estimated by an OPE estimator.

        """
        current_pi = action_dist[:, :, 0].detach()
        log_prob = torch.log(action_dist[:, :, 0])
        idx_tensor = torch.arange(action.shape[0], dtype=torch.long)
        # print(f"Max pi: {torch.mean(current_pi[idx_tensor], dim=0).max()}")
        # print(f"Argmax pi: {torch.argmax(torch.mean(current_pi[idx_tensor], dim=0))}")
        # obs_context = context[obs_list==1]
        # obs_action = action[obs_list==1]
        # obs_reward = reward[obs_list==1]
        # obs_surrogate_reward = surrogate_reward[obs_list==1]
        # obs_action_dist = action_dist[obs_list==1]
        # obs_pscore = pscore[obs_list==1]
        # obs_idx_tensor = torch.arange(obs_action.shape[0], dtype=torch.long)
        # obs_=obs_list[obs_list==1]

        if self.off_policy_objective == "ours":
            raise ValueError(
                "ours is not ready"
            )
            q_hat = self.q_func_estimator.predict(
                context=context,
            )
            q_hat_factual = q_hat[idx_tensor, action]
            iw = current_pi[idx_tensor, action] / pscore
            pw = obs_list/p_o
            estimated_policy_grad_arr = pw * iw * (reward - q_hat_factual)
            estimated_policy_grad_arr *= log_prob[idx_tensor, action]
            estimated_policy_grad_arr += torch.sum(q_hat * current_pi * log_prob, dim=1)
        elif self.off_policy_objective == "sdr":
            q_hat = self.q_func_estimator.predict(
                context=context,
            )
            q_hat_with_s = self.q_func_with_s_estimator.predict(
                context=context,
                surrogate_reward = surrogate_reward,
            )
            q_hat_factual = q_hat[idx_tensor, action]
            q_hat_factual_with_s = q_hat_with_s[idx_tensor, action]
            iw = current_pi[idx_tensor, action] / pscore
            if p_o!=0.0:
                pw = obs_list/p_o
                obs_estimated_policy_grad_arr = pw * iw * (reward - q_hat_factual_with_s)
                obs_estimated_policy_grad_arr *= log_prob[idx_tensor, action]
                estimated_policy_grad_arr = iw * (q_hat_factual_with_s - q_hat_factual)
                estimated_policy_grad_arr *= log_prob[idx_tensor, action]
                estimated_policy_grad_arr += obs_estimated_policy_grad_arr
                estimated_policy_grad_arr += torch.sum(q_hat * current_pi * log_prob, dim=1)
            else:
                raise ValueError("p_o is 0.0, so cannot calculate the sdr")
        
        elif self.off_policy_objective == "sdr-both":
            if p_o!=0.0 and beta<1 and beta>0:
                q_hat = self.q_func_estimator.predict(
                    context=context,
                )
                q_hat_with_s = self.q_func_with_s_estimator.predict(
                    context=context,
                    surrogate_reward = surrogate_reward,
                )
                q_hat_factual = q_hat[idx_tensor, action]
                
                q_hat_factual_with_s = q_hat_with_s[idx_tensor, action]
                iw = current_pi[idx_tensor, action] / pscore
                pw = obs_list/p_o

                obs_estimated_policy_grad_arr = pw * iw * (reward - q_hat_factual_with_s)
                obs_estimated_policy_grad_arr *= log_prob[idx_tensor, action]
                estimated_policy_grad_arr = iw * (q_hat_factual_with_s - q_hat_factual)
                estimated_policy_grad_arr *= log_prob[idx_tensor, action]
                estimated_policy_grad_arr += obs_estimated_policy_grad_arr
                estimated_policy_grad_arr += torch.sum(q_hat * current_pi * log_prob, dim=1)
                f_hat = self.f_func_estimator.predict(
                    context=context,
                )
                f_hat_factual = f_hat[idx_tensor, action]
                side_estimated_policy_grad_arr = iw * (s_sum-f_hat_factual)
                side_estimated_policy_grad_arr *= log_prob[idx_tensor, action]
                side_estimated_policy_grad_arr += torch.sum(f_hat * current_pi * log_prob, dim=1)
                estimated_policy_grad_arr = beta*side_estimated_policy_grad_arr + (1-beta)*estimated_policy_grad_arr
            elif p_o!=0.0 and beta==0:
                q_hat = self.q_func_estimator.predict(
                    context=context,
                )
                q_hat_with_s = self.q_func_with_s_estimator.predict(
                    context=context,
                    surrogate_reward = surrogate_reward,
                )
                q_hat_factual = q_hat[idx_tensor, action]
                q_hat_factual_with_s = q_hat_with_s[idx_tensor, action]
                iw = current_pi[idx_tensor, action] / pscore
                pw = obs_list/p_o
                obs_estimated_policy_grad_arr = pw * iw * (reward - q_hat_factual_with_s)
                obs_estimated_policy_grad_arr *= log_prob[idx_tensor, action]
                estimated_policy_grad_arr = iw * (q_hat_factual_with_s - q_hat_factual)
                estimated_policy_grad_arr *= log_prob[idx_tensor, action]
                estimated_policy_grad_arr += obs_estimated_policy_grad_arr
                estimated_policy_grad_arr += torch.sum(q_hat * current_pi * log_prob, dim=1)
            else:
                iw = current_pi[idx_tensor, action] / pscore
                f_hat = self.f_func_estimator.predict(
                    context=context,
                )
                f_hat_factual = f_hat[idx_tensor, action]
                side_estimated_policy_grad_arr = iw * (s_sum-f_hat_factual)
                side_estimated_policy_grad_arr *= log_prob[idx_tensor, action]
                side_estimated_policy_grad_arr += torch.sum(f_hat * current_pi * log_prob, dim=1)
                estimated_policy_grad_arr = side_estimated_policy_grad_arr
            
        elif self.off_policy_objective == "bdr":
            q_hat = self.q_func_estimator.predict(
                context=context,
            )
            q_hat_factual = q_hat[idx_tensor, action]
            iw = current_pi[idx_tensor, action] / pscore
            pw = obs_list/p_o
            estimated_policy_grad_arr = pw * iw * (reward - q_hat_factual)
            estimated_policy_grad_arr *= log_prob[idx_tensor, action]
            estimated_policy_grad_arr += torch.sum(q_hat * current_pi * log_prob, dim=1)
            
        return estimated_policy_grad_arr

    def _estimate_policy_constraint(
        self,
        action: torch.Tensor,
        pscore: torch.Tensor,
        action_dist: torch.Tensor,
    ) -> torch.Tensor:
        """Estimate the policy constraint term.

        Parameters
        -----------
        action: array-like, shape (n_rounds,)
            Actions sampled by the logging/behavior policy for each data in logged bandit data, i.e., :math:`a_i`.

        pscore: array-like, shape (n_rounds,), default=None
            Action choice probabilities of the logging/behavior policy (propensity scores), i.e., :math:`\\pi_b(a_i|x_i)`.

        action_dist: array-like, shape (n_rounds, n_actions, len_list)
            Action choice probabilities of the evaluation policy (can be deterministic), i.e., :math:`\\pi_e(a_i|x_i)`.

        """
        idx_tensor = torch.arange(action.shape[0], dtype=torch.long)
        iw = action_dist[idx_tensor, action, 0] / pscore

        return torch.log(iw.mean())

    def predict_dist(self, context: np.ndarray, temp: float = 1.0) -> np.ndarray:
        """Predict best actions for new data using softmax with temperature.

        Parameters
        -----------
        context: array-like, shape (n_rounds_of_new_data, dim_context)
            Context vectors for new data.
        temp: float, default=1.0
            Temperature value for adjusting the softmax distribution.
            Higher temperatures make the distribution more uniform,
            while lower temperatures make it more concentrated on the highest values.

        Returns
        -----------
        action_dist: array-like, shape (n_rounds_of_new_data, n_actions, len_list)
            Softmax action probability distribution with temperature adjustment.

        """
        check_array(array=context, name="context", expected_dim=2)
        if context.shape[1] != self.dim_context:
            raise ValueError(
                "Expected `context.shape[1] == self.dim_context`, but found it False"
            )

        self.nn_model.eval()
        x = torch.from_numpy(context).float()
        y = self.nn_model(x).detach().numpy()        
        n = context.shape[0]
        
        # Apply temperature to the predicted scores
        y_temp = y / temp
        
        # Apply softmax to obtain action probability distribution
        action_dist = scipy.special.softmax(y_temp, axis=1)
        action_dist = action_dist.reshape(n, self.n_actions, 1)

        return action_dist

    
    def predict(self, context: np.ndarray) -> np.ndarray:
        """Predict best actions for new data.

        Note
        --------
        Action set predicted by this `predict` method can contain duplicate items.
        If a non-repetitive action set is needed, please use the `sample_action` method.

        Parameters
        -----------
        context: array-like, shape (n_rounds_of_new_data, dim_context)
            Context vectors for new data.

        Returns
        -----------
        action_dist: array-like, shape (n_rounds_of_new_data, n_actions, len_list)
            Action choices made by a classifier, which can contain duplicate items.
            If a non-repetitive action set is needed, please use the `sample_action` method.

        """
        check_array(array=context, name="context", expected_dim=2)
        if context.shape[1] != self.dim_context:
            raise ValueError(
                "Expected `context.shape[1] == self.dim_context`, but found it False"
            )

        self.nn_model.eval()
        context = context.astype('float32')
        x = torch.from_numpy(context).float()
        y = self.nn_model(x).detach().numpy()
        n = context.shape[0]
        predicted_actions = np.argmax(y, axis=1)
        action_dist = np.zeros((n, self.n_actions, 1))
        action_dist[np.arange(n), predicted_actions, 0] = 1

        return action_dist
    
    def predict_own(
        self,
        context: np.ndarray,
        action: np.ndarray,
        surrogate_reward: np.ndarray,
        reward: np.ndarray,
        obs_list: np.ndarray,
        pscore: np.ndarray,
        p_o: float,
        beta: float,
        s_sum: Optional[np.ndarray] = None,
        action_dist: Optional[np.ndarray] = None,
    ) -> np.ndarray:
        context = torch.tensor(context, dtype=torch.float32)
        action = torch.tensor(action, dtype=torch.long)
        surrogate_reward = torch.tensor(surrogate_reward, dtype=torch.float32)
        reward = torch.tensor(reward, dtype=torch.float32)
        obs_list = torch.tensor(obs_list, dtype=torch.float32)
        pscore = torch.tensor(pscore, dtype=torch.float32)
        s_sum = torch.tensor(s_sum, dtype=torch.float32)
        action_dist = torch.tensor(action_dist, dtype=torch.float32)
        action_dist = torch.squeeze(action_dist, -1)

        n_rounds = context.shape[0]
        idx_tensor = torch.arange(action.shape[0], dtype=torch.long)
        if p_o != 0.0 and 0 < beta < 1:
            q_hat = self.q_func_estimator.predict(context)
            q_hat_with_s = self.q_func_with_s_estimator.predict(context, surrogate_reward)
            q_hat_factual = q_hat[idx_tensor, action]
            q_hat_factual_with_s = q_hat_with_s[idx_tensor, action]
            iw = action_dist[idx_tensor, action] / pscore
            pw = obs_list / p_o
            obs_policy_val = pw * iw * (reward - q_hat_factual_with_s)
            policy_val = iw * (q_hat_factual_with_s - q_hat_factual)
            policy_val += obs_policy_val
            policy_val += torch.sum(q_hat * action_dist, dim=1)
            
            f_hat = self.f_func_estimator.predict(context)
            f_hat_factual = f_hat[idx_tensor, action]
            side_policy_val = iw * (s_sum - f_hat_factual)
            side_policy_val += torch.sum(f_hat * action_dist, dim=1)
            
            policy_val = beta * side_policy_val + (1 - beta) * policy_val
        elif p_o != 0.0 and beta == 0:
            q_hat = self.q_func_estimator.predict(context)
            q_hat_with_s = self.q_func_with_s_estimator.predict(context, surrogate_reward)
            q_hat_factual = q_hat[idx_tensor, action]
            q_hat_factual_with_s = q_hat_with_s[idx_tensor, action]
            iw = action_dist[idx_tensor, action] / pscore
            pw = obs_list / p_o
            obs_policy_val = pw * iw * (reward - q_hat_factual_with_s)
            policy_val = iw * (q_hat_factual_with_s - q_hat_factual)
            policy_val += obs_policy_val
            policy_val += torch.sum(q_hat * action_dist, dim=1)
        else:
            iw = action_dist[idx_tensor, action] / pscore
            f_hat = self.f_func_estimator.predict(context)
            f_hat_factual = f_hat[idx_tensor, action]
            side_policy_val = iw * (s_sum - f_hat_factual)
            side_policy_val += torch.sum(f_hat * action_dist, dim=1)
            policy_val = side_policy_val

        return policy_val.mean().item()
        

    def sample_action(
        self,
        context: np.ndarray,
        tau: Union[int, float] = 1.0,
        random_state: Optional[int] = None,
    ) -> np.ndarray:
        """Sample a ranking of (non-repetitive) actions from the Plackett-Luce ranking distribution.

        Note
        --------
        This `sample_action` method samples a **non-repetitive** ranking of actions for new data
        :math:`x \\in \\mathcal{X}` via the so-called "Gumbel Softmax trick" as follows.

        .. math::

            \\s (x,a) = \\hat{f}(x,a) / \\tau + \\gamma_{x,a}, \\quad \\gamma_{x,a} \\sim \\mathrm{Gumbel}(0,1)

        :math:`\\tau` is a temperature hyperparameter.
        :math:`f: \\mathcal{X} \\times \\mathcal{A} \\times \\mathcal{K} \\rightarrow \\mathbb{R}_{+}`
        is a scoring function which is now implemented in the `predict_score` method.
        When `len_list > 0`,  the expected rewards estimated at different positions will be averaged to form :math:`f(x,a)`.
        :math:`\\gamma_{x,a}` is a random variable sampled from the Gumbel distribution.
        By sorting the actions based on :math:`\\s (x,a)` for each context, we can efficiently sample a ranking from
        the Plackett-Luce ranking distribution.

        Parameters
        ----------------
        context: array-like, shape (n_rounds_of_new_data, dim_context)
            Context vectors for new data.

        tau: int or float, default=1.0
            A temperature parameter that controls the randomness of the action choice
            by scaling the scores before applying softmax.
            As :math:`\\tau \\rightarrow \\infty`, the algorithm will select arms uniformly at random.

        random_state: int, default=None
            Controls the random seed in sampling actions.

        Returns
        -----------
        sampled_action: array-like, shape (n_rounds_of_new_data, n_actions, len_list)
            Ranking of actions sampled from the Plackett-Luce ranking distribution via the Gumbel softmax trick.

        """
        check_array(array=context, name="context", expected_dim=2)
        check_scalar(tau, name="tau", target_type=(int, float), min_val=0)
        if context.shape[1] != self.dim_context:
            raise ValueError(
                "Expected `context.shape[1] == self.dim_context`, but found it False"
            )

        n = context.shape[0]
        random_ = check_random_state(random_state)
        sampled_action = np.zeros((n, self.n_actions, self.len_list))
        scores = self.predict_proba(context=context).mean(2) / tau
        scores += random_.gumbel(size=scores.shape)
        ranking = np.argsort(-scores, axis=1)
        for p in np.arange(self.len_list):
            sampled_action[np.arange(n), ranking[:, p], p] = 1
        return sampled_action

    def predict_proba(
        self,
        context: np.ndarray,
    ) -> np.ndarray:
        """Obtains action choice probabilities for new data.

        Note
        --------
        This policy uses multi-layer perceptron (MLP) and the softmax function as the last layer.
        This is a stochastic policy and represented as follows:

        .. math::

            \\pi_\\theta (a \\mid x) = \\frac{\\exp(f_\\theta(x, a))}{\\sum_{a' \\in \\mathcal{A}} \\exp(f_\\theta(x, a'))}

        where :math:`f__\\theta(x, a)` is MLP with parameter :math:`\\theta`.

        Parameters
        ----------------
        context: array-like, shape (n_rounds_of_new_data, dim_context)
            Context vectors for new data.

        Returns
        -----------
        choice_prob: array-like, shape (n_rounds_of_new_data, n_actions, len_list)
            Action choice probabilities obtained by a trained classifier.

        """
        check_array(array=context, name="context", expected_dim=2)
        if context.shape[1] != self.dim_context:
            raise ValueError(
                "Expected `context.shape[1] == self.dim_context`, but found it False"
            )

        self.nn_model.eval()
        x = torch.from_numpy(context).float()
        y = self.nn_model(x).detach().numpy()
        return y[:, :, np.newaxis]



@dataclass
class NNPolicyDataset(torch.utils.data.Dataset):
    """PyTorch dataset for NNPolicyLearner"""

    context: np.ndarray
    action: np.ndarray
    surrogate_reward: np.ndarray
    reward: np.ndarray
    obs_list: np.ndarray
    pscore: np.ndarray
    # position: np.ndarray

    def __post_init__(self):
        """initialize class"""
        assert (
            self.context.shape[0]
            == self.action.shape[0]
            == self.surrogate_reward.shape[0]
            == self.reward.shape[0]
            == self.obs_list.shape[0]
            == self.pscore.shape[0]
            # == self.position.shape[0]
        )

    def __getitem__(self, index):
        return (
            self.context[index],
            self.action[index],
            self.surrogate_reward[index],
            self.reward[index],
            self.obs_list[index],
            self.pscore[index],
            # self.position[index],
        )

    def __len__(self):
        return self.context.shape[0]

@dataclass
class NNPolicyDataset_with_fs(torch.utils.data.Dataset):
    """PyTorch dataset for NNPolicyLearner"""

    context: np.ndarray
    action: np.ndarray
    surrogate_reward: np.ndarray
    s_sum: np.ndarray
    reward: np.ndarray
    obs_list: np.ndarray
    pscore: np.ndarray
    # position: np.ndarray

    def __post_init__(self):
        """initialize class"""
        assert (
            self.context.shape[0]
            == self.action.shape[0]
            == self.surrogate_reward.shape[0]
            == self.s_sum.shape[0]
            == self.reward.shape[0]
            == self.obs_list.shape[0]
            == self.pscore.shape[0]
            # == self.position.shape[0]
        )

    def __getitem__(self, index):
        return (
            self.context[index],
            self.action[index],
            self.surrogate_reward[index],
            self.s_sum[index],
            self.reward[index],
            self.obs_list[index],
            self.pscore[index],
            # self.position[index],
        )

    def __len__(self):
        return self.context.shape[0]

@dataclass
class QFuncEstimator:
    """Q-function estimator based on a neural network.

    Note
    --------
    The neural network is implemented in PyTorch.

    Parameters
    -----------
    n_actions: int
        Number of actions.

    dim_context: int
        Number of dimensions of context vectors.

    hidden_layer_size: Tuple[int, ...], default = (100,)
        The i-th element specifies the size of the i-th layer.

    activation: str, default='relu'
        Activation function.
        Must be one of the followings:
        - 'identity', the identity function, :math:`f(x) = x`.
        - 'logistic', the sigmoid function, :math:`f(x) = \\frac{1}{1 + \\exp(x)}`.
        - 'tanh', the hyperbolic tangent function, `:math:f(x) = \\frac{\\exp(x) - \\exp(-x)}{\\exp(x) + \\exp(-x)}`
        - 'relu', the rectified linear unit function, `:math:f(x) = \\max(0, x)`

    solver: str, default='adam'
        Optimizer of the neural network.
        Must be one of the followings:
        - 'sgd', Stochastic Gradient Descent.
        - 'adam', Adam (Kingma and Ba 2014).
        - 'adagrad', Adagrad (Duchi et al. 2011).

    alpha: float, default=0.001
        L2 penalty.

    batch_size: Union[int, str], default="auto"
        Batch size for SGD, Adagrad, and Adam.
        If "auto", the maximum of 200 and the number of samples is used.
        If integer, must be positive.

    learning_rate_init: int, default=0.0001
        Initial learning rate for SGD, Adagrad, and Adam.

    max_iter: int, default=200
        Number of epochs for SGD, Adagrad, and Adam.

    shuffle: bool, default=True
        Whether to shuffle samples in SGD and Adam.

    random_state: Optional[int], default=None
        Controls the random seed.

    tol: float, default=1e-4
        Tolerance for training.
        When the Training Gradient is not improved at least `tol' for `n_iter_no_change' consecutive iterations,
        training is stopped.

    momentum: float, default=0.9
        Momentum for SGD.
        Must be in the range of [0., 1.].

    nesterovs_momentum: bool, default=True
        Whether to use Nesterov momentum.

    early_stopping: bool, default=False
        Whether to use early stopping for SGD, Adagrad, and Adam.
        If set to true, `validation_fraction' of training data is used as validation data,
        and training is stopped when the Validation Gradient is not improved at least `tol' for `n_iter_no_change' consecutive iterations.

    validation_fraction: float, default=0.1
        Fraction of validation data when early stopping is used.
        Must be in the range of (0., 1.].

    beta_1: float, default=0.9
        Coefficient used for computing running average of gradient for Adam.
        Must be in the range of [0., 1.].

    beta_2: float, default=0.999
        Coefficient used for computing running average of the square of gradient for Adam.
        Must be in the range of [0., 1.].

    epsilon: float, default=1e-8
        Term for numerical stability in Adam.

    n_iter_no_change: int, default=10
        Maximum number of not improving epochs when early stopping is used.

    References
    ------------
    Dong .C. Liu and Jorge Nocedal.
    "On the Limited Memory Method for Large Scale Optimization.", 1989.

    Diederik P. Kingma and Jimmy Ba.
    "Adam: A Method for Stochastic Optimization.", 2014.

    John Duchi, Elad Hazan, and Yoram Singer.
    "Adaptive Subgradient Methods for Online Learning and Stochastic Optimization", 2011.

    """

    n_actions: int
    dim_context: int
    hidden_layer_size: Tuple[int, ...] = (150, )
    activation: str = "relu"
    solver: str = "adam"
    alpha: float = 0.0001
    batch_size: Union[int, str] = "auto"
    learning_rate_init: float = 0.0001
    max_iter: int = 200
    shuffle: bool = True
    random_state: Optional[int] = None
    tol: float = 1e-4
    momentum: float = 0.9
    nesterovs_momentum: bool = True
    early_stopping: bool = False
    validation_fraction: float = 0.1
    beta_1: float = 0.9
    beta_2: float = 0.999
    epsilon: float = 1e-8
    n_iter_no_change: int = 10

    def __post_init__(self) -> None:
        """Initialize class."""
        check_scalar(self.dim_context, "dim_context", int, min_val=1)

        if not isinstance(self.hidden_layer_size, tuple) or any(
            [not isinstance(h, int) or h <= 0 for h in self.hidden_layer_size]
        ):
            raise ValueError(
                f"`hidden_layer_size` must be a tuple of positive integers, but {self.hidden_layer_size} is given"
            )

        if self.solver not in ("adagrad", "sgd", "adam"):
            raise ValueError(
                f"`solver` must be one of 'adam', 'adagrad', or 'sgd', but {self.solver} is given"
            )

        check_scalar(self.alpha, "alpha", float, min_val=0.0)

        if self.batch_size != "auto" and (
            not isinstance(self.batch_size, int) or self.batch_size <= 0
        ):
            raise ValueError(
                f"`batch_size` must be a positive integer or 'auto', but {self.batch_size} is given"
            )

        check_scalar(self.learning_rate_init, "learning_rate_init", float)
        if self.learning_rate_init <= 0.0:
            raise ValueError(
                f"`learning_rate_init`= {self.learning_rate_init}, must be > 0.0"
            )

        check_scalar(self.max_iter, "max_iter", int, min_val=1)

        if not isinstance(self.shuffle, bool):
            raise ValueError(f"`shuffle` must be a bool, but {self.shuffle} is given")

        check_scalar(self.tol, "tol", float)
        if self.tol <= 0.0:
            raise ValueError(f"`tol`= {self.tol}, must be > 0.0")

        check_scalar(self.momentum, "momentum", float, min_val=0.0, max_val=1.0)

        if not isinstance(self.nesterovs_momentum, bool):
            raise ValueError(
                f"`nesterovs_momentum` must be a bool, but {self.nesterovs_momentum} is given"
            )

        if not isinstance(self.early_stopping, bool):
            raise ValueError(
                f"`early_stopping` must be a bool, but {self.early_stopping} is given"
            )

        check_scalar(
            self.validation_fraction, "validation_fraction", float, max_val=1.0
        )
        if self.validation_fraction <= 0.0:
            raise ValueError(
                f"`validation_fraction`= {self.validation_fraction}, must be > 0.0"
            )

        check_scalar(self.beta_1, "beta_1", float, min_val=0.0, max_val=1.0)
        check_scalar(self.beta_2, "beta_2", float, min_val=0.0, max_val=1.0)
        check_scalar(self.epsilon, "epsilon", float, min_val=0.0)
        check_scalar(self.n_iter_no_change, "n_iter_no_change", int, min_val=1)

        if self.random_state is not None:
            self.random_ = check_random_state(self.random_state)
            torch.manual_seed(self.random_state)

        if self.activation == "identity":
            activation_layer = nn.Identity
        elif self.activation == "logistic":
            activation_layer = nn.Sigmoid
        elif self.activation == "tanh":
            activation_layer = nn.Tanh
        elif self.activation == "relu":
            activation_layer = nn.ReLU
        elif self.activation == "elu":
            activation_layer = nn.ELU
        else:
            raise ValueError(
                "`activation` must be one of 'identity', 'logistic', 'tanh', 'relu', or 'elu'"
                f", but {self.activation} is given"
            )

        layer_list = []
        input_size = self.dim_context

        for i, h in enumerate(self.hidden_layer_size):
            layer_list.append(("l{}".format(i), nn.Linear(input_size, h)))
            layer_list.append(("a{}".format(i), activation_layer()))
            input_size = h
        layer_list.append(("output", nn.Linear(input_size, self.n_actions)))

        self.nn_model = nn.Sequential(OrderedDict(layer_list))

    def _create_train_data_for_q_func_estimation(
        self,
        context: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        **kwargs,
    ) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.DataLoader]]:
        """Create training data for off-policy learning.

        Parameters
        -----------
        context: array-like, shape (n_rounds, dim_context)
            Context vectors observed for each data, i.e., :math:`x_i`.

        action: array-like, shape (n_rounds,)
            Actions sampled by the logging/behavior policy for each data in logged bandit data, i.e., :math:`a_i`.

        reward: array-like, shape (n_rounds,)
            Rewards observed for each data in logged bandit data, i.e., :math:`r_i`.

        Returns
        --------
        (training_data_loader, validation_data_loader): Tuple[DataLoader, Optional[DataLoader]]
            Training and validation data loaders in PyTorch

        """
        if self.batch_size == "auto":
            batch_size_ = min(200, context.shape[0])
        else:
            check_scalar(self.batch_size, "batch_size", int, min_val=1)
            batch_size_ = self.batch_size
        context = context.astype('float32')
        dataset = QFuncEstimatorDataset(
            torch.from_numpy(context).float(),
            torch.from_numpy(action).long(),
            torch.from_numpy(reward).float(),
        )

        if self.early_stopping:
            if context.shape[0] <= 1:
                raise ValueError(
                    f"the number of samples is too small ({context.shape[0]}) to create validation data"
                )

            validation_size = max(int(context.shape[0] * self.validation_fraction), 1)
            training_size = context.shape[0] - validation_size
            training_dataset, validation_dataset = torch.utils.data.random_split(
                dataset, [training_size, validation_size]
            )
            training_data_loader = torch.utils.data.DataLoader(
                training_dataset,
                batch_size=batch_size_,
                shuffle=self.shuffle,
            )
            validation_data_loader = torch.utils.data.DataLoader(
                validation_dataset,
                batch_size=batch_size_,
                shuffle=self.shuffle,
            )

            return training_data_loader, validation_data_loader

        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size_,
            shuffle=self.shuffle,
        )

        return data_loader, None

    def fit(
        self,
        context: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
    ) -> None:
        """Fits an offline bandit policy on the given logged bandit data.

        Parameters
        -----------
        context: array-like, shape (n_rounds, dim_context)
            Context vectors observed for each data, i.e., :math:`x_i`.

        action: array-like, shape (n_rounds,)
            Actions sampled by the logging/behavior policy for each data in logged bandit data, i.e., :math:`a_i`.

        reward: array-like, shape (n_rounds,)
            Rewards observed for each data in logged bandit data, i.e., :math:`r_i`.

        """
        # check_bandit_feedback_inputs(
        #     context=context,
        #     action=action,
        #     reward=reward,
        # )

        if context.shape[1] != self.dim_context:
            raise ValueError(
                "Expected `context.shape[1] == self.dim_context`, but found it False"
            )

        if self.solver == "sgd":
            optimizer = optim.SGD(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                momentum=self.momentum,
                weight_decay=self.alpha,
                nesterov=self.nesterovs_momentum,
            )
        elif self.solver == "adagrad":
            optimizer = optim.Adagrad(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                eps=self.epsilon,
                weight_decay=self.alpha,
            )
        elif self.solver == "adam":
            optimizer = optim.Adam(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                betas=(self.beta_1, self.beta_2),
                eps=self.epsilon,
                weight_decay=self.alpha,
            )
        else:
            raise NotImplementedError(
                "`solver` must be one of 'adam', 'adagrad', or 'sgd'"
            )

        (
            training_data_loader,
            validation_data_loader,
        ) = self._create_train_data_for_q_func_estimation(
            context,
            action,
            reward,
        )

        n_not_improving_training = 0
        previous_training_loss = None
        n_not_improving_validation = 0
        previous_validation_loss = None
        training_losses = []
        validation_losses = []
        validation_steps = []
        for _ in tqdm(np.arange(self.max_iter), desc="q-func learning"):
            self.nn_model.train()
            training_loss = []
            validation_loss = []
            validation_step = []
            for x, a, r in training_data_loader:
                optimizer.zero_grad()
                q_hat = self.nn_model(x)[torch.arange(a.shape[0], dtype=torch.long), a]
                loss = mse_loss(r, q_hat)
                loss.backward()
                optimizer.step()

                loss_value = loss.item()
                training_loss.append(loss_value)
                if previous_training_loss is not None:
                    if loss_value - previous_training_loss < self.tol:
                        n_not_improving_training += 1
                    else:
                        n_not_improving_training = 0
                if n_not_improving_training >= self.n_iter_no_change:
                    break
                previous_training_loss = loss_value
            
            average_training_loss = sum(training_loss) / len(training_loss)
            training_losses.append(average_training_loss)

            if self.early_stopping:
                self.nn_model.eval()
                for x, a, r in validation_data_loader:
                    q_hat = self.nn_model(x)[
                        torch.arange(a.shape[0], dtype=torch.long), a
                    ]
                    loss = mse_loss(r, q_hat)
                    loss_value = loss.item()
                    validation_loss.append(loss_value)
                    if previous_validation_loss is not None:
                        if loss_value - previous_validation_loss < self.tol:
                            n_not_improving_validation += 1
                        else:
                            n_not_improving_validation = 0
                    if n_not_improving_validation > self.n_iter_no_change:
                        break
                    previous_validation_loss = loss_value
                average_validation_loss = sum(validation_loss) / len(validation_loss)
                validation_losses.append(average_validation_loss)
                
        # if self.early_stopping:
        #     plt.figure(figsize=(12, 6))
        #     plt.plot(range(len(training_losses)), training_losses, label='QFunc Training Gradient')
        #     plt.plot(range(len(validation_losses)), validation_losses, label='QFunc Validation Gradient')
        #     plt.xlabel('Epochs')
        #     plt.ylabel('Loss')
        #     plt.title('Learning Curve')
        #     plt.legend()
        #     plt.grid(True)
        #     plt.show()
        # else:
        #     plt.figure(figsize=(12, 6))
        #     plt.plot(range(len(training_losses)), training_losses, label='QFunc Training Gradient')
        #     plt.xlabel('Epochs')
        #     plt.ylabel('Loss')
        #     plt.title('Learning Curve')
        #     plt.legend()
        #     plt.grid(True)
        #     plt.show()

    def predict(
        self,
        context: torch.Tensor,
    ) -> torch.Tensor:
        """Predict best continuous actions for new data.

        Parameters
        -----------
        context: Tensor, shape (n_rounds_of_new_data, dim_context)
            Context vectors for new data.

        Returns
        -----------
        estimated_expected_rewards: Tensor, shape (n_rounds_of_new_data,)
            Expected rewards given context and action for new data estimated by the regression model.

        """
        check_tensor(tensor=context, name="context", expected_dim=2)
        if context.shape[1] != self.dim_context:
            raise ValueError(
                "Expected `context.shape[1] == self.dim_context`, but found it False"
            )

        self.nn_model.eval()
        return self.nn_model(context)

@dataclass
class QFuncEstimatorDataset(torch.utils.data.Dataset):
    """PyTorch dataset for QFuncEstimator"""

    feature: np.ndarray
    action: np.ndarray
    reward: np.ndarray

    def __post_init__(self):
        """initialize class"""
        assert self.feature.shape[0] == self.action.shape[0] == self.reward.shape[0]

    def __getitem__(self, index):
        return (
            self.feature[index],
            self.action[index],
            self.reward[index],
        )

    def __len__(self):
        return self.feature.shape[0]


@dataclass
class QFuncEstimatorWithS:
    """Q-function estimator based on a neural network.

    Note
    --------
    The neural network is implemented in PyTorch.

    Parameters
    -----------
    n_actions: int
        Number of actions.

    dim_context: int
        Number of dimensions of context vectors.

    hidden_layer_size: Tuple[int, ...], default = (100,)
        The i-th element specifies the size of the i-th layer.

    activation: str, default='relu'
        Activation function.
        Must be one of the followings:
        - 'identity', the identity function, :math:`f(x) = x`.
        - 'logistic', the sigmoid function, :math:`f(x) = \\frac{1}{1 + \\exp(x)}`.
        - 'tanh', the hyperbolic tangent function, `:math:f(x) = \\frac{\\exp(x) - \\exp(-x)}{\\exp(x) + \\exp(-x)}`
        - 'relu', the rectified linear unit function, `:math:f(x) = \\max(0, x)`

    solver: str, default='adam'
        Optimizer of the neural network.
        Must be one of the followings:
        - 'sgd', Stochastic Gradient Descent.
        - 'adam', Adam (Kingma and Ba 2014).
        - 'adagrad', Adagrad (Duchi et al. 2011).

    alpha: float, default=0.001
        L2 penalty.

    batch_size: Union[int, str], default="auto"
        Batch size for SGD, Adagrad, and Adam.
        If "auto", the maximum of 200 and the number of samples is used.
        If integer, must be positive.

    learning_rate_init: int, default=0.0001
        Initial learning rate for SGD, Adagrad, and Adam.

    max_iter: int, default=200
        Number of epochs for SGD, Adagrad, and Adam.

    shuffle: bool, default=True
        Whether to shuffle samples in SGD and Adam.

    random_state: Optional[int], default=None
        Controls the random seed.

    tol: float, default=1e-4
        Tolerance for training.
        When the Training Gradient is not improved at least `tol' for `n_iter_no_change' consecutive iterations,
        training is stopped.

    momentum: float, default=0.9
        Momentum for SGD.
        Must be in the range of [0., 1.].

    nesterovs_momentum: bool, default=True
        Whether to use Nesterov momentum.

    early_stopping: bool, default=False
        Whether to use early stopping for SGD, Adagrad, and Adam.
        If set to true, `validation_fraction' of training data is used as validation data,
        and training is stopped when the Validation Gradient is not improved at least `tol' for `n_iter_no_change' consecutive iterations.

    validation_fraction: float, default=0.1
        Fraction of validation data when early stopping is used.
        Must be in the range of (0., 1.].

    beta_1: float, default=0.9
        Coefficient used for computing running average of gradient for Adam.
        Must be in the range of [0., 1.].

    beta_2: float, default=0.999
        Coefficient used for computing running average of the square of gradient for Adam.
        Must be in the range of [0., 1.].

    epsilon: float, default=1e-8
        Term for numerical stability in Adam.

    n_iter_no_change: int, default=10
        Maximum number of not improving epochs when early stopping is used.

    References
    ------------
    Dong .C. Liu and Jorge Nocedal.
    "On the Limited Memory Method for Large Scale Optimization.", 1989.

    Diederik P. Kingma and Jimmy Ba.
    "Adam: A Method for Stochastic Optimization.", 2014.

    John Duchi, Elad Hazan, and Yoram Singer.
    "Adaptive Subgradient Methods for Online Learning and Stochastic Optimization", 2011.

    """

    n_actions: int
    dim_context: int
    s_dim: int
    hidden_layer_size: Tuple[int, ...] = (150, )
    activation: str = "relu"
    solver: str = "adam"
    alpha: float = 0.0001
    batch_size: Union[int, str] = "auto"
    learning_rate_init: float = 0.0001
    max_iter: int = 200
    shuffle: bool = True
    random_state: Optional[int] = None
    tol: float = 1e-4
    momentum: float = 0.9
    nesterovs_momentum: bool = True
    early_stopping: bool = False
    validation_fraction: float = 0.1
    beta_1: float = 0.9
    beta_2: float = 0.999
    epsilon: float = 1e-8
    n_iter_no_change: int = 10
    confidence_threshold: float = 0.9

    def __post_init__(self) -> None:
        """Initialize class."""
        check_scalar(self.dim_context, "dim_context", int, min_val=1)

        if not isinstance(self.hidden_layer_size, tuple) or any(
            [not isinstance(h, int) or h <= 0 for h in self.hidden_layer_size]
        ):
            raise ValueError(
                f"`hidden_layer_size` must be a tuple of positive integers, but {self.hidden_layer_size} is given"
            )

        if self.solver not in ("adagrad", "sgd", "adam"):
            raise ValueError(
                f"`solver` must be one of 'adam', 'adagrad', or 'sgd', but {self.solver} is given"
            )

        check_scalar(self.alpha, "alpha", float, min_val=0.0)

        if self.batch_size != "auto" and (
            not isinstance(self.batch_size, int) or self.batch_size <= 0
        ):
            raise ValueError(
                f"`batch_size` must be a positive integer or 'auto', but {self.batch_size} is given"
            )

        check_scalar(self.learning_rate_init, "learning_rate_init", float)
        if self.learning_rate_init <= 0.0:
            raise ValueError(
                f"`learning_rate_init`= {self.learning_rate_init}, must be > 0.0"
            )

        check_scalar(self.max_iter, "max_iter", int, min_val=1)

        if not isinstance(self.shuffle, bool):
            raise ValueError(f"`shuffle` must be a bool, but {self.shuffle} is given")

        check_scalar(self.tol, "tol", float)
        if self.tol <= 0.0:
            raise ValueError(f"`tol`= {self.tol}, must be > 0.0")

        check_scalar(self.momentum, "momentum", float, min_val=0.0, max_val=1.0)

        if not isinstance(self.nesterovs_momentum, bool):
            raise ValueError(
                f"`nesterovs_momentum` must be a bool, but {self.nesterovs_momentum} is given"
            )

        if not isinstance(self.early_stopping, bool):
            raise ValueError(
                f"`early_stopping` must be a bool, but {self.early_stopping} is given"
            )

        check_scalar(
            self.validation_fraction, "validation_fraction", float, max_val=1.0
        )
        if self.validation_fraction <= 0.0:
            raise ValueError(
                f"`validation_fraction`= {self.validation_fraction}, must be > 0.0"
            )

        check_scalar(self.beta_1, "beta_1", float, min_val=0.0, max_val=1.0)
        check_scalar(self.beta_2, "beta_2", float, min_val=0.0, max_val=1.0)
        check_scalar(self.epsilon, "epsilon", float, min_val=0.0)
        check_scalar(self.n_iter_no_change, "n_iter_no_change", int, min_val=1)

        if self.random_state is not None:
            self.random_ = check_random_state(self.random_state)
            torch.manual_seed(self.random_state)

        if self.activation == "identity":
            activation_layer = nn.Identity
        elif self.activation == "logistic":
            activation_layer = nn.Sigmoid
        elif self.activation == "tanh":
            activation_layer = nn.Tanh
        elif self.activation == "relu":
            activation_layer = nn.ReLU
        elif self.activation == "elu":
            activation_layer = nn.ELU
        else:
            raise ValueError(
                "`activation` must be one of 'identity', 'logistic', 'tanh', 'relu', or 'elu'"
                f", but {self.activation} is given"
            )

        layer_list = []
        input_size = self.s_dim

        for i, h in enumerate(self.hidden_layer_size):
            layer_list.append(("l{}".format(i), nn.Linear(input_size, h)))
            layer_list.append(("a{}".format(i), activation_layer()))
            input_size = h
        layer_list.append(("output", nn.Linear(input_size, self.n_actions)))

        self.nn_model = nn.Sequential(OrderedDict(layer_list))
        self.dim_context=self.s_dim

    def _create_train_data_for_q_func_estimation(
        self,
        context: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        **kwargs,
    ) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.DataLoader]]:
        """Create training data for off-policy learning.

        Parameters
        -----------
        context: array-like, shape (n_rounds, dim_context)
            Context vectors observed for each data, i.e., :math:`x_i`.

        action: array-like, shape (n_rounds,)
            Actions sampled by the logging/behavior policy for each data in logged bandit data, i.e., :math:`a_i`.

        reward: array-like, shape (n_rounds,)
            Rewards observed for each data in logged bandit data, i.e., :math:`r_i`.

        Returns
        --------
        (training_data_loader, validation_data_loader): Tuple[DataLoader, Optional[DataLoader]]
            Training and validation data loaders in PyTorch

        """
        if self.batch_size == "auto":
            batch_size_ = min(200, context.shape[0])
        else:
            check_scalar(self.batch_size, "batch_size", int, min_val=1)
            batch_size_ = self.batch_size
        context = context.astype('float32')
        dataset = QFuncEstimatorDataset(
            torch.from_numpy(context).float(),
            torch.from_numpy(action).long(),
            torch.from_numpy(reward).float(),
        )

        if self.early_stopping:
            if context.shape[0] <= 1:
                raise ValueError(
                    f"the number of samples is too small ({context.shape[0]}) to create validation data"
                )

            validation_size = max(int(context.shape[0] * self.validation_fraction), 1)
            training_size = context.shape[0] - validation_size
            training_dataset, validation_dataset = torch.utils.data.random_split(
                dataset, [training_size, validation_size]
            )
            training_data_loader = torch.utils.data.DataLoader(
                training_dataset,
                batch_size=batch_size_,
                shuffle=self.shuffle,
            )
            validation_data_loader = torch.utils.data.DataLoader(
                validation_dataset,
                batch_size=batch_size_,
                shuffle=self.shuffle,
            )

            return training_data_loader, validation_data_loader

        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size_,
            shuffle=self.shuffle,
        )

        return data_loader, None

    def fit(
        self, 
        context: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        surrogate_reward: np.ndarray,
        obs_list: np.ndarray,
    ) -> None:
        """Fits an offline bandit policy on the given logged bandit data.

        Parameters
        -----------
        context: array-like, shape (n_rounds, dim_context)
            Context vectors observed for each data, i.e., :math:`x_i`.

        action: array-like, shape (n_rounds,)
            Actions sampled by the logging/behavior policy for each data in logged bandit data, i.e., :math:`a_i`.

        reward: array-like, shape (n_rounds,)
            Rewards observed for each data in logged bandit data, i.e., :math:`r_i`.

        """
        # check_bandit_feedback_inputs(
        #     context=context,
        #     action=action,
        #     reward=reward,
        # )
        context=surrogate_reward

        if context.shape[1] != self.dim_context:
            raise ValueError(
                "Expected `context.shape[1] == self.dim_context`, but found it False"
            )

        if self.solver == "sgd":
            optimizer = optim.SGD(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                momentum=self.momentum,
                weight_decay=self.alpha,
                nesterov=self.nesterovs_momentum,
            )
        elif self.solver == "adagrad":
            optimizer = optim.Adagrad(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                eps=self.epsilon,
                weight_decay=self.alpha,
            )
        elif self.solver == "adam":
            optimizer = optim.Adam(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                betas=(self.beta_1, self.beta_2),
                eps=self.epsilon,
                weight_decay=self.alpha,
            )
        else:
            raise NotImplementedError(
                "`solver` must be one of 'adam', 'adagrad', or 'sgd'"
            )

        (
            training_data_loader,
            validation_data_loader,
        ) = self._create_train_data_for_q_func_estimation(
            context,
            action,
            reward,
        )

        n_not_improving_training = 0
        previous_training_loss = None
        n_not_improving_validation = 0
        previous_validation_loss = None
        training_losses = []
        validation_losses = []
        validation_steps = []
        for _ in tqdm(np.arange(self.max_iter), desc="q-func learning"):
            self.nn_model.train()
            training_loss = []
            validation_loss = []
            validation_step = []
            for x, a, r in training_data_loader:
                optimizer.zero_grad()
                q_hat = self.nn_model(x)[torch.arange(a.shape[0], dtype=torch.long), a]
                loss = mse_loss(r, q_hat)
                loss.backward()
                optimizer.step()

                loss_value = loss.item()
                training_loss.append(loss_value)
                if previous_training_loss is not None:
                    if loss_value - previous_training_loss < self.tol:
                        n_not_improving_training += 1
                    else:
                        n_not_improving_training = 0
                if n_not_improving_training >= self.n_iter_no_change:
                    break
                previous_training_loss = loss_value
            
            average_training_loss = sum(training_loss) / len(training_loss)
            training_losses.append(average_training_loss)

            if self.early_stopping:
                self.nn_model.eval()
                for x, a, r in validation_data_loader:
                    q_hat = self.nn_model(x)[
                        torch.arange(a.shape[0], dtype=torch.long), a
                    ]
                    loss = mse_loss(r, q_hat)
                    loss_value = loss.item()
                    validation_loss.append(loss_value)
                    if previous_validation_loss is not None:
                        if loss_value - previous_validation_loss < self.tol:
                            n_not_improving_validation += 1
                        else:
                            n_not_improving_validation = 0
                    if n_not_improving_validation > self.n_iter_no_change:
                        break
                    previous_validation_loss = loss_value
                average_validation_loss = sum(validation_loss) / len(validation_loss)
                validation_losses.append(average_validation_loss)
                
        # if self.early_stopping:
        #     plt.figure(figsize=(12, 6))
        #     plt.plot(range(len(training_losses)), training_losses, label='QFunc Training Gradient')
        #     plt.plot(range(len(validation_losses)), validation_losses, label='QFunc Validation Gradient')
        #     plt.xlabel('Epochs')
        #     plt.ylabel('Loss')
        #     plt.title('Learning Curve')
        #     plt.legend()
        #     plt.grid(True)
        #     plt.show()
        # else:
        #     plt.figure(figsize=(12, 6))
        #     plt.plot(range(len(training_losses)), training_losses, label='QFunc Training Gradient')
        #     plt.xlabel('Epochs')
        #     plt.ylabel('Loss')
        #     plt.title('Learning Curve')
        #     plt.legend()
        #     plt.grid(True)
        #     plt.show()

    def predict(self, context: torch.Tensor, surrogate_reward: torch.Tensor) -> torch.Tensor:

        check_tensor(tensor=surrogate_reward, name="surrogate_reward", expected_dim=2)
        if surrogate_reward.shape[1] != self.s_dim:
            raise ValueError(
                "Expected `surrogate_reward.shape[1] == self.s_dim`, but found it False"
            )

        self.nn_model.eval()
        return self.nn_model(surrogate_reward)

@dataclass
class QFuncEstimatorWithS_sol:
    """Q-function estimator based on a neural network using surrogate reward as input."""

    n_actions: int
    dim_context: int
    s_dim: int
    hidden_layer_size: Tuple[int, ...] = (150, )
    activation: str = "relu"
    solver: str = "adam"
    alpha: float = 0.0001
    batch_size: Union[int, str] = "auto"
    learning_rate_init: float = 0.0001
    max_iter: int = 200
    shuffle: bool = True
    random_state: Optional[int] = None
    tol: float = 1e-4
    momentum: float = 0.9
    nesterovs_momentum: bool = True
    early_stopping: bool = False
    validation_fraction: float = 0.1
    beta_1: float = 0.9
    beta_2: float = 0.999
    epsilon: float = 1e-8
    n_iter_no_change: int = 10
    confidence_threshold: float = 0.9

    def __post_init__(self) -> None:
        """Initialize class."""
        check_scalar(self.s_dim, "s_dim", int, min_val=1)

        if not isinstance(self.hidden_layer_size, tuple) or any(
            [not isinstance(h, int) or h <= 0 for h in self.hidden_layer_size]
        ):
            raise ValueError(
                f"`hidden_layer_size` must be a tuple of positive integers, but {self.hidden_layer_size} is given"
            )

        if self.solver not in ("adagrad", "sgd", "adam"):
            raise ValueError(
                f"`solver` must be one of 'adam', 'adagrad', or 'sgd', but {self.solver} is given"
            )

        check_scalar(self.alpha, "alpha", float, min_val=0.0)

        if self.batch_size != "auto" and (
            not isinstance(self.batch_size, int) or self.batch_size <= 0
        ):
            raise ValueError(
                f"`batch_size` must be a positive integer or 'auto', but {self.batch_size} is given"
            )

        check_scalar(self.learning_rate_init, "learning_rate_init", float)
        if self.learning_rate_init <= 0.0:
            raise ValueError(
                f"`learning_rate_init`= {self.learning_rate_init}, must be > 0.0"
            )

        check_scalar(self.max_iter, "max_iter", int, min_val=1)

        if not isinstance(self.shuffle, bool):
            raise ValueError(f"`shuffle` must be a bool, but {self.shuffle} is given")

        check_scalar(self.tol, "tol", float)
        if self.tol <= 0.0:
            raise ValueError(f"`tol`= {self.tol}, must be > 0.0")

        check_scalar(self.momentum, "momentum", float, min_val=0.0, max_val=1.0)

        if not isinstance(self.nesterovs_momentum, bool):
            raise ValueError(
                f"`nesterovs_momentum` must be a bool, but {self.nesterovs_momentum} is given"
            )

        if not isinstance(self.early_stopping, bool):
            raise ValueError(
                f"`early_stopping` must be a bool, but {self.early_stopping} is given"
            )

        check_scalar(
            self.validation_fraction, "validation_fraction", float, max_val=1.0
        )
        if self.validation_fraction <= 0.0:
            raise ValueError(
                f"`validation_fraction`= {self.validation_fraction}, must be > 0.0"
            )

        check_scalar(self.beta_1, "beta_1", float, min_val=0.0, max_val=1.0)
        check_scalar(self.beta_2, "beta_2", float, min_val=0.0, max_val=1.0)
        check_scalar(self.epsilon, "epsilon", float, min_val=0.0)
        check_scalar(self.n_iter_no_change, "n_iter_no_change", int, min_val=1)

        if self.random_state is not None:
            self.random_ = check_random_state(self.random_state)
            torch.manual_seed(self.random_state)

        if self.activation == "identity":
            activation_layer = nn.Identity
        elif self.activation == "logistic":
            activation_layer = nn.Sigmoid
        elif self.activation == "tanh":
            activation_layer = nn.Tanh
        elif self.activation == "relu":
            activation_layer = nn.ReLU
        elif self.activation == "elu":
            activation_layer = nn.ELU
        else:
            raise ValueError(
                "`activation` must be one of 'identity', 'logistic', 'tanh', 'relu', or 'elu'"
                f", but {self.activation} is given"
            )

        layer_list = []
        input_size = self.s_dim

        for i, h in enumerate(self.hidden_layer_size):
            layer_list.append(("l{}".format(i), nn.Linear(input_size, h)))
            layer_list.append(("a{}".format(i), activation_layer()))
            input_size = h
        layer_list.append(("output", nn.Linear(input_size, self.n_actions)))

        self.nn_model = nn.Sequential(OrderedDict(layer_list))

    def _create_train_data_for_q_func_estimation(
        self,
        surrogate_reward: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        **kwargs,
    ) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.DataLoader]]:
        """Create training data for off-policy learning.

        Parameters
        -----------
        surrogate_reward: array-like, shape (n_rounds, s_dim)
            Surrogate reward vectors observed for each data, i.e., :math:`s_i`.

        action: array-like, shape (n_rounds,)
            Actions sampled by the logging/behavior policy for each data in logged bandit data, i.e., :math:`a_i`.

        reward: array-like, shape (n_rounds,)
            Rewards observed for each data in logged bandit data, i.e., :math:`r_i`.

        Returns
        --------
        (training_data_loader, validation_data_loader): Tuple[DataLoader, Optional[DataLoader]]
            Training and validation data loaders in PyTorch

        """
        if self.batch_size == "auto":
            batch_size_ = min(200, surrogate_reward.shape[0])
        else:
            check_scalar(self.batch_size, "batch_size", int, min_val=1)
            batch_size_ = self.batch_size
        surrogate_reward = surrogate_reward.astype('float32')
        dataset = QFuncEstimatorDataset(
            torch.from_numpy(surrogate_reward).float(),
            torch.from_numpy(action).long(),
            torch.from_numpy(reward).float(),
        )

        if self.early_stopping:
            if surrogate_reward.shape[0] <= 1:
                raise ValueError(
                    f"the number of samples is too small ({surrogate_reward.shape[0]}) to create validation data"
                )

            validation_size = max(int(surrogate_reward.shape[0] * self.validation_fraction), 1)
            training_size = surrogate_reward.shape[0] - validation_size
            training_dataset, validation_dataset = torch.utils.data.random_split(
                dataset, [training_size, validation_size]
            )
            training_data_loader = torch.utils.data.DataLoader(
                training_dataset,
                batch_size=batch_size_,
                shuffle=self.shuffle,
            )
            validation_data_loader = torch.utils.data.DataLoader(
                validation_dataset,
                batch_size=batch_size_,
                shuffle=self.shuffle,
            )

            return training_data_loader, validation_data_loader

        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size_,
            shuffle=self.shuffle,
        )

        return data_loader, None

    def fit(
        self, 
        context: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        surrogate_reward: np.ndarray,
        obs_list: np.ndarray,
    ) -> None:

        if surrogate_reward.shape[1] != self.s_dim:
            raise ValueError(
                "Expected `surrogate_reward.shape[1] == self.s_dim`, but found it False"
            )

        if self.solver == "sgd":
            optimizer = optim.SGD(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                momentum=self.momentum,
                weight_decay=self.alpha,
                nesterov=self.nesterovs_momentum,
            )
        elif self.solver == "adagrad":
            optimizer = optim.Adagrad(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                eps=self.epsilon,
                weight_decay=self.alpha,
            )
        elif self.solver == "adam":
            optimizer = optim.Adam(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                betas=(self.beta_1, self.beta_2),
                eps=self.epsilon,
                weight_decay=self.alpha,
            )
        else:
            raise NotImplementedError(
                "`solver` must be one of 'adam', 'adagrad', or 'sgd'"
            )

        (
            training_data_loader,
            validation_data_loader,
        ) = self._create_train_data_for_q_func_estimation(
            surrogate_reward,
            action,
            reward,
        )

        n_not_improving_training = 0
        previous_training_loss = None
        n_not_improving_validation = 0
        previous_validation_loss = None
        training_losses = []
        validation_losses = []
        validation_steps = []
        for _ in tqdm(np.arange(self.max_iter), desc="q-func learning"):
            self.nn_model.train()
            training_loss = []
            validation_loss = []
            validation_step = []
            for s, a, r in training_data_loader:
                optimizer.zero_grad()
                q_hat = self.nn_model(s)[torch.arange(a.shape[0], dtype=torch.long), a]
                loss = mse_loss(r, q_hat)
                loss.backward()
                optimizer.step()

                loss_value = loss.item()
                training_loss.append(loss_value)
                if previous_training_loss is not None:
                    if loss_value - previous_training_loss < self.tol:
                        n_not_improving_training += 1
                    else:
                        n_not_improving_training = 0
                if n_not_improving_training >= self.n_iter_no_change:
                    break
                previous_training_loss = loss_value
            
            average_training_loss = sum(training_loss) / len(training_loss)
            training_losses.append(average_training_loss)

            if self.early_stopping:
                self.nn_model.eval()
                for s, a, r in validation_data_loader:
                    q_hat = self.nn_model(s)[
                        torch.arange(a.shape[0], dtype=torch.long), a
                    ]
                    loss = mse_loss(r, q_hat)
                    loss_value = loss.item()
                    validation_loss.append(loss_value)
                    if previous_validation_loss is not None:
                        if loss_value - previous_validation_loss < self.tol:
                            n_not_improving_validation += 1
                        else:
                            n_not_improving_validation = 0
                    if n_not_improving_validation > self.n_iter_no_change:
                        break
                    previous_validation_loss = loss_value
                average_validation_loss = sum(validation_loss) / len(validation_loss)
                validation_losses.append(average_validation_loss)
                

    def predict(self, context: torch.Tensor, surrogate_reward: torch.Tensor) -> torch.Tensor:

        check_tensor(tensor=surrogate_reward, name="surrogate_reward", expected_dim=2)
        if surrogate_reward.shape[1] != self.s_dim:
            raise ValueError(
                "Expected `surrogate_reward.shape[1] == self.s_dim`, but found it False"
            )

        self.nn_model.eval()
        return self.nn_model(surrogate_reward)


@dataclass
class QFuncEstimatorWithS_new:
    n_actions: int
    dim_context: int
    s_dim: int
    hidden_layer_size: Tuple[int, ...] = (100, )
    hidden_layer_size_s: Tuple[int, ...] = (100, )
    activation: str = "relu"
    solver: str = "adam"
    alpha: float = 0.0001
    batch_size: Union[int, str] = "auto"
    learning_rate_init: float = 0.0001
    max_iter: int = 200
    shuffle: bool = True
    random_state: Optional[int] = None
    tol: float = 1e-4
    momentum: float = 0.9
    nesterovs_momentum: bool = True
    early_stopping: bool = False
    validation_fraction: float = 0.1
    beta_1: float = 0.9
    beta_2: float = 0.999
    epsilon: float = 1e-8
    n_iter_no_change: int = 10
    confidence_threshold: float = 0.9

    def __post_init__(self):
        super().__init__()
        self.build_model()

    def build_model(self):
        activation_function = self.get_activation_function(self.activation)
        
        # Subnetwork for user context
        layers_context = [('input_context', nn.Linear(self.dim_context, self.hidden_layer_size[0]))]
        for i in range(1, len(self.hidden_layer_size)):
            layers_context.append((f'hidden_context_{i}', nn.Linear(self.hidden_layer_size[i-1], self.hidden_layer_size[i])))
            layers_context.append((f'activation_context_{i}', activation_function()))

        # Subnetwork for surrogate rewards
        layers_s = [('input_s', nn.Linear(self.s_dim, self.hidden_layer_size_s[0]))]
        for i in range(1, len(self.hidden_layer_size_s)):
            layers_s.append((f'hidden_s_{i}', nn.Linear(self.hidden_layer_size_s[i-1], self.hidden_layer_size_s[i])))
            layers_s.append((f'activation_s_{i}', activation_function()))
    
        total_last_layer_size = self.hidden_layer_size[-1] + self.hidden_layer_size_s[-1]
        combined_layers = [
            ('combined_layer', nn.Linear(total_last_layer_size, total_last_layer_size // 2)),
            ('combined_activation', activation_function()),
            ('output_layer', nn.Linear(total_last_layer_size // 2, self.n_actions))
        ]

        self.model_context = nn.Sequential(OrderedDict(layers_context))
        self.model_s = nn.Sequential(OrderedDict(layers_s))
        self.combined_model = nn.Sequential(OrderedDict(combined_layers))

        self.nn_model = nn.ModuleDict({
            'context_path': self.model_context,
            'surrogate_path': self.model_s,
            'combined': self.combined_model
        })
    
    def get_optimizer(self):
        if self.solver == "sgd":
            optimizer = optim.SGD(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                momentum=self.momentum,
                weight_decay=self.alpha,
                nesterov=self.nesterovs_momentum,
            )
        elif self.solver == "adagrad":
            optimizer = optim.Adagrad(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                eps=self.epsilon,
                weight_decay=self.alpha,
            )
        elif self.solver == "adam":
            optimizer = optim.Adam(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                betas=(self.beta_1, self.beta_2),
                eps=self.epsilon,
                weight_decay=self.alpha,
            )
        else:
            raise NotImplementedError("`solver` must be one of 'adam', 'adagrad', or 'sgd'")
        return optimizer
    
    def forward(self, x, s):
        context_output = self.model_context(x)
        
        surrogate_output = self.model_s(s)
        combined_input = torch.cat((context_output, surrogate_output), dim=-1)
        
        final_output = self.combined_model(combined_input)
        
        return final_output

    def get_activation_function(self, name):
        activations = {
            "identity": nn.Identity,
            "logistic": nn.Sigmoid,
            "tanh": nn.Tanh,
            "relu": nn.ReLU,
            "elu": nn.ELU
        }
        if name not in activations:
            raise ValueError(f"`activation` must be one of {list(activations.keys())}, but {name} is given")
        return activations[name]
    
    def _create_train_data_for_q_func_estimation_with_s(
        self,
        context: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        surrogate_reward: np.ndarray,
        **kwargs,
    ) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.DataLoader]]:
        if self.batch_size == "auto":
            batch_size_ = min(200, context.shape[0])
        else:
            check_scalar(self.batch_size, "batch_size", int, min_val=1)
            batch_size_ = self.batch_size
        context = context.astype('float32')
        dataset = QFuncEstimatorWithSDataset(
            torch.from_numpy(context).float(),
            torch.from_numpy(action).long(),
            torch.from_numpy(surrogate_reward).float(),
            torch.from_numpy(reward).float(),
        )

        if self.early_stopping:
            if context.shape[0] <= 1:
                raise ValueError(
                    f"the number of samples is too small ({context.shape[0]}) to create validation data"
                )

            validation_size = max(int(context.shape[0] * self.validation_fraction), 1)
            training_size = context.shape[0] - validation_size
            training_dataset, validation_dataset = torch.utils.data.random_split(
                dataset, [training_size, validation_size]
            )
            training_data_loader = torch.utils.data.DataLoader(
                training_dataset,
                batch_size=batch_size_,
                shuffle=self.shuffle,
            )
            validation_data_loader = torch.utils.data.DataLoader(
                validation_dataset,
                batch_size=batch_size_,
                shuffle=self.shuffle,
            )

            return training_data_loader, validation_data_loader

        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size_,
            shuffle=self.shuffle,
        )

        return data_loader, None    
    
    def fit(self, context: np.ndarray, action: np.ndarray, reward: np.ndarray, surrogate_reward: np.ndarray, obs_list: np.ndarray):
        optimizer = self.get_optimizer()
        
        # Separate observed and unobserved data
        obs_context = context[obs_list == 1]
        obs_action = action[obs_list == 1]
        obs_surrogate_reward = surrogate_reward[obs_list == 1]
        obs_reward = reward[obs_list == 1]
        
        un_obs_context = context[obs_list == 0]
        un_obs_action = action[obs_list == 0]
        un_obs_surrogate_reward = surrogate_reward[obs_list == 0]
        
        # length of observed data
        n_obs = obs_context.shape[0]
        n_unobs = un_obs_context.shape[0]
        p_o = n_obs / (n_obs + n_unobs)
        
        # Create data loaders for both observed and unobserved data
        obs_training_loader, obs_validation_loader= self._create_train_data_for_q_func_estimation_with_s(
            context=obs_context,
            action=obs_action,
            reward=obs_reward,
            surrogate_reward=obs_surrogate_reward)
        
        self.train_model(obs_training_loader, obs_validation_loader, optimizer)
        
        # filtered_context, filtered_action, filtered_surrogate_rewards, filtered_rewards = self.predict_and_filter(un_obs_context, un_obs_action, un_obs_surrogate_reward)
        
        # if filtered_context.size > 0 and p_o >=0.8:
        #     combined_context = np.vstack([obs_context, filtered_context])
        #     combined_action = np.concatenate([obs_action, filtered_action])
        #     combined_surrogate_rewards = np.vstack([obs_surrogate_reward, filtered_surrogate_rewards])
        #     combined_rewards = np.concatenate([obs_reward, filtered_rewards])
            
        #     combined_training_loader, _ = self._create_train_data_for_q_func_estimation_with_s(
        #         context=combined_context, 
        #         action=combined_action, 
        #         reward=combined_rewards, 
        #         surrogate_reward=combined_surrogate_rewards)
        #     self.train_model(combined_training_loader, obs_validation_loader, optimizer)
                    
    def train_model(self, training_data_loader, validation_data_loader, optimizer):
        self.nn_model.train()  # Start in training mode
        previous_training_loss = float('inf')
        n_not_improving_training = 0
        previous_validation_loss = float('inf')
        n_not_improving_validation = 0
        training_losses = []
        validation_losses = []
        for epoch in range(self.max_iter):
            total_loss = 0
            for x, a, s, r in training_data_loader:
                optimizer.zero_grad()
                q_hat = self.forward(x, s)  # Using forward method
                q_hat_selected = q_hat[torch.arange(a.shape[0]), a]
                loss = mse_loss(r, q_hat_selected)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()  # Incrementing cumulative loss

            average_training_loss = total_loss / len(training_data_loader)
            training_losses.append(average_training_loss)
            # Loss improvement check
            if (previous_training_loss - average_training_loss) < self.tol:
                n_not_improving_training += 1
            else:
                n_not_improving_training = 0  # Reset if there is improvement

            # Early stopping check
            if n_not_improving_training >= self.n_iter_no_change:
                break  # Exit training if no improvement

            # Update tracking variable after each epoch
            previous_training_loss = average_training_loss
            
            # Validation phase
            if self.early_stopping and validation_data_loader is not None:
                self.nn_model.eval()  # Switch to evaluation mode for validation
                validation_loss = 0
                for x, a, s, r in validation_data_loader:
                    with torch.no_grad():
                        q_hat = self.forward(x, s)
                        q_hat_selected = q_hat[torch.arange(a.shape[0]), a]
                        validation_loss += mse_loss(r, q_hat_selected).item()  # Sum up Validation Gradient
                
                average_validation_loss = validation_loss / len(validation_data_loader)
                validation_losses.append(average_validation_loss)
                # Loss improvement check    
                if (previous_validation_loss - average_validation_loss) < self.tol:
                    n_not_improving_validation += 1
                else:
                    n_not_improving_validation = 0
                
                if n_not_improving_validation > self.n_iter_no_change:
                    break

                    
            
    def predict_and_filter(self, context, action, surrogate_reward):
        self.nn_model.eval()  
        with torch.no_grad():
            x_tensor = torch.from_numpy(context).float()
            s_tensor = torch.from_numpy(surrogate_reward).float()
            q_values = self.forward(x_tensor, s_tensor)
            
            predicted_rewards = q_values.max(dim=1).values
            confidence_scores = self.calculate_confidence(x_tensor, s_tensor)
            
            # Filter data based on confidence
            high_confidence_indices = (confidence_scores > self.confidence_threshold)
            return context[high_confidence_indices], action[high_confidence_indices], surrogate_reward[high_confidence_indices], predicted_rewards[high_confidence_indices]
    
    def predict(self, context: torch.Tensor, surrogate_reward: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        context: Tensor, shape (n_rounds_of_new_data, dim_context)
            Context vectors for new data.
        surrogate_reward: Tensor, shape (n_rounds_of_new_data, dim_s)
            Surrogate reward vectors for new data.

        Returns
        -------
        estimated_expected_rewards: Tensor, shape (n_rounds_of_new_data,)
            Expected rewards given context, surrogate reward, and action for new data estimated by the regression model.
        """
        check_tensor(tensor=context, name="context", expected_dim=2)
        check_tensor(tensor=surrogate_reward, name="surrogate_reward", expected_dim=2)
        
        if context.shape[1] != self.dim_context:
            raise ValueError("Expected `context.shape[1] == self.dim_context`, but found it False")
        if surrogate_reward.shape[1] != self.s_dim:
            raise ValueError("Expected `surrogate_reward.shape[1] == self.dim_s`, but found it False")

        self.nn_model.eval()

        return self.forward(context, surrogate_reward)
    
    def calculate_confidence(self, x, s, num_samples=100):
        self.nn_model.train()  # Enable dropout for prediction variability
        predictions = []
        for _ in range(num_samples):
            with torch.no_grad():
                q_values = self.forward(x, s)
                predictions.append(q_values.detach())
        predictions = torch.stack(predictions, dim=1)  # Combine all predictions along a new dimension
        
        # Calculate standard deviation along the samples dimension
        std_dev = torch.std(predictions, dim=1)
        if torch.any(std_dev == 0):
            std_dev += 1e-8  # To prevent division by zero

        confidence = 1 / std_dev  # Inverse of standard deviation as a measure of confidence
        return confidence.mean(dim=1)  # Averaging confidence scores across all actions


        # if self.early_stopping:
        #     plt.figure(figsize=(12, 6))
        #     plt.plot(range(len(training_losses)), training_losses, label='QFuncS Training Gradient')
        #     plt.plot(range(len(validation_losses)), validation_losses, label='QFuncS Validation Gradient')
        #     plt.xlabel('Epochs')
        #     plt.ylabel('Loss')
        #     plt.title('Learning Curve')
        #     plt.legend()
        #     plt.grid(True)
        #     plt.show()
        # else:
        #     plt.figure(figsize=(12, 6))
        #     plt.plot(range(len(training_losses)), training_losses, label='QFuncS Training Gradient')
        #     plt.xlabel('Epochs')
        #     plt.ylabel('Loss')
        #     plt.title('Learning Curve')
        #     plt.legend()
        #     plt.grid(True)
        #     plt.show()        

@dataclass
class QFuncEstimatorWithS_old:
    n_actions: int
    dim_context: int
    s_dim: int
    hidden_layer_size: Tuple[int, ...] = (150, )
    activation: str = "relu"
    solver: str = "adam"
    alpha: float = 0.0001
    batch_size: Union[int, str] = "auto"
    learning_rate_init: float = 0.0001
    max_iter: int = 200
    shuffle: bool = True
    random_state: Optional[int] = None
    tol: float = 1e-4
    momentum: float = 0.9
    nesterovs_momentum: bool = True
    early_stopping: bool = False
    validation_fraction: float = 0.1
    beta_1: float = 0.9
    beta_2: float = 0.999
    epsilon: float = 1e-8
    n_iter_no_change: int = 10
    confidence_threshold: float = 0.9
    
    def __post_init__(self) -> None:
        """Initialize class."""
        check_scalar(self.dim_context, "dim_context", int, min_val=1)
        check_scalar(self.s_dim, "dim_context", int, min_val=1)

        if not isinstance(self.hidden_layer_size, tuple) or any(
            [not isinstance(h, int) or h <= 0 for h in self.hidden_layer_size]
        ):
            raise ValueError(
                f"`hidden_layer_size` must be a tuple of positive integers, but {self.hidden_layer_size} is given"
            )

        if self.solver not in ("adagrad", "sgd", "adam"):
            raise ValueError(
                f"`solver` must be one of 'adam', but {self.solver} is given"
            )

        check_scalar(self.alpha, "alpha", float, min_val=0.0)

        if self.batch_size != "auto" and (
            not isinstance(self.batch_size, int) or self.batch_size <= 0
        ):
            raise ValueError(
                f"`batch_size` must be a positive integer or 'auto', but {self.batch_size} is given"
            )

        check_scalar(self.learning_rate_init, "learning_rate_init", float)
        if self.learning_rate_init <= 0.0:
            raise ValueError(
                f"`learning_rate_init`= {self.learning_rate_init}, must be > 0.0"
            )

        check_scalar(self.max_iter, "max_iter", int, min_val=1)

        if not isinstance(self.shuffle, bool):
            raise ValueError(f"`shuffle` must be a bool, but {self.shuffle} is given")

        check_scalar(self.tol, "tol", float)
        if self.tol <= 0.0:
            raise ValueError(f"`tol`= {self.tol}, must be > 0.0")

        check_scalar(self.momentum, "momentum", float, min_val=0.0, max_val=1.0)

        if not isinstance(self.nesterovs_momentum, bool):
            raise ValueError(
                f"`nesterovs_momentum` must be a bool, but {self.nesterovs_momentum} is given"
            )

        if not isinstance(self.early_stopping, bool):
            raise ValueError(
                f"`early_stopping` must be a bool, but {self.early_stopping} is given"
            )

        check_scalar(
            self.validation_fraction, "validation_fraction", float, max_val=1.0
        )
        if self.validation_fraction <= 0.0:
            raise ValueError(
                f"`validation_fraction`= {self.validation_fraction}, must be > 0.0"
            )

        check_scalar(self.beta_1, "beta_1", float, min_val=0.0, max_val=1.0)
        check_scalar(self.beta_2, "beta_2", float, min_val=0.0, max_val=1.0)
        check_scalar(self.epsilon, "epsilon", float, min_val=0.0)
        check_scalar(self.n_iter_no_change, "n_iter_no_change", int, min_val=1)

        if self.random_state is not None:
            self.random_ = check_random_state(self.random_state)
            torch.manual_seed(self.random_state)

        if self.activation == "identity":
            activation_layer = nn.Identity
        elif self.activation == "logistic":
            activation_layer = nn.Sigmoid
        elif self.activation == "tanh":
            activation_layer = nn.Tanh
        elif self.activation == "relu":
            activation_layer = nn.ReLU
        elif self.activation == "elu":
            activation_layer = nn.ELU
        else:
            raise ValueError(
                "`activation` must be one of 'identity', 'logistic', 'tanh', 'relu', or 'elu'"
                f", but {self.activation} is given"
            )

        layer_list = []
        self.input_size = self.dim_context + self.s_dim

        for i, h in enumerate(self.hidden_layer_size):
            layer_list.append(("l{}".format(i), nn.Linear(self.input_size, h)))
            layer_list.append(("a{}".format(i), activation_layer(self.activation)))
            self.input_size = h

        layer_list.append(("output", nn.Linear(self.input_size, self.n_actions)))
        self.nn_model = nn.Sequential(OrderedDict(layer_list))
        
    def _create_train_data_for_q_func_estimation_with_s(
        self,
        context: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        surrogate_reward: np.ndarray,
        **kwargs,
    ) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.DataLoader]]:
        if self.batch_size == "auto":
            batch_size_ = min(200, context.shape[0])
        else:
            check_scalar(self.batch_size, "batch_size", int, min_val=1)
            batch_size_ = self.batch_size
        context = context.astype('float32')
        dataset = QFuncEstimatorWithSDataset(
            torch.from_numpy(context).float(),
            torch.from_numpy(action).long(),
            torch.from_numpy(surrogate_reward).float(),
            torch.from_numpy(reward).float(),
        )

        if self.early_stopping:
            if context.shape[0] <= 1:
                raise ValueError(
                    f"the number of samples is too small ({context.shape[0]}) to create validation data"
                )

            validation_size = max(int(context.shape[0] * self.validation_fraction), 1)
            training_size = context.shape[0] - validation_size
            training_dataset, validation_dataset = torch.utils.data.random_split(
                dataset, [training_size, validation_size]
            )
            training_data_loader = torch.utils.data.DataLoader(
                training_dataset,
                batch_size=batch_size_,
                shuffle=self.shuffle,
            )
            validation_data_loader = torch.utils.data.DataLoader(
                validation_dataset,
                batch_size=batch_size_,
                shuffle=self.shuffle,
            )

            return training_data_loader, validation_data_loader

        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size_,
            shuffle=self.shuffle,
        )

        return data_loader, None    
    
    def fit(
        self, 
        context: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        surrogate_reward: np.ndarray,
        obs_list: np.ndarray,
    ) -> None:
        if context.shape[1] != self.dim_context:
            raise ValueError(
                "Expected `context.shape[1] == self.dim_context`, but found it False"
            )

        if self.solver == "sgd":
            optimizer = optim.SGD(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                momentum=self.momentum,
                weight_decay=self.alpha,
                nesterov=self.nesterovs_momentum,
            )
        elif self.solver == "adagrad":
            optimizer = optim.Adagrad(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                eps=self.epsilon,
                weight_decay=self.alpha,
            )
        elif self.solver == "adam":
            optimizer = optim.Adam(
                self.nn_model.parameters(),
                lr=self.learning_rate_init,
                betas=(self.beta_1, self.beta_2),
                eps=self.epsilon,
                weight_decay=self.alpha,
            )
        else:
            raise NotImplementedError(
                "`solver` must be one of 'adam', 'adagrad', or 'sgd'"
            )

        obs_context = context[obs_list==1]
        obs_action = action[obs_list==1]
        obs_surrogate_reward = surrogate_reward[obs_list==1]
        obs_reward = reward[obs_list==1]
        un_obs_context = context[obs_list==0]
        un_obs_action = action[obs_list==0]
        un_obs_surrogate_reward = surrogate_reward[obs_list==0]
        un_obs_reward = reward[obs_list==0]
        
        (
            obs_training_data_loader,
            obs_validation_data_loader,
        ) = self._create_train_data_for_q_func_estimation_with_s(
            context=obs_context,
            action=obs_action,
            reward=obs_reward,
            surrogate_reward=obs_surrogate_reward,
        )
        
        self.train_model(
            training_data_loader=obs_training_data_loader,
            validation_data_loader=obs_validation_data_loader,
            optimizer=optimizer,
        )
        
        filtered_context, filtered_action, filtered_surrogate_rewards, filtered_predicted_rewards = self.predict_and_filter(
            context=un_obs_context,
            action=un_obs_action,
            surrogate_reward=un_obs_surrogate_reward,
        )
        # if filtered_context.size > 0 and obs_context.shape[0] / (context.shape[0]) >= 0.8:
        #     context_with_fil = np.concatenate([obs_context, filtered_context], axis=0)
        #     action_with_fil = np.concatenate([obs_action, filtered_action], axis=0)
        #     surrogate_reward_with_fil = np.concatenate([obs_surrogate_reward, filtered_surrogate_rewards], axis=0)
        #     reward_with_fil = np.concatenate([obs_reward, filtered_predicted_rewards], axis=0)
            
        #     (
        #         high_conf_training_data_loader,
        #         unobs_validation_data_loader,
        #     ) = self._create_train_data_for_q_func_estimation_with_s(
        #         context=context_with_fil,
        #         action=action_with_fil,
        #         reward=reward_with_fil,
        #         surrogate_reward=surrogate_reward_with_fil,
        #     )
            
        #     self.train_model(
        #         training_data_loader=high_conf_training_data_loader, 
        #         validation_data_loader=obs_validation_data_loader,
        #         optimizer=optimizer,)
        
    def predict_and_filter(self, context, action, surrogate_reward):
        self.nn_model.eval()  # Set the model to evaluation mode

        # Prepare the data
        x = torch.tensor(context, dtype=torch.float32)
        s = torch.tensor(surrogate_reward, dtype=torch.float32)
        a = torch.tensor(action, dtype=torch.long)

        # Concatenate context and surrogate reward
        x_s = torch.cat([x, s], dim=1)

        # Predict rewards
        with torch.no_grad():
            all_q_values = self.nn_model(x_s)
            predicted_rewards = all_q_values.gather(1, a.view(-1, 1)).squeeze(1)
        # Calculate confidence
        confidence_scores = self._calculate_confidence(x_s)

        # Filter based on confidence threshold
        confident_indices = confidence_scores > self.confidence_threshold
        filtered_context = context[confident_indices]
        filtered_action = action[confident_indices]
        filtered_surrogate_reward = surrogate_reward[confident_indices]
        filtered_predicted_rewards = predicted_rewards[confident_indices]
        

        return filtered_context, filtered_action, filtered_surrogate_reward, filtered_predicted_rewards

    def _calculate_confidence(self, x_s, num_samples=100):
        self.nn_model.train()  # Enable dropout for prediction variability
        predictions = [self.nn_model(x_s).detach() for _ in range(num_samples)]
        predictions = torch.stack(predictions, dim=1)  # Change in stacking dimension
        std_dev = torch.std(predictions, dim=1)
        confidence = 1 / std_dev  # Inverse of standard deviation as a measure of confidence
        return confidence.mean(dim=1)  # Ensuring a one-dimensional output


    def train_model(
        self,
        training_data_loader: torch.utils.data.DataLoader,
        validation_data_loader: Optional[torch.utils.data.DataLoader],
        optimizer: torch.optim.Optimizer,
    ) -> None:
        n_not_improving_training = 0
        previous_training_loss = None
        n_not_improving_validation = 0
        previous_validation_loss = None
        training_losses = []
        validation_losses = []
        for _ in tqdm(np.arange(self.max_iter), desc="q-func with s learning"):
            self.nn_model.train()
            training_loss = []  
            validation_loss = []
            for x, a, s, r in training_data_loader:
                x_s = torch.cat([x, s], dim=1)  
                optimizer.zero_grad()
                q_hat = self.nn_model(x_s)[torch.arange(a.shape[0], dtype=torch.long), a]
                loss = mse_loss(r, q_hat)
                loss.backward()
                optimizer.step()

                # Check for training improvement
                loss_value = loss.item()
                training_loss.append(loss_value)
                if previous_training_loss is not None:
                    if loss_value - previous_training_loss < self.tol:
                        n_not_improving_training += 1
                    else:
                        n_not_improving_training = 0
                if n_not_improving_training >= self.n_iter_no_change:
                    break
                previous_training_loss = loss_value
            average_training_loss = sum(training_loss) / len(training_loss)
            training_losses.append(average_training_loss)
            if self.early_stopping:
                self.nn_model.eval()
                for x, a, s, r in validation_data_loader:
                    x_s = torch.cat([x, s], dim=1)
                    q_hat = self.nn_model(x_s)[torch.arange(a.shape[0], dtype=torch.long), a]
                    loss = mse_loss(r, q_hat)
                    loss_value = loss.item()
                    validation_loss.append(loss_value)
                    if previous_validation_loss is not None:
                        if loss_value - previous_validation_loss < self.tol:
                            n_not_improving_validation += 1
                        else:
                            n_not_improving_validation = 0
                    if n_not_improving_validation > self.n_iter_no_change:
                        break
                    previous_validation_loss = loss_value
                average_validation_loss = sum(validation_loss) / len(validation_loss)
                validation_losses.append(average_validation_loss)
        # if self.early_stopping:
        #     plt.figure(figsize=(12, 6))
        #     plt.plot(range(len(training_losses)), training_losses, label='QFuncS Training Gradient')
        #     plt.plot(range(len(validation_losses)), validation_losses, label='QFuncS Validation Gradient')
        #     plt.xlabel('Epochs')
        #     plt.ylabel('Loss')
        #     plt.title('Learning Curve')
        #     plt.legend()
        #     plt.grid(True)
        #     plt.show()
        # else:
        #     plt.figure(figsize=(12, 6))
        #     plt.plot(range(len(training_losses)), training_losses, label='QFuncS Training Gradient')
        #     plt.xlabel('Epochs')
        #     plt.ylabel('Loss')
        #     plt.title('Learning Curve')
        #     plt.legend()
        #     plt.grid(True)
        #     plt.show()

    def predict(self, context: torch.Tensor, surrogate_reward: torch.Tensor) -> torch.Tensor:
        """
        Predict best continuous actions for new data, incorporating surrogate rewards.

        Parameters
        ----------
        context: Tensor, shape (n_rounds_of_new_data, dim_context)
            Context vectors for new data.
        surrogate_reward: Tensor, shape (n_rounds_of_new_data, dim_s)
            Surrogate reward vectors for new data.

        Returns
        -------
        estimated_expected_rewards: Tensor, shape (n_rounds_of_new_data,)
            Expected rewards given context, surrogate reward, and action for new data estimated by the regression model.
        """
        check_tensor(tensor=context, name="context", expected_dim=2)
        check_tensor(tensor=surrogate_reward, name="surrogate_reward", expected_dim=2)
        
        if context.shape[1] != self.dim_context:
            raise ValueError("Expected `context.shape[1] == self.dim_context`, but found it False")
        if surrogate_reward.shape[1] != self.s_dim:
            raise ValueError("Expected `surrogate_reward.shape[1] == self.dim_s`, but found it False")

        self.nn_model.eval()

        # Concatenate context and surrogate reward
        x_s = torch.cat([context, surrogate_reward], dim=1)
        
        # Predict using the concatenated features
        return self.nn_model(x_s)
    
                    
@dataclass
class QFuncEstimatorWithSDataset(torch.utils.data.Dataset):
    """PyTorch dataset for QFuncEstimatorWithS"""

    feature: np.ndarray
    action: np.ndarray
    surrogate_reward: np.ndarray
    reward: np.ndarray

    def __post_init__(self):
        """initialize class"""
        assert self.feature.shape[0] == self.action.shape[0] == self.reward.shape[0] == self.surrogate_reward.shape[0]

    def __getitem__(self, index):
        return (
            self.feature[index],
            self.action[index],
            self.surrogate_reward[index],
            self.reward[index],
        )

    def __len__(self):
        return self.feature.shape[0]