from dataclasses import dataclass
import enum

import numpy as np
from obp.dataset import linear_behavior_policy
from obp.dataset import linear_reward_function
from obp.dataset import logistic_reward_function
from obp.dataset import logistic_polynomial_reward_function
from obp.dataset import polynomial_reward_function
from obp.dataset import SyntheticBanditDataset
from obp.types import BanditFeedback
from obp.utils import sample_action_fast
from obp.utils import check_array
from obp.utils import softmax
from scipy.stats import rankdata
from sklearn.preprocessing import OneHotEncoder
from sklearn.utils import check_random_state
from sklearn.utils import check_scalar
from sklearn.preprocessing import PolynomialFeatures
from scipy.stats import norm
from scipy.stats import truncnorm
import pandas as pd
import math
from typing import Optional, List
from sklearn.preprocessing import MinMaxScaler
import time


def sigmoid(x: np.ndarray) -> np.ndarray:
    return 1.0 / (1 + np.exp(-x))


class RewardType(enum.Enum):
    BINARY = "binary"
    CONTINUOUS = "continuous"

    def __repr__(self) -> str:

        return str(self)

def sample_random_uniform_coefficients(
    dim_surrogate_context: int,
    dim_context: int,
    dim_action: int,
    random_: np.random.RandomState,
    **kwargs,
) -> np.ndarray:
    context_coef_ = random_.uniform(-1, 1, size=dim_context)
    action_coef_ = random_.uniform(-1, 1, size=dim_action)
    surrogate_coef_ = random_.uniform(-1, 1, size=dim_surrogate_context)
    context_action_coef_ = random_.uniform(-1, 1, size=(dim_context, dim_action))
    context_surrogate_coef_ = random_.uniform(-1, 1, size=(dim_context, dim_surrogate_context))
    action_surrogate_coef_ = random_.uniform(-1, 1, size=(dim_action, dim_surrogate_context))
    return context_coef_, action_coef_, surrogate_coef_, context_action_coef_, context_surrogate_coef_, action_surrogate_coef_

def fixed_expected_reward_function(
    context: np.ndarray,
    action_context: np.ndarray,
    expected_surrogate: np.ndarray,
    lambda_: float,
    random_state: int,
    **kwargs,
) -> np.ndarray:
    """Generate surrogate rewards given contexts and action contexts."""
    poly = PolynomialFeatures(degree=1)
    context_ = poly.fit_transform(context)
    action_context_ = poly.fit_transform(action_context)
    bias = np.ones((expected_surrogate.shape[0], expected_surrogate.shape[1], 1))
    surrogate_context_ = np.concatenate([bias, expected_surrogate], axis=2)
    datasize, dim_context = context_.shape
    n_actions, dim_action_context = action_context_.shape
    datasize, n_actions, dim_surrogate_context = surrogate_context_.shape
    random_ = check_random_state(random_state)
    
    context_coef_, action_coef_, surrogate_coef_, context_action_coef_, context_surrogate_coef_, action_surrogate_coef_ = sample_random_uniform_coefficients(
        dim_surrogate_context=dim_surrogate_context,
        dim_context=dim_context,
        dim_action=dim_action_context,
        random_=random_,
    )
    context_values = np.tile(context_ @ context_coef_, (n_actions, 1)).T
    
    action_values = np.tile(action_coef_ @ action_context_.T, (datasize, 1))
    i, h, k = surrogate_context_.shape
    surrogate_context_reshaped = surrogate_context_.reshape(i*h, k)
    surrogate_values_reshaped = surrogate_context_reshaped @ surrogate_coef_
    surrogate_values = surrogate_values_reshaped.reshape(i, h)
    context_action_values = np.einsum('ij,jk,hk->ih', context_, context_action_coef_, action_context_)
    context_surrogate_values = np.einsum('ij,jk,ihk->ih', context_, context_surrogate_coef_, surrogate_context_)
    action_surrogate_values = np.einsum('hj,jk,ihk->ih', action_context_, action_surrogate_coef_, surrogate_context_)

    rev_lambda_ = 1.0 - lambda_
    rest_s = context_values+action_values+context_action_values+context_surrogate_values+action_surrogate_values
    non_s_rew = rev_lambda_*rest_s
    s_rew = lambda_*surrogate_values
    q_x_a_f = non_s_rew + s_rew
    
    return q_x_a_f

