# !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np
import os.path
from sklearn import preprocessing
from sklearn.metrics import root_mean_squared_error
from sklearn.preprocessing import PolynomialFeatures

# Class for Strategic Bandit Environment with estimate of optimal perturbation 
class StrategicEnv:
    def __init__(self, 
                 reward_function='linear', 
                 contexts=1000, 
                 arms=5, 
                 context_dim=2, 
                 arm_dim=2, 
                 sigma=0.1, 
                 delta_max=1.0, 
                 delta_sigma=0.2, 
                 strategic_nature="static", 
                 poly="poly_false", 
                 vary_arms="vary_arms", 
                 strategic_arm="strategic_arm", 
                 seed=1.0):
        """ 
        Strategic Bandit Environment with combined context and arm features
        Args:
            reward_function: Latent reward function type
            contexts: Number of contexts
            arms: Number of arms/agents
            context_dim: Dimension of context features
            arm_dim: Dimension of arm features
            sigma: Standard deviation of Gaussian noise
            delta_max: Maximum perturbation allowed
            delta_sigma: Standard deviation of Gaussian for perturbation
            strategic_nature: Nature of the strategic behavior (default: "static")
            eta: Perturbation step size
            poly: Whether to use polynomial features (default: poly_false)
            vary_arms: Whether to vary arms across rounds (default: vary_arms)
            strategic_arm: Whether there is a strategic arm (default: strategic_arm)
            seed: Random seed for reproducibility (default: 1)
        """
        self.reward_function    = reward_function
        self.contexts           = contexts
        self.arms               = arms
        self.context_dim        = context_dim
        self.arm_dim            = arm_dim
        self.sigma              = sigma
        self.delta_max          = delta_max
        self.delta_sigma        = delta_sigma
        self.stratic_nature     = strategic_nature
        self.poly               = poly
        self.vary_arms          = vary_arms
        self.strategic_arm      = strategic_arm

        # Set random seed
        np.random.seed(seed)

        # Initialize context-arm feature dimension
        self.dim = context_dim + arm_dim

        # Set instance name
        self.problem_name = 'strategic_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format(
            reward_function, 
            contexts, 
            arms, 
            context_dim, 
            arm_dim, 
            sigma, 
            delta_max, 
            delta_sigma, 
            strategic_nature,
            poly, 
            vary_arms,
            strategic_arm,
            seed
        )
        
        # print('Problem Name: {}'.format(self.problem_name))

        # Variable that keep track of environment state
        self.iter = 0

        # Select an strategic agent/arm randomly at the beginning
        self.strategic_arm_index = np.random.randint(0, self.arms)   

        # Load existing or generate dataset
        self.get_syn_problem()


    # Get dataset for the dueling bandit problem
    def get_syn_problem(self):
        path_to_file = "data/problems/syn_{}.npz".format(self.problem_name)
        if os.path.exists(path_to_file):
            # Loading existing dataset from the file    
            load_data = np.load(path_to_file)    
            self.all_context_arms = load_data['all_context_arms']  
            self.all_rewards = load_data['all_rewards']
            self.optimal_delta = load_data['optimal_delta']
            self.theta = load_data['theta']
            self.dim = len(self.all_context_arms[0, 0])

        else:         
            # Generate contexts
            all_contexts = np.random.uniform(low=0, high=1, size=(self.contexts, self.context_dim))

            # Check if arms are to be varied across rounds
            if self.vary_arms == "vary_arms":
                all_arms = np.random.uniform(low=0, high=1, size=(self.contexts*self.arms, self.arm_dim))

            elif self.vary_arms == "fix_arms":
                # Generate fixed arms
                fixed_arms = np.random.uniform(low=0, high=1, size=(self.arms, self.arm_dim))
                all_arms = np.tile(fixed_arms, (self.contexts, 1))

            else:
                raise ValueError('Unknown option for vary_arms: {}'.format(self.vary_arms))

            # Replicate every context for each arm 
            context_repeated = np.repeat(all_contexts, self.arms, axis=0)

            # Combine context and arm features
            all_context_arms = np.concatenate([context_repeated, all_arms], axis=-1)

            # Polynomial feature transformation
            if self.poly == "poly_true":
                poly = PolynomialFeatures(degree=2)                     # Polynomial feature transformer   
                hd_data = poly.fit_transform(all_context_arms)          # Lifting the data
                hd_data_no1 = np.delete(hd_data, 0, axis=1)             # Removing first column as it is 1 only
                all_context_arms = np.delete(hd_data_no1, -1, axis=1)   # Removing last column as it is 1 only

            elif self.poly == "poly_false":
                pass

            else:
                raise ValueError('Unknown option for poly: {}'.format(self.poly))
            
            # Reshape the context-arms into shape (contexts, arms, context_dim + arm_dim)
            self.all_context_arms = all_context_arms.reshape(self.contexts, self.arms, -1)

            # Normalize the context-arm features
            # self.all_context_arms = preprocessing.normalize(self.all_context_arms, axis=2, norm='l2')

            # Feature dimension
            self.dim = len(self.all_context_arms[0, 0])  

            # Generate random theta for the latent reward function
            self.theta = np.random.uniform(low=0, high=2, size=(self.dim, ))
            # self.theta = self.theta / np.linalg.norm(self.theta)   # Normalizing theta

            # Get rewards for all context-arms
            self.all_rewards = self.get_reward(self.all_context_arms)

            # Get the maximum rewards for all contexts
            self.max_rewards =  np.max(self.all_rewards, axis=1)

            # Find how much to perturb the strategic arm to maximize its reward
            self.optimal_delta = np.zeros((self.contexts, ))
            if self.strategic_arm == "no_strategic_arm":
                pass

            elif self.strategic_arm == "strategic_arm":
                for a in range(self.contexts):
                    #  Context-arm features of the strategic arm
                    strategic_arm_features = self.all_context_arms[a, self.strategic_arm_index]
                    self.optimal_delta[a] = self.find_optimal_perturbation(strategic_arm_features, self.max_rewards[a])
            
            else:
                raise ValueError('Unknown option for strategic_arm: {}'.format(self.strategic_arm))
            
            # Save problem instance and corresponding theta
            np.savez(
                path_to_file, 
                all_context_arms = self.all_context_arms,
                all_rewards = self.all_rewards,
                optimal_delta = self.optimal_delta,
                theta = self.theta
            )


    # Find the optimal perturbation for the strategic arm
    def find_optimal_perturbation(self, arm_features, max_reward):
        # Using bisection method to find the optimal perturbation
        delta_low = 0
        delta_high = self.delta_max

        # Compute rewards at the boundaries
        reward_low = self.get_reward((1+delta_low)*arm_features)
        reward_high = self.get_reward((1+delta_high)*arm_features)

        # Check if minimum perturbation is enough
        if reward_low >= max_reward:
            return delta_low
        
        # Check if maximum perturbation is not enough
        if reward_high <= max_reward:
            return delta_high
        
        # Bisection method
        iterations = 100
        tolerance = 1e-8
        for _ in range(iterations):
            delta_mid = 0.5 * (delta_low + delta_high)
            reward_mid = self.get_reward((1+delta_mid)*arm_features)

            if abs(reward_mid - max_reward) < tolerance or (delta_high - delta_low) < tolerance:
                return delta_mid

            if reward_mid < max_reward:
                delta_low = delta_mid
                reward_low = reward_mid

            else:
                delta_high = delta_mid
                reward_high = reward_mid
        
        # Return the midpoint if convergence not reached
        return 0.5 * (delta_low + delta_high)
    

    # Different latent reward functions
    def get_reward(self, x):     
        if self.reward_function == 'linear': # Problem Instance I
            return 5*x.dot(self.theta)
        
        elif self.reward_function == 'linear2': # Problem Instance II
            return 2*x.dot(self.theta)
        
        elif self.reward_function == 'linear1': # Problem Instance III
            return x.dot(self.theta)

        elif self.reward_function == 'square':
            return 10*(x.dot(self.theta))**2           
        
        elif self.reward_function == 'cubic':
            return (x.dot(self.theta))**3
        
        elif self.reward_function == 'cosine':    
            return np.cos(3*x.dot(self.theta))
        
        elif self.reward_function == 'sine':
            return 2*np.sin(x.dot(self.theta))
        
        elif self.reward_function == 'levy':
            w = 1 + (x - 1.0) / 4.0
            part1 = np.sin(np.pi * w[0]) ** 2
            part2 = np.sum((w[1:self.dim - 1] - 1) ** 2 * (1 + 10 * np.sin(np.pi * w[1:self.dim - 1] + 1) ** 2))
            part3 = (w[self.dim - 1] - 1) ** 2 * (1 + np.sin(2 * np.pi * w[self.dim - 1])**2)
            return part1 + part2 + part3
        
        elif self.reward_function == 'ackley':
            a, b, c = -2, 0.2, np.pi
            part1 = -a * np.exp(-b / np.sqrt(self.dim) * np.linalg.norm(x, axis=-1))
            part2 = -(np.exp(np.mean(np.cos(c * x), axis=-1)))
            return 10*(part1 + part2 + np.e + a)
        
        else:
            raise ValueError('Unknown reward function: {}'.format(self.reward_function))


    # Strategic perturbation of arm features
    def perturb_features(self, arm_features):
        # Strategic perturbation of arm features to maximize its reward
        optimal_delta = self.optimal_delta[self.iter]

        # Add noise in perturbation
        delta_sigma = self.delta_sigma
        if self.stratic_nature == "static":
            delta_sigma = self.delta_sigma 

        elif self.stratic_nature == "adaptive":
            delta_sigma = self.delta_sigma * (1.0/(self.iter+1))
        
        elif self.stratic_nature == "best":
            delta_sigma = self.delta_sigma * (1 - self.iter/self.contexts)
        
        else:
            raise ValueError('Unknown option for strategic_nature: {}'.format(self.stratic_nature))

        # Sample noise from Gaussian
        noise = np.random.normal(0, delta_sigma, 1)

        # Ensure perturbation is within bounds: [0, delta_max]
        delta = np.clip(optimal_delta + noise, 0, self.delta_max)

        # Apply the perturbation to the arm features
        perturbed_features = (1 + delta) * arm_features
        return perturbed_features
    

    # Get context-arms for the given time step
    def get_context_arms(self):
        # Get the current arm features
        context_arms = self.all_context_arms[self.iter].copy()

        # Check if there is no strategic arm
        if self.strategic_arm == "no_strategic_arm":
            return context_arms

        # Perturb the strategic arm features
        strategic_arm_features = context_arms[self.strategic_arm_index]
        context_arms[self.strategic_arm_index] = self.perturb_features(strategic_arm_features)

        return context_arms


    # Get noisy rewards for the selected context-arm
    def get_noisy_reward(self, arm):
        # Get the rewards for the selected arm
        reward = self.all_rewards[self.iter][arm]
        noise = np.random.normal(0, self.sigma, 1)
        return reward + noise


    # Get instantaneous regret for the selected context-arm with respect to all active arms
    def get_active_regret(self, active_arms, arm):
        # Get the rewards for the selected context-arm
        arm_reward = self.all_rewards[self.iter][arm]

        # Get the maximum reward among the active arms
        max_reward = arm_reward
        for a in active_arms:
            max_reward = np.maximum(max_reward, self.all_rewards[self.iter][a])

        # print (active_arms, self.strategic_arm_index, max_reward)
        
        # Update the iteration counter
        self.iter += 1

        return max_reward - arm_reward
    

    # Reset environment variables
    def reset(self):
        # Reset environment variables
        self.iter = 0
        
        # Shuffle context-arms, rewards, and optimal perturbations while keeping their correspondence
        p = np.random.permutation(self.contexts)
        self.all_context_arms = self.all_context_arms[p]
        self.all_rewards = self.all_rewards[p]
        self.optimal_delta = self.optimal_delta[p]
        

    # Finish the environment
    def finish(self):
        return self.iter == self.contexts


