from dataclasses import dataclass
import enum

import numpy as np
from obp.dataset import linear_behavior_policy
from obp.dataset import linear_reward_function
from obp.dataset import logistic_reward_function
from obp.dataset import logistic_polynomial_reward_function
from obp.dataset import polynomial_reward_function
from obp.dataset import SyntheticBanditDataset
from obp.types import BanditFeedback
from obp.utils import sample_action_fast
from obp.utils import check_array
from obp.utils import softmax
from scipy.stats import rankdata
from sklearn.preprocessing import OneHotEncoder
from sklearn.utils import check_random_state
from sklearn.utils import check_scalar
from sklearn.preprocessing import PolynomialFeatures
from scipy.stats import norm
from scipy.stats import truncnorm


def sigmoid(x: np.ndarray) -> np.ndarray:
    return 1.0 / (1 + np.exp(-x))


class RewardType(enum.Enum):
    BINARY = "binary"
    CONTINUOUS = "continuous"

    def __repr__(self) -> str:

        return str(self)

def sample_random_uniform_coefficients(
    dim_surrogate_context: int,
    dim_context: int,
    dim_action: int,
    random_: np.random.RandomState,
    **kwargs,
) -> np.ndarray:
    context_coef_ = random_.uniform(-1, 1, size=dim_context)
    action_coef_ = random_.uniform(-1, 1, size=dim_action)
    surrogate_coef_ = random_.uniform(-1, 1, size=dim_surrogate_context)
    context_action_coef_ = random_.uniform(-1, 1, size=(dim_context, dim_action))
    context_surrogate_coef_ = random_.uniform(-1, 1, size=(dim_context, dim_surrogate_context))
    action_surrogate_coef_ = random_.uniform(-1, 1, size=(dim_action, dim_surrogate_context))
    return context_coef_, action_coef_, surrogate_coef_, context_action_coef_, context_surrogate_coef_, action_surrogate_coef_

def fixed_expected_reward_function(
    context: np.ndarray,
    action_context: np.ndarray,
    expected_surrogate: np.ndarray,
    lambda_: float,
    random_state: int,
    **kwargs,
) -> np.ndarray:
    """Generate surrogate rewards given contexts and action contexts."""
    poly = PolynomialFeatures(degree=1)
    context_ = poly.fit_transform(context)
    action_context_ = poly.fit_transform(action_context)
    bias = np.ones((expected_surrogate.shape[0], expected_surrogate.shape[1], 1))
    surrogate_context_ = np.concatenate([bias, expected_surrogate], axis=2)
    datasize, dim_context = context_.shape
    n_actions, dim_action_context = action_context_.shape
    datasize, n_actions, dim_surrogate_context = surrogate_context_.shape
    random_ = check_random_state(random_state)
    
    context_coef_, action_coef_, surrogate_coef_, context_action_coef_, context_surrogate_coef_, action_surrogate_coef_ = sample_random_uniform_coefficients(
        dim_surrogate_context=dim_surrogate_context,
        dim_context=dim_context,
        dim_action=dim_action_context,
        random_=random_,
    )
    context_values = np.tile(context_ @ context_coef_, (n_actions, 1)).T
    
    action_values = np.tile(action_coef_ @ action_context_.T, (datasize, 1))
    i, h, k = surrogate_context_.shape
    surrogate_context_reshaped = surrogate_context_.reshape(i*h, k)
    surrogate_values_reshaped = surrogate_context_reshaped @ surrogate_coef_
    surrogate_values = surrogate_values_reshaped.reshape(i, h)
    context_action_values = np.einsum('ij,jk,hk->ih', context_, context_action_coef_, action_context_)
    context_surrogate_values = np.einsum('ij,jk,ihk->ih', context_, context_surrogate_coef_, surrogate_context_)
    action_surrogate_values = np.einsum('hj,jk,ihk->ih', action_context_, action_surrogate_coef_, surrogate_context_)

    rev_lambda_ = 1.0 - lambda_
    rest_s = context_values+action_values+context_action_values+context_surrogate_values+action_surrogate_values
    non_s_rew = rev_lambda_*rest_s
    s_rew = lambda_*surrogate_values
    q_x_a_f = non_s_rew + s_rew
    
    return q_x_a_f

