# !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np
import os.path
from sklearn import preprocessing


# Class for Dueling Bandit Environment
class DuelingEnv:
    def __init__(self, db_function, dim, arms, size, noise, suboptimality_gap, seed):
        self.db_function = db_function
        self.dim = dim
        self.arms = arms
        self.size = size
        self.noise = noise
        self.suboptimality_gap = suboptimality_gap    
        np.random.seed(seed)

        # Set instance name
        self.problem_name = '{}_{}_{}_{}_{}_{}_{}'.format(
            db_function, 
            dim, 
            arms, 
            size, 
            noise, 
            suboptimality_gap,
            seed
        )
        
        # Variable that keep track of environment state
        self.t = 0                  # Sample counter of context-arms
        self.context_arms = None    # Context-arms for the current time step
        
        # Load existing or generate dataset
        self.get_db_problem()
          
    # Get dataset for the dueling bandit problem with desired optimility gap
    def get_db_problem(self):
        path_to_file = "data/problems/db_{}.npz".format(self.problem_name)
        if os.path.exists(path_to_file):
            # Loading existing dataset from the file    
            # print('Loading existing dataset from the file')
            load_data = np.load(path_to_file)    
            self.all_context_arms = load_data['all_context_arms']  
            self.theta = load_data['theta']
            
        else:            
            # Generating dataset with desired optimality gap
            # print('Generating dataset with desired optimality gap')
            # Generate random theta
            self.theta = np.random.uniform(low=-1, high=1, size=(self.dim))
            
            # Generate dataset with desired optimality gap
            context_arms_list = []            
            for _ in range(self.size):
                # Genrate all context_arm pairs
                context_arms = np.random.uniform(low=-1, high=1, size=(self.arms, self.dim))

                # Add to list
                context_arms_list.append(context_arms)
            
            # Convert to numpy array
            self.all_context_arms = np.array(context_arms_list)
            
            # Save problem instance and corresponding theta
            # print('Generation completed and saving problem instance and corresponding theta')
            np.savez(path_to_file, all_context_arms = self.all_context_arms, theta = self.theta)          
    
    # Different latent reward functions
    def get_reward(self, x):       
        if self.db_function == 'linear':
            return x.dot(self.theta)
        
        elif self.db_function == 'square':
            return 10*(x.dot(self.theta))**2
        
        elif self.db_function == 'square20':
            return 20*(x.dot(self.theta))**2
        
        elif self.db_function == 'square30':
            return 30*(x.dot(self.theta))**2

        elif self.db_function == 'square40':
            return 40*(x.dot(self.theta))**2
        
        elif self.db_function == 'square50':
            return 50*(x.dot(self.theta))**2
        
        elif self.db_function == 'cosine':    
            return np.cos(3*x.dot(self.theta))
        
        elif self.db_function == 'cosine3':    
            return 3*np.cos(x.dot(self.theta))
        
        elif self.db_function == 'cosine5':    
            return 5*np.cos(x.dot(self.theta))
        
        elif self.db_function == 'cosine10':    
            return 10*np.cos(x.dot(self.theta)) 
        
        elif self.db_function == 'cosine20':    
            return 20*np.cos(x.dot(self.theta))
        
        elif self.db_function == 'cosine25':    
            return 25*np.cos(x.dot(self.theta))
        
        elif self.db_function == 'sine':
            return np.sin(3*x.dot(self.theta))
        
        elif self.db_function == 'levy':
            w = 1 + (x - 1.0) / 4.0
            part1 = np.sin(np.pi * w[0]) ** 2
            part2 = np.sum((w[1:self.dim - 1] - 1) ** 2 * (1 + 10 * np.sin(np.pi * w[1:self.dim - 1] + 1) ** 2))
            part3 = (w[self.dim - 1] - 1) ** 2 * (1 + np.sin(2 * np.pi * w[self.dim - 1])**2)
            return part1 + part2 + part3
        
        elif self.db_function == 'ackley':
            a, b, c = -2, 0.2, np.pi
            part1 = -a * np.exp(-b / np.sqrt(self.dim) * np.linalg.norm(x, axis=-1))
            part2 = -(np.exp(np.mean(np.cos(c * x), axis=-1)))
            return 10*(part1 + part2 + np.e + a)
        
        else:
            raise ValueError('Unknown problem: {}'.format(self.db_function))
    
    # Get context-arms for the given time step
    def get_context_arms(self):
        self.context_arms = self.all_context_arms[self.t]
        self.t += 1
        return self.context_arms        
    
    # Get preference feedback for the selected arms
    def get_feedback(self, arm1, arm2):
        latent_reward1 = self.get_reward(self.context_arms[arm1])
        latent_reward2 = self.get_reward(self.context_arms[arm2])
        y_prob = 1.0/(1.0 + np.exp(-self.noise*(latent_reward1 - latent_reward2)))
        y = np.random.binomial(1, y_prob)
        
        return y
    
    # Get regret for the selected arms
    def get_regret(self, arm1, arm2):
        rewards = self.get_reward(self.context_arms)
        max_reward = rewards[np.argmax(rewards)]
        reward1 = rewards[arm1]
        reward2 = rewards[arm2]
        average_regret = max_reward - ((reward1 + reward2)/2)
        weak_regret = max_reward - max(reward1, reward2)
        
        return average_regret, weak_regret
    
    # Reset environment variables
    def reset(self):
        # Reset environment variables
        self.t = 0
        
        # Shuffle data
        np.random.shuffle(self.all_context_arms)
    
    # Finish the environment
    def finish(self):
        return self.t == self.size