# Class for Strategic Bandit Environment with optimal perturbation via gradient ascent
class StrategicEnvGrad:
    def __init__(self, 
                 reward_function='linear', 
                 contexts=1000, 
                 arms=5, 
                 context_dim=2, 
                 arm_dim=2, 
                 sigma=0.1, 
                 delta_max=1.0, 
                 delta_sigma=0.2, 
                 strategic_nature="static", 
                 poly="poly_false", 
                 vary_arms="fixed_arms", 
                 strategic_arm="strategic_arm", 
                 seed=1.0):
        """ 
        Strategic Bandit Environment with combined context and arm features
        Args:
            reward_function: Latent reward function type
            contexts: Number of contexts
            arms: Number of arms/agents
            context_dim: Dimension of context features
            arm_dim: Dimension of arm features
            sigma: Standard deviation of Gaussian noise
            delta_max: Maximum perturbation allowed
            delta_sigma: Standard deviation of Gaussian for perturbation
            strategic_nature: Nature of the strategic behavior (default: "static")
            eta: Perturbation step size
            poly: Whether to use polynomial features (default: poly_false)
            vary_arms: Whether to vary arms across rounds (default: vary_arms)
            strategic_arm: Whether there is a strategic arm (default: strategic_arm)
            seed: Random seed for reproducibility (default: 1)
        """
        self.reward_function    = reward_function
        self.contexts           = contexts
        self.arms               = arms
        self.context_dim        = context_dim
        self.arm_dim            = arm_dim
        self.sigma              = sigma
        self.delta_max          = delta_max
        self.delta_sigma        = delta_sigma
        self.stratic_nature     = strategic_nature
        self.poly               = poly
        self.vary_arms          = vary_arms
        self.strategic_arm      = strategic_arm

        # Set random seed
        np.random.seed(seed)

        # Initialize context-arm feature dimension
        self.dim = context_dim + arm_dim

        # Set instance name
        self.problem_name = 'strategic_grad_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format(
            reward_function, 
            contexts, 
            arms, 
            context_dim, 
            arm_dim, 
            sigma, 
            delta_max, 
            delta_sigma, 
            strategic_nature,
            poly, 
            vary_arms,
            strategic_arm,
            seed
        )
        
        # print('Problem Name: {}'.format(self.problem_name))

        # Variable that keep track of environment state
        self.iter = 0

        # Select an strategic agent/arm randomly at the beginning
        self.strategic_arm_index = np.random.randint(0, self.arms)   

        # Load existing or generate dataset
        self.get_syn_problem()


    # Get dataset for the dueling bandit problem
    def get_syn_problem(self):
        path_to_file = "data/problems/syn_{}.npz".format(self.problem_name)
        if os.path.exists(path_to_file):
            # Loading existing dataset from the file    
            load_data = np.load(path_to_file)    
            self.all_context_arms = load_data['all_context_arms']  
            self.all_rewards = load_data['all_rewards']
            self.theta = load_data['theta']
            self.dim = len(self.all_context_arms[0, 0])

        else:         
            # Generate contexts
            all_contexts = np.random.uniform(low=0, high=1, size=(self.contexts, self.context_dim))

            # Check if arms are to be varied across rounds
            if self.vary_arms == "vary_arms":
                all_arms = np.random.uniform(low=0, high=1, size=(self.contexts*self.arms, self.arm_dim))

            elif self.vary_arms == "fix_arms":
                # Generate fixed arms
                fixed_arms = np.random.uniform(low=0, high=1, size=(self.arms, self.arm_dim))
                all_arms = np.tile(fixed_arms, (self.contexts, 1))

            else:
                raise ValueError('Unknown option for vary_arms: {}'.format(self.vary_arms))

            # Replicate every context for each arm 
            context_repeated = np.repeat(all_contexts, self.arms, axis=0)

            # Combine context and arm features
            all_context_arms = np.concatenate([context_repeated, all_arms], axis=-1)

            # Polynomial feature transformation
            if self.poly == "poly_true":
                poly = PolynomialFeatures(degree=2)                     # Polynomial feature transformer   
                hd_data = poly.fit_transform(all_context_arms)          # Lifting the data
                hd_data_no1 = np.delete(hd_data, 0, axis=1)             # Removing first column as it is 1 only
                all_context_arms = np.delete(hd_data_no1, -1, axis=1)   # Removing last column as it is 1 only

            elif self.poly == "poly_false":
                pass

            else:
                raise ValueError('Unknown option for poly: {}'.format(self.poly))
            
            # Reshape the context-arms into shape (contexts, arms, context_dim + arm_dim)
            self.all_context_arms = all_context_arms.reshape(self.contexts, self.arms, -1)

            # Normalize the context-arm features
            # self.all_context_arms = preprocessing.normalize(self.all_context_arms, axis=2, norm='l2')

            # Feature dimension
            self.dim = len(self.all_context_arms[0, 0]) 

            # Generate random theta for the latent reward function
            self.theta = np.random.uniform(low=0, high=2, size=(self.dim, ))
            # self.theta = self.theta / np.linalg.norm(self.theta)   # Normalizing theta

            # Get rewards for all context-arms
            self.all_rewards = self.get_reward(self.all_context_arms)

            # Get the maximum rewards for all contexts
            self.max_rewards =  np.max(self.all_rewards, axis=1)

            # Save problem instance and corresponding theta
            np.savez(
                path_to_file, 
                all_context_arms = self.all_context_arms,
                all_rewards = self.all_rewards,
                theta = self.theta
            )

        # Initialize perturbation features and direction
        self.perturbation_vector = np.zeros((self.dim, )) 
        self.delta = np.zeros((self.dim, ))


    # Different latent reward functions
    def get_reward(self, x):     
        if self.reward_function == 'linear': # Problem Instance I
            return 5*x.dot(self.theta)
        
        elif self.reward_function == 'linear2': # Problem Instance II
            return 2*x.dot(self.theta)
        
        elif self.reward_function == 'linear1': # Problem Instance III
            return x.dot(self.theta)

        elif self.reward_function == 'square':
            return 10*(x.dot(self.theta))**2           
        
        elif self.reward_function == 'cubic':
            return (x.dot(self.theta))**3
        
        elif self.reward_function == 'cosine':    
            return np.cos(3*x.dot(self.theta))
        
        elif self.reward_function == 'sine':
            return 2*np.sin(x.dot(self.theta))
        
        elif self.reward_function == 'levy':
            w = 1 + (x - 1.0) / 4.0
            part1 = np.sin(np.pi * w[0]) ** 2
            part2 = np.sum((w[1:self.dim - 1] - 1) ** 2 * (1 + 10 * np.sin(np.pi * w[1:self.dim - 1] + 1) ** 2))
            part3 = (w[self.dim - 1] - 1) ** 2 * (1 + np.sin(2 * np.pi * w[self.dim - 1])**2)
            return part1 + part2 + part3
        
        elif self.reward_function == 'ackley':
            a, b, c = -2, 0.2, np.pi
            part1 = -a * np.exp(-b / np.sqrt(self.dim) * np.linalg.norm(x, axis=-1))
            part2 = -(np.exp(np.mean(np.cos(c * x), axis=-1)))
            return 10*(part1 + part2 + np.e + a)
        
        else:
            raise ValueError('Unknown reward function: {}'.format(self.reward_function))


    # Strategic perturbation of arm features using finite-difference stochastic gradient update
    def perturb_features(self, arm_feature, feedback):
        # Get the random perturbation vector
        perturb_vector = np.abs(np.random.randn(self.dim))
        perturb_vector /= np.linalg.norm(perturb_vector)  # Normalize the perturbation vector

        # Add type of perturbation
        delta_fd = 1e-2 * np.linalg.norm(arm_feature)
        eta = 0.05
        if self.stratic_nature == "static": # Moderate changes as default
            delta_fd = 1e-2 * np.linalg.norm(arm_feature)
            eta = 0.05

        elif self.stratic_nature == "conservative":
            delta_fd = 1e-2 * np.linalg.norm(arm_feature)
            eta = 0.01

        elif self.stratic_nature == "exploration":
            delta_fd = 5*1e-2 * np.linalg.norm(arm_feature)
            eta = 0.05

        elif self.stratic_nature == "aggressive":
            delta_fd = 2*1e-2 * np.linalg.norm(arm_feature)
            eta = 0.1

        else:
            raise ValueError('Unknown option for strategic_nature: {}'.format(self.stratic_nature))

        # Perturbation using finite difference in the random direction
        self.perturbation_vector = self.delta + delta_fd * perturb_vector

        # Update direction (delta) by one-point gradient estimate
        self.delta += eta * feedback * perturb_vector

        # Ensure the perturbation norm does not exceed delta_max
        if np.linalg.norm(self.delta) > self.delta_max:
            self.delta = (self.delta_max / np.linalg.norm(self.delta)) * self.delta
    

    # Get context-arms for the given time step
    def get_context_arms(self):
        # Get the current arm features
        context_arms = self.all_context_arms[self.iter].copy()

        # Check if there is no strategic arm
        if self.strategic_arm == "no_strategic_arm":
            return context_arms


        # Reward before perturbation
        reward_before = self.get_reward(context_arms[self.strategic_arm_index])

        # Perturb the strategic arm features
        context_arms[self.strategic_arm_index] += self.perturbation_vector

        # Rewards before and after perturbation
        if reward_before > self.get_reward(context_arms[self.strategic_arm_index]):
            # Throw error if perturbation is not working
            raise ValueError("Perturbation not working!")
        
        return context_arms


    # Get noisy rewards for the selected context-arm
    def get_noisy_reward(self, arm):
        # Get the rewards for the selected arm
        reward = self.all_rewards[self.iter][arm]
        noise = np.random.normal(0, self.sigma, 1)
        return reward + noise


    # Get instantaneous regret for the selected context-arm with respect to all active arms
    def get_active_regret(self, active_arms, arm):
        # Get the rewards for the selected context-arm
        arm_reward = self.all_rewards[self.iter][arm]

        # Get the maximum reward among the active arms
        max_reward = arm_reward
        for a in active_arms:
            max_reward = np.maximum(max_reward, self.all_rewards[self.iter][a])

        # print (active_arms, self.strategic_arm_index, max_reward)
        # Best arm
        # best_arm = np.argmax(self.all_rewards[self.iter])
        # if best_arm != arm:
        #     print (best_arm, arm, max_reward - arm_reward)
        
        # Update perturbation vector using feedback
        if arm == self.strategic_arm_index:
            # Perturb the agent arm using gradient ascent
            self.perturb_features(self.all_context_arms[self.iter][self.strategic_arm_index], 1)
       
        else:
            # Perturb the agent arm using gradient ascent
            self.perturb_features(self.all_context_arms[self.iter][self.strategic_arm_index], 0)
        

        # Update the iteration counter
        self.iter += 1

        return max_reward - arm_reward
    

    # Reset environment variables
    def reset(self):
        # Reset environment variables
        self.iter = 0
        self.perturbation_vector = np.zeros((self.dim, ))
        self.delta = np.zeros((self.dim, ))
        
        # Shuffle context-arms, rewards, and optimal perturbations while keeping their correspondence
        p = np.random.permutation(self.contexts)
        self.all_context_arms = self.all_context_arms[p]
        self.all_rewards = self.all_rewards[p]
        

    # Finish the environment
    def finish(self):
        return self.iter == self.contexts


# For sanity check
if __name__ == '__main__':

    # Test the bandit class
    env = StrategicEnv('linear', 200, 5, 2, 2, 0.1, 1.0, 0.2, "static", "poly_false", "vary_arms", "strategic_arm", 1)
    print(env.get_context_arms())
    print(env.get_noisy_reward(0))
    print(env.get_active_regret(0, [0,1,2,3,4]))
    print(env.get_active_regret(0, [0,1,2,4]))
    print(env.get_active_regret(0, [0,1,2,3]))
    print(env.get_active_regret(0, [0,1,2,3,4]))

    del env