### Function to calculate expected rewards q(x, a, s)
def expected_reward_function(
    context: np.ndarray,
    action_context: np.ndarray,
    surrogate_context: np.ndarray,
    lambda_: float,
    random_state: int,
    **kwargs,
) -> np.ndarray:
    """Generate surrogate rewards given contexts and action contexts."""
    poly = PolynomialFeatures(degree=1)
    context_ = poly.fit_transform(context)
    action_context_ = poly.fit_transform(action_context)
    surrogate_context_ = poly.fit_transform(surrogate_context)

    datasize, dim_context = context_.shape
    n_actions, dim_action_context = action_context_.shape
    n_surrogate, dim_surrogate_context = surrogate_context_.shape
    random_ = check_random_state(random_state)

    context_coef_, action_coef_, surrogate_coef_, context_action_coef_, context_surrogate_coef_, action_surrogate_coef_ = sample_random_uniform_coefficients(
        dim_surrogate_context=dim_surrogate_context,
        dim_context=dim_context,
        dim_action=dim_action_context,
        random_=random_,
    )


    context_values = context_coef_ @ context_.T
    
    action_values = action_coef_ @ action_context_.T
    
    surrogate_values = surrogate_coef_ @ surrogate_context_.T

    context_action_values = np.einsum('ij,jk,ik->i', context_, context_action_coef_, action_context_)

    context_surrogate_values = np.einsum('ij,jk,ik->i', context_, context_surrogate_coef_, surrogate_context_)
    
    action_surrogate_values = np.einsum('ij,jk,ik->i', action_context_, action_surrogate_coef_, surrogate_context_)

    rev_lambda_ = 1.0 - lambda_
    rest_s = context_values+action_values+context_action_values+context_surrogate_values+action_surrogate_values
    non_s_rew = rev_lambda_*rest_s
    s_rew = lambda_*surrogate_values
    q_x_a_s = non_s_rew + s_rew

    return q_x_a_s, surrogate_coef_

def f_s_function(surrogate_rewards: np.ndarray, surrogate_coef_: np.ndarray) -> np.ndarray:
    # Exclude the first element from surrogate_coef_
    poly = PolynomialFeatures(degree=1)
    surrogate_rewards_ = poly.fit_transform(surrogate_rewards)
    coef_excluding_first = surrogate_coef_[1:]
    binary_coef = np.where(coef_excluding_first >= 0, 1, -1)
    
    # Calculate the result for each sample in surrogate_rewards_
    result = np.sum(binary_coef * surrogate_rewards_[:, 1:], axis=1)
    
    return result

def all_f_s_function(surrogate_rewards: np.ndarray, surrogate_coef_: np.ndarray) -> np.ndarray:
    # Exclude the first element from surrogate_coef_
    poly = PolynomialFeatures(degree=1)
    bias = np.ones((surrogate_rewards.shape[0], surrogate_rewards.shape[1], 1))
    surrogate_rewards_ = np.concatenate([bias, surrogate_rewards], axis=2)
    coef_excluding_first = surrogate_coef_[1:]
    binary_coef = np.where(coef_excluding_first >= 0, 1, -1)
    
    coef_broadcast = binary_coef[np.newaxis, np.newaxis, :]
    
    result = np.sum(coef_broadcast * surrogate_rewards_[:, :, 1:], axis=2)
    
    return result

def s_sum_function(surrogate_rewards: np.ndarray, surrogate_coef_: np.ndarray, alpha_noise: float, random_state: int) -> np.ndarray:
    random_ = check_random_state(random_state)
    coef_excluding_first = surrogate_coef_[1:]
    noisy_coef = coef_excluding_first + random_.normal(0, alpha_noise, coef_excluding_first.shape)
    if surrogate_rewards.ndim == 2:
        surrogate_values = noisy_coef @ surrogate_rewards.T
    elif surrogate_rewards.ndim == 3:
        i, h, k = surrogate_rewards.shape
        surrogate_context_reshaped = surrogate_rewards.reshape(i*h, k)
        surrogate_values_reshaped = surrogate_context_reshaped @ noisy_coef
        surrogate_values = surrogate_values_reshaped.reshape(i, h)
    else:
        raise ValueError("Invalid shape for surrogate_rewards. Expected 2 or 3 dimensions.")
    
    return surrogate_values


