# !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np
import os.path
from sklearn import preprocessing


# Class for Combinatorial Dueling Bandit Environment
class CombDuelingEnv:
    def __init__(self, db_value_function, reward_function, dim, total_arms, super_arms, size, noise, suboptimality_gap, seed):
        self.db_value_function = db_value_function
        self.reward_function = reward_function
        self.dim = dim
        self.total_arms = total_arms
        self.super_arms = super_arms
        self.size = size
        self.noise = noise
        self.suboptimality_gap = suboptimality_gap    
        np.random.seed(seed)

        # Set instance name
        self.problem_name = '{}_{}_{}_{}_{}_{}_{}_{}_{}'.format(
            db_value_function, 
            reward_function,
            dim, 
            total_arms, 
            super_arms,
            size, 
            noise, 
            suboptimality_gap,
            seed
        )
        
        # Variable that keep track of environment state
        self.t = 0                  # Sample counter of context-arms
        self.context_arms = None    # Context-arms for the current time step
        
        # Load existing or generate dataset
        self.get_db_problem()
        
    # Get dataset for the dueling bandit problem with desired optimility gap
    def get_db_problem(self):
        path_to_file = "data/problems/db_{}.npz".format(self.problem_name)
        if os.path.exists(path_to_file):
            # Loading existing dataset from the file    
            # print('Loading existing dataset from the file')
            load_data = np.load(path_to_file)    
            self.all_context_arms = load_data['all_context_arms']  
            self.theta = load_data['theta']

            # print(self.all_context_arms)
            
        else:            
            # Generating dataset with desired optimality gap
            # print('Generating dataset with desired optimality gap')
            # Generate random theta
            self.theta = np.random.uniform(low=-1, high=1, size=(self.dim))
            
            # Generate dataset with desired optimality gap
            context_arms_list = []       
            for _ in range(self.size):
                # Genrate all context_arm pairs
                context_arms = np.random.uniform(low=-1, high=1, size=(self.total_arms, self.dim))

                # Add to list
                context_arms_list.append(context_arms)
            
            # Convert to numpy array
            self.all_context_arms = np.array(context_arms_list)
            
            # Save problem instance and corresponding theta
            # print('Generation completed and saving problem instance and corresponding theta')
            os.makedirs(os.path.dirname(path_to_file), exist_ok=True)
            np.savez(path_to_file, all_context_arms = self.all_context_arms, theta = self.theta)    
            
    # Different latent reward functions
    def get_value(self, x):

        if self.db_value_function == 'linear':
            return x.dot(self.theta)
        
        elif self.db_value_function == 'square':
            return 10*(x.dot(self.theta))**2
        
        elif self.db_value_function == 'square20':
            return 20*(x.dot(self.theta))**2
        
        elif self.db_value_function == 'square30':
            return 30*(x.dot(self.theta))**2

        elif self.db_value_function == 'square40':
            return 40*(x.dot(self.theta))**2
        
        elif self.db_value_function == 'square50':
            return 50*(x.dot(self.theta))**2
        
        elif self.db_value_function == 'cosine':    
            return np.cos(3*x.dot(self.theta))
        
        elif self.db_value_function == 'cosine3':    
            return 3*np.cos(x.dot(self.theta))
        
        elif self.db_value_function == 'cosine5':    
            return 5*np.cos(x.dot(self.theta))
        
        elif self.db_value_function == 'cosine10':    
            return 10*np.cos(x.dot(self.theta)) 
        
        elif self.db_value_function == 'cosine20':    
            return 20*np.cos(x.dot(self.theta))
        
        elif self.db_value_function == 'cosine25':    
            return 25*np.cos(x.dot(self.theta))
        
        elif self.db_value_function == 'sine':
            return np.sin(3*x.dot(self.theta))
        
        elif self.db_value_function == 'levy':
            w = 1 + (x - 1.0) / 4.0
            part1 = np.sin(np.pi * w[0]) ** 2
            part2 = np.sum((w[1:self.dim - 1] - 1) ** 2 * (1 + 10 * np.sin(np.pi * w[1:self.dim - 1] + 1) ** 2))
            part3 = (w[self.dim - 1] - 1) ** 2 * (1 + np.sin(2 * np.pi * w[self.dim - 1])**2)
            return part1 + part2 + part3
        
        elif self.db_value_function == 'ackley':
            a, b, c = -2, 0.2, np.pi
            part1 = -a * np.exp(-b / np.sqrt(self.dim) * np.linalg.norm(x, axis=-1))
            part2 = -(np.exp(np.mean(np.cos(c * x), axis=-1)))
            return 10*(part1 + part2 + np.e + a)
        
        else:
            raise ValueError('Unknown problem: {}'.format(self.db_value_function))
        

    # Get context-arms for the given time step
    def get_context_arms(self):
        self.context_arms = self.all_context_arms[self.t]
        self.t += 1
        return self.context_arms     
    
    # Get preference feedback for the selected arms
    def get_feedback(self, arm1, arm2):
        latent_reward1 = self.get_value(self.context_arms[arm1])
        latent_reward2 = self.get_value(self.context_arms[arm2])
        y_prob = 1.0/(1.0 + np.exp(-self.noise*(latent_reward1 - latent_reward2)))
        y = np.random.binomial(1, y_prob)
        
        return y
    
    def get_reward(self, arm_set_values):
        
        if len(arm_set_values) != self.super_arms:
            raise ValueError('Super Arm size not right: {}'.format(len(arm_set_values)))
        
        # add the value of each arm to form the reward of super arm
        if self.reward_function == 'add':
            arm_set_reward = 0
            for value in arm_set_values:
                # print('x', value)
                arm_set_reward = arm_set_reward + value
            return arm_set_reward
        
        else:
            raise ValueError('Unknown reward function: {}'.format(self.reward_function))

    # Get regret for the selected arms
    def get_regret(self, arm_set1, arm_set2):
        
        if self.reward_function == 'add':
        
            values = self.get_value(self.context_arms)


            
            # find the indices of top-k arms
            topk_indices = np.argsort(values)[-self.super_arms:][::-1]
            # compute top-k arms rewards
            topk_reward_sum = np.sum(values[topk_indices])
            max_reward = topk_reward_sum

            arm_set1_values = values[arm_set1]
            arm_set2_values = values[arm_set2]
            # print('armvalue',arm_set1, arm_set1_values)
            # print('armvalue',arm_set2, arm_set2_values)
            reward1 = self.get_reward(arm_set1_values)
            reward2 = self.get_reward(arm_set2_values)
            average_regret = max_reward - ((reward1 + reward2)/2)
            weak_regret = max_reward - max(reward1, reward2)
            # print('-------------')
            # print(values, arm_set1, arm_set2)
            # print('reward', max_reward, reward1, reward2)
            
            return average_regret, weak_regret
        
        else:
            raise ValueError('Unknown reward function: {}'.format(self.reward_function))
        
    # Reset environment variables
    def reset(self):
        # Reset environment variables
        self.t = 0
        
        # Shuffle data
        np.random.shuffle(self.all_context_arms)
    
    # Finish the environment
    def finish(self):
        return self.t == self.size
    
    
# For sanity check
if __name__ == '__main__':
    # Test the DuelingEnv class
    env = CombDuelingEnv('square', 'add', 5, 5, 2, 100, 0.1, 1.0, 1)
    print(env.get_context_arms())
    print(env.get_feedback(0, 1))
    print(env.get_regret([2,3], [1,4]))
    env.reset()
    print(env.get_context_arms())
    print(env.get_feedback(0, 1))
    print(env.get_regret([2,3], [1,4]))
    print(env.get_regret([2,3], [0,4]))
    del env