# Class for Dueling Bandit Environment
class GLMEnv:
    def __init__(self, glm_function, dim, arms, size, noise, suboptimality_gap, seed):
        self.glm_function = glm_function
        self.dim = dim
        self.arms = arms
        self.size = size
        self.noise = noise
        self.suboptimality_gap = suboptimality_gap    
        np.random.seed(seed)

        # Set instance name
        self.problem_name = '{}_{}_{}_{}_{}_{}_{}'.format(
            glm_function, 
            dim, 
            arms, 
            size, 
            noise, 
            suboptimality_gap,
            seed
        )
        
        # Variable that keep track of environment state
        self.t = 0                  # Sample counter of context-arms
        self.context_arms = None    # Context-arms for the current time step
        
        # Load existing or generate dataset
        self.get_glm_problem()
          
    # Get dataset for the dueling bandit problem with desired optimility gap
    def get_glm_problem(self):
        path_to_file = "data/problems/glm_{}.npz".format(self.problem_name)
        if os.path.exists(path_to_file):
            # Loading existing dataset from the file    
            # print('Loading existing dataset from the file')
            load_data = np.load(path_to_file)    
            self.all_context_arms = load_data['all_context_arms']  
            self.theta = load_data['theta']
            
        else:            
            # Generating dataset with desired optimality gap
            # print('Generating dataset with desired optimality gap')
            # Generate random theta
            self.theta = np.random.uniform(low=-1, high=1, size=(self.dim))
            
            # Generate dataset with desired optimality gap
            context_arms_list = []            
            for _ in range(self.size):
                # Genrate all context_arm pairs
                context_arms = np.random.uniform(low=-1, high=1, size=(self.arms, self.dim))

                # Add to list
                context_arms_list.append(context_arms)
            
            # Convert to numpy array
            self.all_context_arms = np.array(context_arms_list)
            
            # Save problem instance and corresponding theta
            # print('Generation completed and saving problem instance and corresponding theta')
            np.savez(path_to_file, all_context_arms = self.all_context_arms, theta = self.theta)          
    
    # Different latent reward functions
    def get_reward(self, x):       
        if self.glm_function == 'linear':
            return x.dot(self.theta)
        
        elif self.glm_function == 'square':
            return 10*(x.dot(self.theta))**2
        
        elif self.glm_function == 'square20':
            return 20*(x.dot(self.theta))**2
        
        elif self.glm_function == 'square30':
            return 30*(x.dot(self.theta))**2

        elif self.glm_function == 'square40':
            return 40*(x.dot(self.theta))**2
        
        elif self.glm_function == 'square50':
            return 50*(x.dot(self.theta))**2
        
        elif self.glm_function == 'cosine':    
            return np.cos(3*x.dot(self.theta))
        
        elif self.glm_function == 'cosine3':    
            return 3*np.cos(x.dot(self.theta))
        
        elif self.glm_function == 'cosine5':    
            return 5*np.cos(x.dot(self.theta))
        
        elif self.glm_function == 'cosine10':    
            return 10*np.cos(x.dot(self.theta)) 
        
        elif self.glm_function == 'cosine20':    
            return 20*np.cos(x.dot(self.theta))
        
        elif self.glm_function == 'cosine25':    
            return 25*np.cos(x.dot(self.theta))
        
        elif self.glm_function == 'sine':
            return np.sin(3*x.dot(self.theta))
        
        else:
            raise ValueError('Unknown problem: {}'.format(self.glm_function))
    
    # Get context-arms for the given time step
    def get_context_arms(self):
        self.context_arms = self.all_context_arms[self.t]
        self.t += 1
        return self.context_arms        
    
    # Get preference feedback for the selected arms
    def get_feedback(self, arm):
        latent_reward = self.get_reward(self.context_arms[arm])
        y_prob = 1.0/(1.0 + np.exp(-self.noise*latent_reward))
        y = np.random.binomial(1, y_prob)
        
        return y
    
    # Get regret for the selected arms
    def get_regret(self, arm):
        rewards = self.get_reward(self.context_arms)
        return rewards[np.argmax(rewards)] - rewards[arm]
    
    # Reset environment variables
    def reset(self):
        # Reset environment variables
        self.t = 0
        
        # Shuffle data
        np.random.shuffle(self.all_context_arms)
    
    # Finish the environment
    def finish(self):
        return self.t == self.size




# For sanity check
if __name__ == '__main__':
    # Test the DuelingEnv class
    env = DuelingEnv('square', 5, 5, 100, 0.1, 1.0, 1)
    print(env.get_context_arms())
    print(env.get_feedback(0, 1))
    print(env.get_regret(2, 1))
    env.reset()
    print(env.get_context_arms())
    print(env.get_feedback(0, 1))
    print(env.get_regret(2, 1))
    del env