### Function to calculate expected rewards q(x, a, s)
def expected_reward_function(
    context: np.ndarray,
    action_context: np.ndarray,
    surrogate_context: np.ndarray,
    lambda_: float,
    random_state: int,
    **kwargs,
) -> np.ndarray:
    """Generate surrogate rewards given contexts and action contexts."""
    poly = PolynomialFeatures(degree=1)
    context_ = poly.fit_transform(context)
    action_context_ = poly.fit_transform(action_context)
    surrogate_context_ = poly.fit_transform(surrogate_context)

    datasize, dim_context = context_.shape
    n_actions, dim_action_context = action_context_.shape
    n_surrogate, dim_surrogate_context = surrogate_context_.shape
    random_ = check_random_state(random_state)

    context_coef_, action_coef_, surrogate_coef_, context_action_coef_, context_surrogate_coef_, action_surrogate_coef_ = sample_random_uniform_coefficients(
        dim_surrogate_context=dim_surrogate_context,
        dim_context=dim_context,
        dim_action=dim_action_context,
        random_=random_,
    )

    context_values = context_coef_ @ context_.T
    
    action_values = action_coef_ @ action_context_.T
    
    surrogate_values = surrogate_coef_ @ surrogate_context_.T

    context_action_values = np.einsum('ij,jk,ik->i', context_, context_action_coef_, action_context_)

    context_surrogate_values = np.einsum('ij,jk,ik->i', context_, context_surrogate_coef_, surrogate_context_)
    
    action_surrogate_values = np.einsum('ij,jk,ik->i', action_context_, action_surrogate_coef_, surrogate_context_)

    rev_lambda_ = 1.0 - lambda_
    rest_s = context_values+action_values+context_action_values+context_surrogate_values+action_surrogate_values
    non_s_rew = rev_lambda_*rest_s
    s_rew = lambda_*surrogate_values
    q_x_a_s = non_s_rew + s_rew

    return q_x_a_s, surrogate_coef_

def f_s_function(surrogate_rewards: np.ndarray, surrogate_coef_: np.ndarray) -> float:
    # Exclude the first element from surrogate_coef_
    poly = PolynomialFeatures(degree=1)
    surrogate_rewards_ = poly.fit_transform(surrogate_rewards)
    coef_excluding_first = surrogate_coef_[1:]
    # Find the maximum value and its index
    max_value = np.max(coef_excluding_first)
    max_index = np.argmax(coef_excluding_first) + 1  # Adding 1 to compensate for the excluded element
    # Multiply with the corresponding element in surrogate_rewards
    result = max_value * surrogate_rewards_[:, max_index]
    return result

def all_f_s_function(surrogate_rewards: np.ndarray, surrogate_coef_: np.ndarray) -> float:
    # Exclude the first element from surrogate_coef_
    poly = PolynomialFeatures(degree=1)
    bias = np.ones((surrogate_rewards.shape[0], surrogate_rewards.shape[1], 1))
    surrogate_rewards_ = np.concatenate([bias, surrogate_rewards], axis=2)
    coef_excluding_first = surrogate_coef_[1:]
    # Find the maximum value and its index
    max_value = np.max(coef_excluding_first)
    max_index = np.argmax(coef_excluding_first) + 1  # Adding 1 to compensate for the excluded element
    # Multiply with the corresponding element in surrogate_rewards
    result = max_value * surrogate_rewards_[:, :, max_index]
    return result

def s_sum_function(surrogate_rewards: np.ndarray, alpha_noise: float, random_state: int) -> np.ndarray:
    random_ = check_random_state(random_state)
    
    last_dim = surrogate_rewards.shape[-1]
    
    base_noise = np.ones(last_dim)
    # base_noise is 1/number of dimensions
    base_noise = base_noise / 2
    base_noise[1]=0.1
    base_noise[2]=-1.6
    base_noise[3] = 0.3
    # base_noise[3] = 0.1
    # base_noise[-3] = 0.1
    
    random_noise = random_.normal(loc=0, scale=alpha_noise, size=last_dim)
    
    total_noise = base_noise + random_noise
    
    max_noise = np.argmax(total_noise)
    
    noisy_rewards = surrogate_rewards * total_noise
    
    if surrogate_rewards.ndim == 2:
        axis = 1
    elif surrogate_rewards.ndim == 3:
        axis = 2
    else:
        raise ValueError("Invalid shape for surrogate_rewards. Expected 2 or 3 dimensions.")
    
    result = np.sum(noisy_rewards, axis=axis)
    
    return result, max_noise



