# !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np
import os.path
from sklearn import preprocessing
from sklearn.metrics import mean_squared_error


# Class for Dueling Bandit Environment with combined context and arm features
class SyntheticEnv:
    def __init__(self, db_function, contetxs, arms, dim, seed):
        self.db_function    = db_function
        self.contexts       = contetxs
        self.arms           = arms
        self.dim            = dim
        np.random.seed(seed)

        # Set instance name
        self.problem_name = '{}_{}_{}_{}_{}'.format(
            db_function, 
            contetxs,
            arms, 
            dim,
            seed
        )
        
        # Maximum rewards for all contexts
        self.all_context_arms = None
        self.max_rewards = None       
        
        # Load existing or generate dataset
        self.get_syn_problem()
          
    # Get dataset for the dueling bandit problem with desired optimility gap
    def get_syn_problem(self):
        path_to_file = "data/problems/syn_{}.npz".format(self.problem_name)
        if os.path.exists(path_to_file):
            # Loading existing dataset from the file    
            load_data = np.load(path_to_file)    
            self.all_context_arms = load_data['all_context_arms']  
            self.max_rewards = load_data['max_rewards']
            self.theta = load_data['theta']
        else:            
            # Generate random theta
            self.theta = np.random.uniform(low=-1, high=1, size=(self.dim, ))
            
            # Generate contetxs and arms
            self.all_context_arms = np.random.uniform(low=-1, high=1, size=(self.contexts*self.arms, self.dim))
            rewards = self.get_reward(self.all_context_arms)

            # Reshape the context-arms and rewards
            self.all_context_arms = self.all_context_arms.reshape(self.contexts, self.arms, -1)
            rewards = rewards.reshape(self.contexts, self.arms)

            # Get the maximum rewards for all contexts
            self.max_rewards = np.max(rewards, axis=1)
            
            # Save problem instance and corresponding theta
            np.savez(
                path_to_file, 
                all_context_arms = self.all_context_arms, 
                max_rewards = self.max_rewards, 
                theta = self.theta
            )          
    

    # Different latent reward functions
    def get_reward(self, x):       
        if self.db_function == 'linear':
            return 5*x.dot(self.theta)
        
        elif self.db_function == 'square':
            return 10*(x.dot(self.theta))**2
        
        elif self.db_function == 'cubic':
            return (x.dot(self.theta))**3
        
        elif self.db_function == 'cosine':    
            return np.cos(3*x.dot(self.theta))
        
        elif self.db_function == 'sine':
            return 2*np.sin(x.dot(self.theta))
        
        elif self.db_function == 'levy':
            w = 1 + (x - 1.0) / 4.0
            part1 = np.sin(np.pi * w[0]) ** 2
            part2 = np.sum((w[1:self.dim - 1] - 1) ** 2 * (1 + 10 * np.sin(np.pi * w[1:self.dim - 1] + 1) ** 2))
            part3 = (w[self.dim - 1] - 1) ** 2 * (1 + np.sin(2 * np.pi * w[self.dim - 1])**2)
            return part1 + part2 + part3
        
        elif self.db_function == 'ackley':
            a, b, c = -2, 0.2, np.pi
            part1 = -a * np.exp(-b / np.sqrt(self.dim) * np.linalg.norm(x, axis=-1))
            part2 = -(np.exp(np.mean(np.cos(c * x), axis=-1)))
            return 10*(part1 + part2 + np.e + a)
        
        else:
            raise ValueError('Unknown problem: {}'.format(self.db_function))

    
    # Get preference feedback for the selected arms
    def get_preference(self, xt_1, xt_2):
        latent_reward1 = self.get_reward(xt_1)
        latent_reward2 = self.get_reward(xt_2)
        y_prob = 1.0/(1.0 + np.exp(-1*(latent_reward1 - latent_reward2)))
        y = np.random.binomial(1, y_prob)
        
        return y
    

    # Fet suboptimality gap
    def get_suboptimality_gap(self, arms):
        # Get the context-arm pairs for all contexts and arms selected by the policy
        context_arms = np.array([self.all_context_arms[c, arms[c]] for c in range(self.contexts)])
        latent_max_rewards = self.get_reward(context_arms)

        # Calculate RMSE between the maximum rewards and the latent maximum rewards
        rmse =np.sqrt(mean_squared_error(self.max_rewards, latent_max_rewards))

        # Calculate the mean absolute error between the maximum rewards and the latent maximum rewards
        mae = np.mean(np.abs(self.max_rewards - latent_max_rewards))

        return float(np.max(self.max_rewards - latent_max_rewards)), rmse, mae


    # Get regret for the selected arms
    def get_regret(self, c_t, xt_1, xt_2):
        context_arms =  self.all_context_arms[c_t]
        rewards = self.get_reward(context_arms)
        max_reward = rewards[np.argmax(rewards)]
        reward1 = self.get_reward(xt_1)
        reward2 = self.get_reward(xt_2)
        average_regret = max_reward - ((reward1 + reward2)/2)
        weak_regret = max_reward - max(reward1, reward2)
        
        return average_regret, weak_regret


# For sanity check
if __name__ == '__main__':
    # Test the DuelingEnv class
    env = SyntheticEnv('cosine', 5, 100, 5, 1)
    print(env.all_context_arms.shape)
    print(env.max_rewards.shape)
    print(env.theta.shape)

    del env