from dataclasses import dataclass
from obp.dataset import SyntheticBanditDataset
from obp.dataset import BaseBanditDataset
import numpy as np
from sklearn.preprocessing import OneHotEncoder
from typing import Callable, Tuple
from typing import Optional
from obp.dataset import linear_behavior_policy

@dataclass
class SyntheticBanditDatasetWithSurrogate(BaseBanditDataset):
    n_actions: int
    dim_context: int = 1
    reward_type: str = RewardType.BINARY.value
    reward_std: float = 1.0
    beta: float = 1.0
    alpha_noise: float =0.01
    random_state: int = 12345
    lambda_: float = 0.5
    s_noise: float = 0.4
    s_dim: int = 10
    p_o: float = 0.5

    def __post_init__(self):
        check_scalar(self.n_actions, "n_actions", int, min_val=2)
        check_scalar(self.dim_context, "dim_context", int, min_val=1)
        check_scalar(self.beta, "beta", (int, float))
        check_scalar(self.alpha_noise, "alpha_noise", (int, float), min_val=0)
        if RewardType(self.reward_type) not in [
            RewardType.BINARY,
            RewardType.CONTINUOUS,
        ]:
            raise ValueError(
                f"`reward_type` must be either '{RewardType.BINARY.value}' or '{RewardType.CONTINUOUS.value}',"
                f"but {self.reward_type} is given.'"
            )
        check_scalar(self.reward_std, "reward_std", (int, float), min_val=0)
        self.surrogate_reward_min = 0
        self.surrogate_reward_max = 1e10
        if RewardType(self.reward_type) == RewardType.CONTINUOUS:
            self.reward_min = 0
            self.reward_max = 1e10
        self.action_context = np.eye(self.n_actions, dtype=int)

        if self.random_state is None:
            raise ValueError("`random_state` must be given")
        self.random_ = check_random_state(self.random_state)
        
        self.s_dim = max(1, self.s_dim)  # Ensure at least one dimension
    
    def sample_contextfree_expected_reward(self) -> np.ndarray:
        """Sample expected reward for each action from the uniform distribution."""
        return self.random_.uniform(size=self.n_actions)
    
    def sample_surogate_reward_given_expected_surrogate_reward(
        self,
        expected_surrogate_reward_factual: np.ndarray,
    ) -> np.ndarray:
        mean = expected_surrogate_reward_factual
        surrogate_rewards = self.random_.normal(loc=mean, scale=self.s_noise)
        p_s_xa = norm.pdf(x=surrogate_rewards, loc=expected_surrogate_reward_factual, scale=self.s_noise)
        
        return surrogate_rewards, p_s_xa
    
    def sample_reward_given_expected_reward(
        self,
        expected_reward_factual: np.ndarray,
    ) -> np.ndarray:
        mean = expected_reward_factual
        rewards = self.random_.normal(loc=mean, scale=self.reward_std)
        return rewards

    def obtain_batch_bandit_feedback(self, n_rounds:int) -> BanditFeedback:
        check_scalar(n_rounds, "n_rounds", int, min_val=1)
        contexts=self.random_.normal(size=(n_rounds, self.dim_context))

        # decide pi_0 and sample actions
        pi_b_logits = linear_behavior_policy(
            context=contexts,
            action_context=self.action_context,
            random_state=self.random_state-100,
        )
        pi_b = softmax(self.beta * pi_b_logits)
        actions = sample_action_fast(pi_b, random_state=self.random_state)

        # get expected surrogate rewards and sample
        all_f_x_a=np.zeros(((n_rounds, self.n_actions, self.s_dim)))
        for d in np.arange(self.s_dim):
            all_f_x_a[:, :, d]= linear_reward_function(context=contexts, 
                                                        action_context=self.action_context, 
                                                        random_state=self.random_state+d+1)
        # # correct expected_surrogate_reward_
        # mean = all_f_x_a
        # a = (self.surrogate_reward_min - mean) / self.s_noise
        # b = (self.surrogate_reward_max - mean) / self.s_noise
        # all_f_x_a = truncnorm.stats(
        #     a=a, b=b, loc=mean, scale=self.s_noise, moments="m"
        # )
        
        f_x_a = all_f_x_a[np.arange(actions.shape[0]), actions]
        # sample surrogate reward given
        surrogate_rewards, p_s_xa = self.sample_surogate_reward_given_expected_surrogate_reward(
            expected_surrogate_reward_factual=f_x_a,
        )
        
        all_q_x_a_f = fixed_expected_reward_function(
            context=contexts,
            action_context=self.action_context,
            expected_surrogate=all_f_x_a,
            lambda_=self.lambda_,
            random_state=self.random_state,
        )
        # if RewardType(self.reward_type) == RewardType.CONTINUOUS:
        #     mean_exp = all_q_x_a_f
        #     a_exp = (self.reward_min - mean_exp) / self.reward_std
        #     b_exp = (self.reward_max - mean_exp) / self.reward_std
        #     all_q_x_a_f = truncnorm.stats(
        #         a=a_exp, b=b_exp, loc=mean_exp, scale=self.reward_std, moments="m"
        #     )
        
        q_x_a_f = all_q_x_a_f[np.arange(actions.shape[0]), actions]
        
        q_x_a_s, surrogate_coef_ = expected_reward_function(
            context=contexts,
            action_context=self.action_context[actions],
            surrogate_context=f_x_a,
            lambda_=self.lambda_,
            random_state=self.random_state,
        )
        # if RewardType(self.reward_type) == RewardType.CONTINUOUS:
        #     # correct expected_reward_, as we use truncated normal distribution here
        #     mean = q_x_a_s
        #     a = (self.reward_min - mean) / self.reward_std
        #     b = (self.reward_max - mean) / self.reward_std
        #     q_x_a_s = truncnorm.stats(
        #         a=a, b=b, loc=mean, scale=self.reward_std, moments="m"
        #     )
        
        if RewardType(self.reward_type) == RewardType.BINARY:
            q_x_a_s = sigmoid(q_x_a_s)
            rewards = self.random_state.binomial(n=1, p=q_x_a_s)
        else:
            rewards = self.sample_reward_given_expected_reward(
                expected_reward_factual=q_x_a_s,
            )
        
        pscores=pi_b[np.arange(n_rounds), actions]
        
        #get F(s) for SurIPS
        s_sum = f_s_function(
            surrogate_rewards=surrogate_rewards,
            surrogate_coef_=surrogate_coef_,)
        
        f_sum = all_f_s_function(
            surrogate_rewards=all_f_x_a,
            surrogate_coef_=surrogate_coef_,)
        
        f_s = s_sum_function(surrogate_rewards=surrogate_rewards, surrogate_coef_=surrogate_coef_, alpha_noise=self.alpha_noise, random_state=self.random_state)
        
        expected_f_s = s_sum_function(surrogate_rewards=all_f_x_a, surrogate_coef_=surrogate_coef_, alpha_noise=self.alpha_noise, random_state=self.random_state)
        
        obs_list = np.zeros(n_rounds)
        num_ones = int(self.p_o*n_rounds) if self.p_o > 0 else 0
        ones_indices = self.random_.choice(n_rounds, num_ones, replace=False)
        obs_list[ones_indices] = 1
        obs_contexts = contexts[obs_list == 1]
        obs_actions = actions[obs_list == 1]
        obs_f_x_a = f_x_a[obs_list == 1]
        obs_q_x_a_s = q_x_a_s[obs_list == 1]
        obs_q_x_a_f = q_x_a_f[obs_list == 1]
        obs_rewards = rewards[obs_list == 1]
        obs_p_s_xa = p_s_xa[obs_list == 1]
        obs_surrogate_rewards = surrogate_rewards[obs_list == 1]
        obs_pi_b = pi_b[obs_list == 1]
        obs_pscores = pscores[obs_list == 1]

        
        return dict(
            n_rounds=n_rounds,
            n_actions=self.n_actions,
            contexts=contexts,
            action_context=self.action_context,
            actions=actions,
            all_f_x_a=all_f_x_a,
            f_x_a = f_x_a,
            q_x_a_f=q_x_a_f,
            all_q_x_a_f = all_q_x_a_f,
            q_x_a_s = q_x_a_s,
            p_s_xa=p_s_xa,
            surrogate_rewards=surrogate_rewards,
            rewards=rewards,
            pi_b=pi_b[:, :, np.newaxis],
            pscores=pscores,
            obs_list=obs_list,
            obs_contexts=obs_contexts,
            obs_actions=obs_actions,
            obs_f_x_a=obs_f_x_a,
            obs_q_x_a_s=obs_q_x_a_s,
            obs_q_x_a_f=obs_q_x_a_f,
            obs_rewards=obs_rewards,    
            obs_p_s_xa=obs_p_s_xa,
            obs_surrogate_rewards=obs_surrogate_rewards,
            obs_pi_b=obs_pi_b[:, :, np.newaxis],
            obs_pscores=obs_pscores,
            f_s=f_s,
            expected_f_s = expected_f_s,
            s_sum = s_sum,
            f_sum = f_sum,
            s_dim = self.s_dim,
            dim_context=self.dim_context,
            surrogate_coef_=surrogate_coef_,
        )
        
    def calc_ground_truth_policy_value(
        self, expected_reward: np.ndarray, action_dist: np.ndarray
    ) -> float:
        """Calculate the policy value of given action distribution on the given expected_reward.

        Parameters
        -----------
        expected_reward: array-like, shape (n_rounds, n_actions)
            Expected reward given context (:math:`x`) and action (:math:`a`), i.e., :math:`q(x,a):=\\mathbb{E}[r|x,a]`.
            This is often the `expected_reward` of the test set of logged bandit data.

        action_dist: array-like, shape (n_rounds, n_actions, len_list)
            Action choice probabilities of the evaluation policy (can be deterministic), i.e., :math:`\\pi_e(a_i|x_i)`.

        Returns
        ----------
        policy_value: float
            The policy value of the given action distribution on the given logged bandit data.

        """
        check_array(array=expected_reward, name="expected_reward", expected_dim=2)
        check_array(array=action_dist, name="action_dist", expected_dim=3)
        if expected_reward.shape[0] != action_dist.shape[0]:
            raise ValueError(
                "Expected `expected_reward.shape[0] = action_dist.shape[0]`, but found it False"
            )
        if expected_reward.shape[1] != action_dist.shape[1]:
            raise ValueError(
                "Expected `expected_reward.shape[1] = action_dist.shape[1]`, but found it False"
            )
        max_rewards_per_round = np.max(expected_reward, axis=1)
        average_max_reward = np.mean(max_rewards_per_round)
        expected_reward_given_act_dist = np.average(expected_reward, weights=action_dist[:, :, 0], axis=1).mean()
        expected_reward_uniform_policy = np.mean(expected_reward)
        return (expected_reward_given_act_dist-expected_reward_uniform_policy)/(average_max_reward-expected_reward_uniform_policy)
    
    def calc_full_policy_value(
        self, expected_reward: np.ndarray, expected_surrogate_reward: np.ndarray, action_dist: np.ndarray, beta:float
    ) -> float:
        """Calculate the full policy value of given action distribution on the given expected_reward.

        Parameters
        -----------
        expected_reward: array-like, shape (n_rounds, n_actions)
            Expected reward given context (:math:`x`) and action (:math:`a`), i.e., :math:`q(x,a):=\\mathbb{E}[r|x,a]`.
            This is often the `expected_reward` of the test set of logged bandit data.

        action_dist: array-like, shape (n_rounds, n_actions, len_list)
            Action choice probabilities of the evaluation policy (can be deterministic), i.e., :math:`\\pi_e(a_i|x_i)`.
        
        beta: float
            The temperature parameter of the full policy value function.
            Beta dicides whether to prioritize the reward or the surrogate reward.
            
        Returns
        ----------
        policy_value: float
            The policy value of the given action distribution on the given logged bandit data.

        """
        
        full_expected_reward = (1-beta)*expected_reward + (beta*expected_surrogate_reward)
        max_rewards_per_round = np.max(full_expected_reward, axis=1)
        average_max_reward = np.mean(max_rewards_per_round)
        expected_reward_given_act_dist = np.average(full_expected_reward, weights=action_dist[:, :, 0], axis=1).mean()
        expected_reward_uniform_policy = np.mean(full_expected_reward)
        return (expected_reward_given_act_dist-expected_reward_uniform_policy)/(average_max_reward-expected_reward_uniform_policy)



    #     Parameters
    #     -----------
    #     f_s: array-like, shape (n_rounds, n_actions)
    #         Expected reward given context (:math:`x`) and action (:math:`a`), i.e., :math:`q(x,a):=\\mathbb{E}[r|x,a]`.
    #         This is often the `expected_reward` of the test set of logged bandit data.

    #     action_dist: array-like, shape (n_rounds, n_actions, len_list)
    #         Action choice probabilities of the evaluation policy (can be deterministic), i.e., :math:`\\pi_e(a_i|x_i)`.

    #     Returns
    #     ----------
    #     policy_value: float
    #         The policy value of the given action distribution on the given logged bandit data.

    #     """
    #     check_array(array=f_s, name="f_s", expected_dim=2)
    #     check_array(array=action_dist, name="action_dist", expected_dim=3)
    #     if f_s.shape[0] != action_dist.shape[0]:
    #         raise ValueError(
    #             "Expected `f_s.shape[0] = action_dist.shape[0]`, but found it False"
    #         )
    #     if f_s.shape[1] != action_dist.shape[1]:
    #         raise ValueError(
    #             "Expected `f_s.shape[1] = action_dist.shape[1]`, but found it False"
    #         )
    #     max_rewards_per_round = np.max(f_s, axis=1)
    #     average_max_reward = np.mean(max_rewards_per_round)
    #     expected_reward_given_act_dist = np.average(f_s, weights=action_dist[:, :, 0], axis=1).mean()
    #     expected_reward_uniform_policy = np.mean(f_s)
    #     return (expected_reward_given_act_dist-expected_reward_uniform_policy)/(average_max_reward-expected_reward_uniform_policy)

    # def calc_ground_truth_policy_value(
    #     self, expected_reward: np.ndarray, action_dist: np.ndarray
    # ) -> float:
    #     """Calculate the policy value of given action distribution on the given expected_reward.

    #     Parameters
    #     -----------
    #     expected_reward: array-like, shape (n_rounds, n_actions)
    #         Expected reward given context (:math:`x`) and action (:math:`a`), i.e., :math:`q(x,a):=\\mathbb{E}[r|x,a]`.
    #         This is often the `expected_reward` of the test set of logged bandit data.

    #     action_dist: array-like, shape (n_rounds, n_actions, len_list)
    #         Action choice probabilities of the evaluation policy (can be deterministic), i.e., :math:`\\pi_e(a_i|x_i)`.

    #     Returns
    #     ----------
    #     policy_value: float
    #         The policy value of the given action distribution on the given logged bandit data.

    #     """
    #     check_array(array=expected_reward, name="expected_reward", expected_dim=2)
    #     check_array(array=action_dist, name="action_dist", expected_dim=3)
    #     if expected_reward.shape[0] != action_dist.shape[0]:
    #         raise ValueError(
    #             "Expected `expected_reward.shape[0] = action_dist.shape[0]`, but found it False"
    #         )
    #     if expected_reward.shape[1] != action_dist.shape[1]:
    #         raise ValueError(
    #             "Expected `expected_reward.shape[1] = action_dist.shape[1]`, but found it False"
    #         )

    #     return np.average(expected_reward, weights=action_dist[:, :, 0], axis=1).mean()