from dataclasses import dataclass
from obp.dataset import SyntheticBanditDataset
from obp.dataset import BaseBanditDataset
import numpy as np
from sklearn.preprocessing import OneHotEncoder
from typing import Callable, Tuple
from typing import Optional
from obp.dataset import linear_behavior_policy

@dataclass
class BanditDatasetWithSurrogate(BaseBanditDataset):
    n_actions: int
    reward_type: str = RewardType.BINARY.value
    beta: float = 1.0
    alpha_noise: float =0.01
    random_state: int = 12345
    p_o: float = 0.5

    def __post_init__(self):
        check_scalar(self.n_actions, "n_actions", int, min_val=2)
        check_scalar(self.beta, "beta", (int, float))
        check_scalar(self.alpha_noise, "alpha_noise", (int, float), min_val=0)
        if RewardType(self.reward_type) not in [
            RewardType.BINARY,
            RewardType.CONTINUOUS,
        ]:
            raise ValueError(
                f"`reward_type` must be either '{RewardType.BINARY.value}' or '{RewardType.CONTINUOUS.value}',"
                f"but {self.reward_type} is given.'"
            )
        self.surrogate_reward_min = 0
        self.surrogate_reward_max = 1e10
        if RewardType(self.reward_type) == RewardType.CONTINUOUS:
            self.reward_min = 0
            self.reward_max = 1e10
        self.action_context = np.eye(self.n_actions, dtype=int)

        if self.random_state is None:
            raise ValueError("`random_state` must be given")
        self.random_ = check_random_state(self.random_state)
    
    def sample_contextfree_expected_reward(self) -> np.ndarray:
        """Sample expected reward for each action from the uniform distribution."""
        return self.random_.uniform(size=self.n_actions)
    
    def sample_surogate_reward_given_expected_surrogate_reward(
        self,
        expected_surrogate_reward_factual: np.ndarray,
    ) -> np.ndarray:
        mean = expected_surrogate_reward_factual
        surrogate_rewards = self.random_.normal(loc=mean, scale=self.s_noise)
        p_s_xa = norm.pdf(x=surrogate_rewards, loc=expected_surrogate_reward_factual, scale=self.s_noise)
        
        return surrogate_rewards, p_s_xa
    
    def sample_reward_given_expected_reward(
        self,
        expected_reward_factual: np.ndarray,
    ) -> np.ndarray:
        mean = expected_reward_factual
        rewards = self.random_.normal(loc=mean, scale=self.reward_std)
        return rewards

    def obtain_batch_bandit_feedback(self, expected_rewards_df: pd.DataFrame, user_features_df: pd.DataFrame, test: bool, n_rounds: int=10000) -> BanditFeedback:
        start_time = time.time()
        random_user_ids = expected_rewards_df['user_id'].unique()
        if test==False:
            round_user_ids = self.random_.choice(random_user_ids, size=n_rounds, replace=True)
        else:
            round_user_ids = random_user_ids
            n_rounds=len(round_user_ids)

        columns = user_features_df.columns[1:]
        only_userid_df = pd.DataFrame({'user_id': round_user_ids})

        merged_only_userid_df = pd.merge(only_userid_df, user_features_df, on='user_id', how='left')

        contexts = merged_only_userid_df.drop('user_id', axis=1)

        contexts = contexts.values.astype('float64')
        data_loaded_time = time.time()
        # print(f"Data loaded in {data_loaded_time-start_time} seconds")
        
        pi_b_logits = linear_behavior_policy(
            context=contexts,
            action_context=self.action_context,
            random_state=self.random_state-100,
        )
        pi_b = softmax(self.beta * pi_b_logits)
        actions = sample_action_fast(pi_b, random_state=self.random_state)
        action_counts = np.bincount(actions, minlength=self.n_actions)

        feature_cols =['watch_ratio'] + ['watch_ratio'] + ['days_since_upload'] + ['video_duration']
        all_f_x_a = np.zeros((n_rounds, self.n_actions, len(feature_cols)))
        s_dim = len(feature_cols)
        all_q_x_a_f = np.zeros((n_rounds, self.n_actions))

        round_user_ids_df = pd.DataFrame({'user_id': round_user_ids, 'order': range(len(round_user_ids))})

        merged_df = pd.merge(round_user_ids_df, expected_rewards_df, on='user_id').sort_values(by=['order', 'video_id'])
        merged_df = merged_df.groupby('order').head(self.n_actions)

        n_rounds = len(round_user_ids)
        all_f_x_a = np.zeros((n_rounds, self.n_actions, len(feature_cols)))
        all_q_x_a_f = np.zeros((n_rounds, self.n_actions))

        time_all_f_x_a = time.time()
        
        if not merged_df.empty:
            all_f_x_a = merged_df[feature_cols].to_numpy().reshape(n_rounds, self.n_actions, len(feature_cols))
            all_f_x_a[:,:,1] = np.where(all_f_x_a[:,:,1] >= 0.4, 0, -1)
            all_f_x_a[:,:,0] = np.where(all_f_x_a[:,:,0] >= 1.1, 1, 0)
            all_q_x_a_f = merged_df["watch_ratio"].to_numpy().reshape(n_rounds, self.n_actions)


                        
        f_x_a = all_f_x_a[np.arange(actions.shape[0]), actions]
        surrogate_rewards=f_x_a
        q_x_a_f = all_q_x_a_f[np.arange(actions.shape[0]), actions]

        rewards = q_x_a_f
    
        
        pscores=pi_b[np.arange(n_rounds), actions]
        
        f_sum, max_noise = s_sum_function(surrogate_rewards=all_f_x_a, alpha_noise=self.alpha_noise, random_state=12345)
        s_sum = f_sum[np.arange(actions.shape[0]), actions]
        
        f_s = surrogate_rewards[:, 1]
        expected_f_s = all_f_x_a[:, :, 1]
        
        # )
        # if RewardType(self.reward_type) == RewardType.CONTINUOUS:
        #     mean = f_s
        #     a = (self.reward_min - mean) / self.reward_std
        #     b = (self.reward_max - mean) / self.reward_std
        #     f_s = truncnorm.stats(
        #         a=a, b=b, loc=mean, scale=self.reward_std, moments="m"
        #     )
        # else:
        #     f_s=sigmoid(f_s)
        
        # get arrays of only the observed data (will be used for easier implementation)
        obs_list = np.zeros(n_rounds)
        num_ones = int(self.p_o*n_rounds)
        ones_indices = self.random_.choice(n_rounds, num_ones, replace=False)
        obs_list[ones_indices] = 1
        obs_contexts = contexts[obs_list == 1]
        obs_actions = actions[obs_list == 1]
        obs_f_x_a = f_x_a[obs_list == 1]
        obs_q_x_a_f = q_x_a_f[obs_list == 1]
        obs_rewards = rewards[obs_list == 1]
        obs_surrogate_rewards = surrogate_rewards[obs_list == 1]
        obs_pi_b = pi_b[obs_list == 1]
        obs_pscores = pscores[obs_list == 1]
        obs_s_sum = s_sum[obs_list == 1]
        
        end_time = time.time() 
        # print(f"Time to create the dataset: {end_time-start_time} seconds")
        return dict(
            n_rounds=n_rounds,
            n_actions=self.n_actions,
            contexts=contexts,
            action_context=self.action_context,
            actions=actions,
            all_f_x_a=all_f_x_a,
            f_x_a = f_x_a,
            q_x_a_f=q_x_a_f,
            all_q_x_a_f = all_q_x_a_f,
            surrogate_rewards=surrogate_rewards,
            rewards=rewards,
            pi_b=pi_b[:, :, np.newaxis],
            pscores=pscores,
            obs_list=obs_list,
            obs_contexts=obs_contexts,
            obs_actions=obs_actions,
            obs_f_x_a=obs_f_x_a,
            obs_q_x_a_f=obs_q_x_a_f,
            obs_rewards=obs_rewards,    
            obs_surrogate_rewards=obs_surrogate_rewards,
            obs_pi_b=obs_pi_b[:, :, np.newaxis],
            obs_pscores=obs_pscores,
            f_s=f_s,
            expected_f_s = expected_f_s,
            s_sum = s_sum,
            f_sum = f_sum,
            dim_context = len(columns),
            s_dim = s_dim,
            obs_s_sum = obs_s_sum,
            user_ids=random_user_ids,
            round_user_ids=round_user_ids,
        )
        
    def calc_ground_truth_policy_value(
        self, expected_reward: np.ndarray, action_dist: np.ndarray
    ) -> float:
        """Calculate the policy value of given action distribution on the given expected_reward.

        Parameters
        -----------
        expected_reward: array-like, shape (n_rounds, n_actions)
            Expected reward given context (:math:`x`) and action (:math:`a`), i.e., :math:`q(x,a):=\\mathbb{E}[r|x,a]`.
            This is often the `expected_reward` of the test set of logged bandit data.

        action_dist: array-like, shape (n_rounds, n_actions, len_list)
            Action choice probabilities of the evaluation policy (can be deterministic), i.e., :math:`\\pi_e(a_i|x_i)`.

        Returns
        ----------
        policy_value: float
            The policy value of the given action distribution on the given logged bandit data.

        """
        check_array(array=expected_reward, name="expected_reward", expected_dim=2)
        check_array(array=action_dist, name="action_dist", expected_dim=3)
        if expected_reward.shape[0] != action_dist.shape[0]:
            raise ValueError(
                "Expected `expected_reward.shape[0] = action_dist.shape[0]`, but found it False"
            )
        if expected_reward.shape[1] != action_dist.shape[1]:
            raise ValueError(
                "Expected `expected_reward.shape[1] = action_dist.shape[1]`, but found it False"
            )
        max_rewards_per_round = np.max(expected_reward, axis=1)
        average_max_reward = np.mean(max_rewards_per_round)
        expected_reward_given_act_dist = np.average(expected_reward, weights=action_dist[:, :, 0], axis=1).mean()
        expected_reward_uniform_policy = np.mean(expected_reward)
        def validate(value, name):
            if math.isnan(value) or math.isinf(value):
                raise ValueError(f"{name} contains invalid value {value}")

        validate(expected_reward_given_act_dist, "expected_reward_given_act_dist")
        validate(expected_reward_uniform_policy, "expected_reward_uniform_policy")
        validate(average_max_reward, "average_max_reward")
        denominator = average_max_reward - expected_reward_uniform_policy
        if denominator == 0:
            # Handle division by zero case, maybe log an error or return a special value
            raise ValueError(f"average max {average_max_reward} == expected reward uniform policy {expected_reward_uniform_policy}")
        # elif (expected_reward_given_act_dist - expected_reward_uniform_policy)<0:
        #     raise ValueError(f"expected reward given action dist {expected_reward_given_act_dist} < expected reward uniform policy {expected_reward_uniform_policy}")
        else:
            return (expected_reward_given_act_dist - expected_reward_uniform_policy) / denominator
        
    def calc_raw_val(
        self, expected_reward: np.ndarray, action_dist: np.ndarray
    ) -> float:
        """Calculate the policy value of given action distribution on the given expected_reward.

        Parameters
        -----------
        expected_reward: array-like, shape (n_rounds, n_actions)
            Expected reward given context (:math:`x`) and action (:math:`a`), i.e., :math:`q(x,a):=\\mathbb{E}[r|x,a]`.
            This is often the `expected_reward` of the test set of logged bandit data.

        action_dist: array-like, shape (n_rounds, n_actions, len_list)
            Action choice probabilities of the evaluation policy (can be deterministic), i.e., :math:`\\pi_e(a_i|x_i)`.

        Returns
        ----------
        policy_value: float
            The policy value of the given action distribution on the given logged bandit data.

        """

        expected_reward_given_act_dist = np.average(expected_reward, weights=action_dist[:, :, 0], axis=1).mean()
        return expected_reward_given_act_dist
    
    def calc_full_policy_value(
        self, expected_reward: np.ndarray, expected_surrogate_reward: np.ndarray, action_dist: np.ndarray, beta:float
    ) -> float:
        """Calculate the full policy value of given action distribution on the given expected_reward.

        Parameters
        -----------
        expected_reward: array-like, shape (n_rounds, n_actions)
            Expected reward given context (:math:`x`) and action (:math:`a`), i.e., :math:`q(x,a):=\\mathbb{E}[r|x,a]`.
            This is often the `expected_reward` of the test set of logged bandit data.

        action_dist: array-like, shape (n_rounds, n_actions, len_list)
            Action choice probabilities of the evaluation policy (can be deterministic), i.e., :math:`\\pi_e(a_i|x_i)`.
        
        beta: float
            The temperature parameter of the full policy value function.
            Beta dicides whether to prioritize the reward or the surrogate reward.
            
        Returns
        ----------
        policy_value: float
            The policy value of the given action distribution on the given logged bandit data.

        """
        
        full_expected_reward = (1-beta)*expected_reward + (beta*expected_surrogate_reward)
        max_rewards_per_round = np.max(full_expected_reward, axis=1)
        average_max_reward = np.mean(max_rewards_per_round)
        expected_reward_given_act_dist = np.average(full_expected_reward, weights=action_dist[:, :, 0], axis=1).mean()
        expected_reward_uniform_policy = np.mean(full_expected_reward)
        return (expected_reward_given_act_dist-expected_reward_uniform_policy)/(average_max_reward-expected_reward_uniform_policy)
    