import torch
import numpy as np
from scipy.spatial.distance import pdist, squareform
from onlinedatasets.datasets import SVMDataset, get_dataset, get_batches, GrowingNumpyDataSet, DataSetUnsupervised, DataSet
from onlinedatasets.models import TorchRewardsModel, AutoEncoder, TorchRewardsModelMultilayer
import pandas as pd
import random

#import cvxpy
import itertools

import IPython
import ray

import timeit
import copy
#ray.init()

def get_submatrices(matrix, used_indices, unused_indices):
    submatrices = []
    for i in unused_indices:
        test_indices = used_indices + [i]
        # IPython.embed()
        # raise ValueError("lksmdf")
        submatrix = matrix[test_indices, :][:, test_indices]
        #submatrix = submatrix[:, test_indices]
        submatrices.append(submatrix)
    return submatrices


def get_submatrices_fast(matrix, used_indices, unused_indices):
    used_matrix = matrix[used_indices, :][:, used_indices]
    side_matrix_cols = matrix[used_indices, :][:,unused_indices]
    side_matrix_rows = matrix[unused_indices, :][:, used_indices]
    unused_matrix = matrix[unused_indices, :][:, unused_indices]
    submatrices = []
    output_matrices_size = len(used_indices)+1
    submatrix_base = np.zeros((output_matrices_size,output_matrices_size))
    submatrix_base[0:output_matrices_size-1, 0:output_matrices_size-1] =  used_matrix

    for i in range(len(unused_indices)):
        submatrix = copy.deepcopy(submatrix_base)
        submatrix[output_matrices_size-1, : output_matrices_size-1] =  side_matrix_rows[i, :]
        submatrix[:output_matrices_size-1, output_matrices_size-1] = side_matrix_cols[:, i]
        submatrix[output_matrices_size-1, output_matrices_size-1] = unused_matrix[i,i]
        submatrices.append(submatrix)
    return submatrices

def get_submatrices_with_determinants(matrix, used_indices, unused_indices):
    determinants = []
    for i in unused_indices:
        test_indices = used_indices + [i]
        # IPython.embed()
        # raise ValueError("lksmdf")
        determinant = np.linalg.det(matrix[test_indices, :][:, test_indices])
        #submatrix = submatrix[:, test_indices]
        determinants.append(determinant)
    return determinants



def compute_determinants(matrix_list):
    return [np.linalg.det(matrix) for matrix in matrix_list]


@ray.remote
def compute_determinants_remote(matrix_list ):

  return compute_determinants(matrix_list)



def greedy_max_determinant_submatrix_indices(matrix, submatrix_size, random_first_point = False, randomize = True, 
    use_ray = False, num_processors = 30):

    unused_indices = list(range(matrix.shape[0]))
    used_indices = []

    if not random_first_point:
        if randomize:
            max_value = np.diag(matrix).max()
            used_index = np.random.choice(np.flatnonzero(np.diag(matrix) == max_value))
        else:
            used_index = np.argmax(np.diag(matrix)) 

    else:
        used_index = np.random.choice(list(np.arange(matrix.shape[0])))
    
    used_indices.append(used_index)
    unused_indices.remove(used_index)

    start_time = timeit.default_timer()
    for i in range(1, submatrix_size):
        #print("Submatrix size now ", i+1)

        start_time_compute_submatrices = timeit.default_timer()

        #submatrices = get_submatrices(matrix, used_indices, unused_indices)
        submatrices = get_submatrices_fast(matrix, used_indices, unused_indices)
        
        elapsed_time_compute_submatrices = timeit.default_timer() - start_time_compute_submatrices

        #print("submatrices computation time {}".format(elapsed_time_compute_submatrices))

        start_time_compute_determinant = timeit.default_timer()


        if use_ray :

            chunk_size = int(len(submatrices)/num_processors)
            determinants_list_of_lists = [submatrices[x:x+chunk_size] for x in range(0, len(submatrices), chunk_size )]

            results_pre = [compute_determinants_remote.remote(matrix_list) for matrix_list in determinants_list_of_lists]
           

            results = ray.get(results_pre)

            determinants = []
            for det_list in results:
                determinants += det_list

            # if i == 30:
            #     IPython.embed()
            #     raise ValueError("Asdflkm")

        else:
            determinants = compute_determinants(submatrices)

            #determinants = get_submatrices_with_determinants(matrix, used_indices, unused_indices)

        
        elapsed_compute_determinant = timeit.default_timer() - start_time_compute_determinant
        #print("determinant computation time {}".format(elapsed_compute_determinant))
        
        start_time_compute_determinant_postprocess = timeit.default_timer()

        determinants_with_indices = list(zip(determinants, unused_indices))
        determinants_with_indices.sort()
        determinants.reverse()
        max_value = determinants_with_indices[0][0]
        #print("inner max value ", max_value)
        argmax_indices = [index for (val, index) in determinants_with_indices if val == max_value]


        used_index = np.random.choice(argmax_indices)

        used_indices.append(used_index)
        unused_indices.remove(used_index)

        elapsed_time_compute_determinant_postprocess = timeit.default_timer() - start_time_compute_determinant_postprocess
        #print("determinant postprocess computation time {}".format(elapsed_time_compute_determinant_postprocess))


    elapsed = timeit.default_timer() - start_time
    #IPython.embed()
    #raise ValueError("asdflkm")
    return used_indices, max_value, unused_indices


def generate_reward_augmented_matrix(kernel_matrix, rewards, lambda_det):
    exp_rewards = np.exp(np.array(rewards)/(2*lambda_det))
    intermediate_result = np.multiply(kernel_matrix, exp_rewards)
    intermediate_result = np.multiply(intermediate_result.transpose(), exp_rewards)
    return intermediate_result.transpose()

    # modified_matrix = np.matmul(np.matmul(diag_reward_matrix, kernel_matrix), diag_reward_matrix)

    # return modified_matrix





def train_autoencoder(autoencoder, dataset, num_steps, batch_size, verbose = True, logging_frequency = 100):
    #autoencoder = AutoEncoder(random_init = True, dim = dataset.dimension, representation_layer_size = 10 )
    #IPython.embed()
    #optimizer = torch.optim.Adam(list(autoencoder.encoder.parameters())+list(autoencoder.decoder.parameters()), lr = 0.001)
    test_loss_list = []
    optimizer = torch.optim.Adam(autoencoder.encoder_decoder.parameters(), lr = 0.001)
    for i in range(num_steps):
        batch_X = get_batches(dataset, batch_size)
        optimizer.zero_grad()
        loss = autoencoder.get_loss(batch_X)
        #unadultered_loss = autoencoder.get_loss(batch_X)
        if i % logging_frequency == 0:
           test_batch = get_batches(dataset, 10000000000000)
           test_loss = autoencoder.get_loss(test_batch)
           test_loss = test_loss.detach().cpu().numpy()
           print('test loss ' , test_loss, "iteration ", i)
           print("batch X ", batch_X)
           print("reconstructed ", autoencoder.reconstruction(batch_X))

           test_loss_list.append(test_loss)

        loss.backward()
        optimizer.step()

    return autoencoder, test_loss_list


def train_simple_regression(reward_model, dataset, num_steps, batch_size, verbose = False, logging_frequency = 100):
    optimizer = torch.optim.Adam(reward_model.network.parameters(), lr = 0.01)
    test_loss_list =[]
    for i in range(num_steps):
        batch_X, batch_y = get_batches(dataset, batch_size)


        optimizer.zero_grad()
        loss = reward_model.get_loss(batch_X, batch_y)
        if verbose and i % logging_frequency == 0:
            print("iteration ", i)
            test_batch_X, test_batch_y = get_batches(dataset, 100000000000000000)
            test_loss = reward_model.get_loss(test_batch_X, test_batch_y)
            test_loss = test_loss.detach().cpu().numpy()
            print("test loss ", test_loss)

            test_loss_list.append(test_loss)

        loss.backward()
        optimizer.step()

    return reward_model, test_loss_list



### TODO: Implement this.
def train_simple_logistic_regression(reward_model, dataset, num_steps, batch_size, verbose = False, logging_frequency = 100):
    optimizer = torch.optim.Adam(reward_model.network.parameters(), lr = 0.01)
    test_loss_list =[]
    for i in range(num_steps):
        batch_X, batch_y = get_batches(dataset, batch_size)


        optimizer.zero_grad()
        loss = reward_model.get_logistic_loss(batch_X, batch_y)
        if verbose and i % logging_frequency == 0:
            print("iteration ", i)
            with torch.no_grad():
                test_batch_X, test_batch_y = get_batches(dataset, 100000000000000000)
                test_loss = reward_model.get_logistic_loss(test_batch_X, test_batch_y)
                test_loss = test_loss.detach().cpu().numpy()
                print("test loss ", test_loss)
                test_loss_list.append(test_loss)

        loss.backward()
        optimizer.step()

    return reward_model, test_loss_list




def train_simple_regression_full_minimization(reward_model, dataset, num_batches, batch_size, num_opt_steps, opt_batch_size = 10, l1 = False):
    growing_dataset = GrowingNumpyDataSet()
    max_observed_reward_during_training = -float("inf")
    max_observed_rewards = []

    for j in range(num_batches):
        optimizer = torch.optim.Adam(reward_model.network.parameters(), lr = 0.01)
        new_batch_X, new_batch_y = get_batches(dataset, batch_size)
        growing_dataset.add_data(new_batch_X, new_batch_y)
        max_observed_reward_during_training = max(max_observed_reward_during_training, max(new_batch_y))
        max_observed_rewards.append(max_observed_reward_during_training)

        for i in range(num_opt_steps):
            #print("Batch num ", j, " opt step ", i)
            batch_X, batch_y = get_batches(growing_dataset, opt_batch_size)
            optimizer.zero_grad()
            if l1:
                loss = reward_model.get_loss_l1(batch_X, batch_y)
            else:
                loss = reward_model.get_loss(batch_X, batch_y)
            loss.backward()
            optimizer.step()


        #max_rewards_estimated_value, estimated_index_real_value, max_reward_value, estimated_value_of_max_reward, resulting_rank = evaluate_reward_model(reward_model, train_dataset)


    return reward_model, max_observed_rewards


### This class implements a dataset that 
### first fits a regression model to create the responses.
### This is to be used for diagnostics
class DataSetWithRegressionResponses:
    def __init__(self, base_dataset, MLP= True, representation_layer_sizes = [10], num_steps = 2000, batch_size = 20):
        
        self.base_dataset = base_dataset
        #self.reward_model = TorchRewardsModel(random_init = True, MLP = MLP, dim = self.base_dataset.dimension, 
        #    representation_layer_size = representation_layer_size)

        self.reward_model = TorchRewardsModelMultilayer(dim = self.base_dataset.dimension, 
            representation_layer_sizes = representation_layer_sizes, activation_type = "relu")

        self.reward_model = train_simple_regression(self.reward_model, self.base_dataset, num_steps, batch_size)[0]

        self.dataset = None
        self.labels = None
        self.type = "generative"

        if hasattr(base_dataset, "labels"):
            self.dataset = self.base_dataset.dataset
            with torch.no_grad():

                self.labels = self.reward_model.get_reward(self.base_dataset.dataset.values)
                self.labels = self.labels.detach().numpy()
                self.labels = pd.DataFrame(self.labels)
        
            self.type = "from_file"
            self.num_datapoints = self.base_dataset.num_datapoints



        self.dimension = self.base_dataset.dimension
        self.random_state = 1

    def get_batch(self, batch_size):
        
        if self.type == "generative":
            (batch_X,batch_Y) = self.base_dataset.get_batch(batch_size) 
            with torch.no_grad():
                batch_Y_new = self.reward_model.get_reward(batch_X)
                batch_Y_new = batch_Y_new.detach()
                return (batch_X, batch_Y_new.numpy())

        elif self.type == "from_file":
            if batch_size > self.num_datapoints:
                X = self.dataset.values
                Y = self.labels.values
            else:
                X = self.dataset.sample(batch_size, random_state=self.random_state).values
                Y = self.labels.sample(batch_size, random_state=self.random_state).values
            self.random_state += 1

            return (X, Y)


        else:
            raise ValueError("This type of regression response dataset does not exist {}.".format(self.type))


    def fix_responses(self):
        with torch.no_grad():
            self.labels = self.reward_model.get_reward(self.base_dataset.dataset.values)
            self.labels = self.labels.detach().numpy()
            self.labels = pd.DataFrame(self.labels)



class RandomBatch:
    def __init__(self):
        self.name = "RandomBatch"
        #self.reward_model = reward_model
        #self.batch_size = batch_size

    def get_name(self):
        return "{}".format(self.name)


    def get_batch(self, growing_dataset, complement_supervised_dataset, batch_size):
        return get_random_reward_batch(complement_supervised_dataset, batch_size)

    def fit_data(self, num_opt_steps, growing_dataset, complement_unsupervised_dataset, opt_batch_size):
        return 0




class EnsembleOptimism:
    def __init__(self, reward_models, l1, lambda_reward_max, l2_regularizer, range_regularizer, verbose = False, noise_injection = False):
        self.name = "EnsembleOptimism"
        self.reward_models = reward_models
        #self.batch_size = batch_size
        self.l1 = l1
        self.verbose = verbose
        self.l2_regularizer = l2_regularizer
        self.range_regularizer = range_regularizer
        self.lambda_reward_max = lambda_reward_max
        self.noise_injection = noise_injection
        if self.noise_injection:
            self.name = "EnsembleOptimismNoiseY"

    def get_name(self):
        return "{}".format(self.name)


    def get_batch(self, growing_dataset, complement_supervised_dataset, batch_size):
        if growing_dataset.get_size() == 0:
            return get_random_reward_batch(complement_supervised_dataset, batch_size)
        return get_max_reward_batch_ensemble(self.reward_models, complement_supervised_dataset, batch_size)


    def fit_data(self, num_opt_steps, growing_dataset, complement_unsupervised_dataset, opt_batch_size):

        self.optimizers = [torch.optim.Adam(reward_model.network.parameters(), lr = 0.01) for reward_model in self.reward_models]

        for reward_model in self.reward_models:
            reward_model.reset_weights()

        for i in range(num_opt_steps):            
            for optimizer in self.optimizers:
                optimizer.zero_grad()
            

          
    
            batches = [  get_batches(growing_dataset, opt_batch_size) for _ in range(len(self.reward_models))]

            if self.verbose:
                print("Global batch num ", j, " opt step ", i)
            
            ensemble_losses = []

            for ((batch_X, batch_y), reward_model, optimizer) in zip(batches, self.reward_models, self.optimizers):
                #IPython.embed()
                #raise ValueError("asdlfkm")
                if self.noise_injection:
                   batch_y = copy.deepcopy(batch_y)
                   batch_y += np.random.normal(0,1,(len(batch_y),1))
                if self.l1:
                    loss = reward_model.get_loss_l1(batch_X, batch_y)
                else:
                    loss = reward_model.get_loss(batch_X, batch_y)

                ensemble_losses.append(loss.detach().numpy())           
                loss.backward()
                optimizer.step()


            # if self.verbose:
            #     print("loss " , loss)


            if i == num_opt_steps-1:
                with torch.no_grad():
                    model_fitting_loss = np.mean(ensemble_losses)
            

            #torch.mean(ensemble_losses).backward()

            #self.optimizer.step()




        return model_fitting_loss






class SequentialBatchOptimism:
    def __init__(self, reward_model, l1, lambda_reward_max, l2_regularizer, range_regularizer, 
        num_opt_in_batch_steps = 5000, in_batch_opt_batch_size = 10, verbose = False):
        self.name = "SequentialBatchOptimism"
        self.reward_model = reward_model
        #self.batch_size = batch_size
        self.l1 = l1
        self.verbose = verbose
        self.l2_regularizer = l2_regularizer
        self.range_regularizer = range_regularizer
        self.lambda_reward_max = lambda_reward_max
        self.num_opt_in_batch_steps = num_opt_in_batch_steps
        self.in_batch_opt_batch_size = in_batch_opt_batch_size
        if self.lambda_reward_max == 0:
            raise ValueError("SequentialBatchOptimism is set with lambda_reward_max = 0")



        self.reward_models = []


    def get_name(self):
        return "{} lambda_reg {}".format(self.name, self.lambda_reward_max)

    def get_batch(self, growing_dataset, complement_supervised_dataset, batch_size):
        if growing_dataset.get_size() == 0:
            return get_random_reward_batch(complement_supervised_dataset, batch_size)

        growing_dataset_copy = copy.deepcopy(growing_dataset)
        complement_supervised_dataset_copy = copy.deepcopy(complement_supervised_dataset)

        resulting_batch_index = []
        resulting_filtered_batch_y = []
        resulting_filtered_batch_X = []


        for i in range(batch_size):

            complement_batch_X, complement_batch_y = get_batches( complement_supervised_dataset_copy, 1000000000) 

            complement_unsupervised_dataset = DataSetUnsupervised(pd.DataFrame(complement_batch_X))  

            #IPython.embed()
            ### Fit optimistic model
            self.fit_data( self.num_opt_in_batch_steps, growing_dataset_copy, 
                complement_unsupervised_dataset, self.in_batch_opt_batch_size, optimism_up = True)

            filtered_batch_X, filtered_batch_y, complement_batch_X, complement_batch_y =  get_max_reward_batch(self.reward_model, complement_supervised_dataset_copy, 1)

            optimistic_filtered_batch_y  = self.reward_model.get_reward(np.array(filtered_batch_X))


            



            ### Fit pessimistic model
            self.fit_data( self.num_opt_in_batch_steps, growing_dataset_copy, 
                complement_unsupervised_dataset, self.in_batch_opt_batch_size, optimism_up = False)






            pessimistic_filtered_batch_y  = self.reward_model.get_reward(np.array(filtered_batch_X))


            mean_value = float(((optimistic_filtered_batch_y + pessimistic_filtered_batch_y)/2).detach())


            resulting_filtered_batch_y += list(filtered_batch_y[0])
            filtered_batch_y_copy = copy.deepcopy(filtered_batch_y)
            filtered_batch_y_copy[0] = mean_value

            if complement_supervised_dataset.return_dataframe:
                resulting_batch_index += list(filtered_batch_y.index)
                resulting_filtered_batch_X += [ filtered_batch_X.values[0,:] ]
                
            else:
                resulting_filtered_batch_X += [ filtered_batch_X[0,:] ]
                
            


            growing_dataset_copy.add_data(np.array(filtered_batch_X), np.array(filtered_batch_y_copy))
            
            complement_supervised_dataset_copy = DataSet(pd.DataFrame(complement_batch_X), pd.DataFrame(complement_batch_y), 
                return_dataframe = complement_supervised_dataset.return_dataframe) 


        if complement_supervised_dataset.return_dataframe:

            filtered_batch_y = pd.DataFrame(data = resulting_filtered_batch_y, index = resulting_batch_index)
            filtered_batch_X = pd.DataFrame(data = resulting_filtered_batch_X, index = resulting_batch_index)

        else:
            filtered_batch_y =  np.expand_dims(np.array(resulting_filtered_batch_y),1)
            filtered_batch_X =  np.array(resulting_filtered_batch_X)


        # IPython.embed()
        # raise ValueError("Asdflkm")


        return filtered_batch_X, filtered_batch_y, complement_batch_X, complement_batch_y    



        #return get_max_reward_batch(self.reward_model, complement_supervised_dataset, batch_size)


    def fit_data(self, num_opt_steps, growing_dataset, complement_unsupervised_dataset, opt_batch_size, optimism_up = True):
        self.optimizer = torch.optim.Adam(self.reward_model.network.parameters(), lr = 0.01)

        self.reward_model.reset_weights()

        for i in range(num_opt_steps):            
            self.optimizer.zero_grad()
            batch_X, batch_y = get_batches(growing_dataset, opt_batch_size)
            unsupervised_batch_X = get_batches(complement_unsupervised_dataset, opt_batch_size)

            #IPython.embed()
            if self.verbose:
                print("Global batch num ", j, " opt step ", i)
            if self.l1:
                loss = self.reward_model.get_loss_l1(batch_X, batch_y)
            else:
                loss = self.reward_model.get_loss(batch_X, batch_y)

            if i == num_opt_steps-1:
                with torch.no_grad():
                    model_fitting_loss = loss.detach().numpy()


            predictions = self.reward_model.get_reward(unsupervised_batch_X)
            

            loss += (1-2*optimism_up)*self.lambda_reward_max*torch.mean(predictions)

            #loss -= self.lambda_reward_max*torch.mean(predictions)

            if self.verbose:
                print("loss " , loss)
            loss.backward()
            self.optimizer.step()




        return model_fitting_loss




class EnsembleSequentialBatchOptimism:
    def __init__(self, reward_models, l1, lambda_reward_max, l2_regularizer, range_regularizer, 
        num_opt_in_batch_steps = 5000, in_batch_opt_batch_size = 10, verbose = False, noise_injection = False):


        self.name = "EnsembleSequentialBatchOptimism"
        self.reward_models = reward_models

        self.l1 = l1


        self.l1 = l1
        self.verbose = verbose
        self.l2_regularizer = l2_regularizer
        self.range_regularizer = range_regularizer
        self.lambda_reward_max = lambda_reward_max
        self.num_opt_in_batch_steps = num_opt_in_batch_steps
        self.in_batch_opt_batch_size = in_batch_opt_batch_size
        self.noise_injection = noise_injection
        if self.noise_injection:
            self.name = "EnsembleSequentialBatchOptimismNoiseY"
     


    def get_name(self):
        return "{}".format(self.name)

    def get_batch(self, growing_dataset, complement_supervised_dataset, batch_size):
        if growing_dataset.get_size() == 0:
            return get_random_reward_batch(complement_supervised_dataset, batch_size)

        growing_dataset_copy = copy.deepcopy(growing_dataset)
        complement_supervised_dataset_copy = copy.deepcopy(complement_supervised_dataset)

        resulting_batch_index = []
        resulting_filtered_batch_y = []
        resulting_filtered_batch_X = []


        for i in range(batch_size):

            complement_batch_X, complement_batch_y = get_batches( complement_supervised_dataset_copy, 1000000000) 

            complement_unsupervised_dataset = DataSetUnsupervised(pd.DataFrame(complement_batch_X))  

            #IPython.embed()
            ### Fit optimistic model
            self.fit_data( self.num_opt_in_batch_steps, growing_dataset_copy, 
                complement_unsupervised_dataset, self.in_batch_opt_batch_size)

            filtered_batch_X, filtered_batch_y, complement_batch_X, complement_batch_y =  get_max_reward_batch_ensemble(self.reward_models, complement_supervised_dataset_copy, 1)

            filtered_batch_y_ensemble_values = [reward_model.get_reward(np.array(filtered_batch_X)).detach().numpy() for reward_model in self.reward_models]



            optimistic_filtered_batch_y  = np.max(filtered_batch_y_ensemble_values)


            pessimistic_filtered_batch_y = np.min(filtered_batch_y_ensemble_values)

            # ### Fit pessimistic model
            # self.fit_data( self.num_opt_in_batch_steps, growing_dataset_copy, 
            #     complement_unsupervised_dataset, self.in_batch_opt_batch_size, optimism_up = False)


            #pessimistic_filtered_batch_y  = self.reward_model.get_reward(np.array(filtered_batch_X))


            mean_value = float(((optimistic_filtered_batch_y + pessimistic_filtered_batch_y)/2))#detach())


            resulting_filtered_batch_y += list(filtered_batch_y[0])
            filtered_batch_y_copy = copy.deepcopy(filtered_batch_y)
            filtered_batch_y_copy[0] = mean_value

            if complement_supervised_dataset.return_dataframe:
                resulting_batch_index += list(filtered_batch_y.index)
                resulting_filtered_batch_X += [ filtered_batch_X.values[0,:] ]
                
            else:
                resulting_filtered_batch_X += [ filtered_batch_X[0,:] ]
                
            


            growing_dataset_copy.add_data(np.array(filtered_batch_X), np.array(filtered_batch_y_copy))
            
            complement_supervised_dataset_copy = DataSet(pd.DataFrame(complement_batch_X), pd.DataFrame(complement_batch_y), 
                return_dataframe = complement_supervised_dataset.return_dataframe) 


        if complement_supervised_dataset.return_dataframe:

            filtered_batch_y = pd.DataFrame(data = resulting_filtered_batch_y, index = resulting_batch_index)
            filtered_batch_X = pd.DataFrame(data = resulting_filtered_batch_X, index = resulting_batch_index)

        else:
            filtered_batch_y =  np.expand_dims(np.array(resulting_filtered_batch_y),1)
            filtered_batch_X =  np.array(resulting_filtered_batch_X)


        # IPython.embed()
        # raise ValueError("Asdflkm")


        return filtered_batch_X, filtered_batch_y, complement_batch_X, complement_batch_y    




    def fit_data(self, num_opt_steps, growing_dataset, complement_unsupervised_dataset, opt_batch_size):

        self.optimizers = [torch.optim.Adam(reward_model.network.parameters(), lr = 0.01) for reward_model in self.reward_models]

        for reward_model in self.reward_models:
            reward_model.reset_weights()

        for i in range(num_opt_steps):            
            for optimizer in self.optimizers:
                optimizer.zero_grad()
            

          
    
            batches = [  get_batches(growing_dataset, opt_batch_size) for _ in range(len(self.reward_models))]

            if self.verbose:
                print("Global batch num ", j, " opt step ", i)
            
            ensemble_losses = []

            for ((batch_X, batch_y), reward_model, optimizer) in zip(batches, self.reward_models, self.optimizers):
                
                if self.noise_injection:
                    batch_y = copy.deepcopy(batch_y)
                    batch_y += np.random.normal(0,1,(len(batch_y),1))

                if self.l1:
                    loss = reward_model.get_loss_l1(batch_X, batch_y)
                else:
                    loss = reward_model.get_loss(batch_X, batch_y)

                ensemble_losses.append(loss.detach().numpy())           
                loss.backward()
                optimizer.step()


            # if self.verbose:
            #     print("loss " , loss)


            if i == num_opt_steps-1:
                with torch.no_grad():
                    model_fitting_loss = np.mean(ensemble_losses)
            

            #torch.mean(ensemble_losses).backward()

            #self.optimizer.step()




        return model_fitting_loss




class MeanOptimism:
    def __init__(self, reward_model, l1, lambda_reward_max, l2_regularizer, range_regularizer, verbose = False):
        self.name = "MeanOptimism"
        self.reward_model = reward_model
        #self.batch_size = batch_size
        self.l1 = l1
        self.verbose = verbose
        self.l2_regularizer = l2_regularizer
        self.range_regularizer = range_regularizer
        self.lambda_reward_max = lambda_reward_max

    def get_batch(self, growing_dataset,  complement_supervised_dataset, batch_size):
        if growing_dataset.get_size() == 0:
            return get_random_reward_batch(complement_supervised_dataset, batch_size)

        return get_max_reward_batch(self.reward_model, complement_supervised_dataset, batch_size)

    def get_name(self):
        return "{} lambda_reg {}".format(self.name, self.lambda_reward_max)



    def fit_data(self, num_opt_steps, growing_dataset, complement_unsupervised_dataset, opt_batch_size):
        self.optimizer = torch.optim.Adam(self.reward_model.network.parameters(), lr = 0.01)

        self.reward_model.reset_weights()

        for i in range(num_opt_steps):            
            self.optimizer.zero_grad()
            batch_X, batch_y = get_batches(growing_dataset, opt_batch_size)
            unsupervised_batch_X = get_batches(complement_unsupervised_dataset, opt_batch_size)

            #IPython.embed()
            if self.verbose:
                print("Global batch num ", j, " opt step ", i)
            if self.l1:
                loss = self.reward_model.get_loss_l1(batch_X, batch_y)
            else:
                loss = self.reward_model.get_loss(batch_X, batch_y)

            if i == num_opt_steps-1:
                with torch.no_grad():
                    model_fitting_loss = loss.detach().numpy()


            predictions = self.reward_model.get_reward(unsupervised_batch_X)
            loss -= self.lambda_reward_max*torch.mean(predictions)

            if self.verbose:
                print("loss " , loss)
            loss.backward()
            self.optimizer.step()




        return model_fitting_loss




class DeterminantOptimism(MeanOptimism):
    def __init__(self, reward_model, l1, lambda_reward_max, l2_regularizer, range_regularizer, verbose = False, lambda_det = 1):
        MeanOptimism.__init__(self, reward_model = reward_model, l1 = l1, 
            lambda_reward_max = lambda_reward_max, l2_regularizer = l2_regularizer, 
            range_regularizer = range_regularizer, verbose = verbose)

        self.name = "DeterminantOptimism"
        self.lambda_det = lambda_det

    def get_name(self):
        return "{} lambda_reg {} det_reg {}".format(self.name, self.lambda_reward_max, self.lambda_det)


    def get_batch(self, growing_dataset, complement_supervised_dataset, batch_size):
        if growing_dataset.get_size() == 0:
            return get_random_reward_batch(complement_supervised_dataset, batch_size)

        return get_max_determinant_batch(self.reward_model, complement_supervised_dataset, batch_size, self.lambda_det)












class MaxOptimism:
    def __init__(self, reward_model, l1, lambda_reward_max, l2_regularizer, range_regularizer, verbose = False, max_optimism_eta = 20):
        self.name = "MaxOptimism"
        self.reward_model = reward_model
        #self.batch_size = batch_size
        self.l1 = l1
        self.verbose = verbose
        self.l2_regularizer = l2_regularizer
        self.range_regularizer = range_regularizer
        self.lambda_reward_max = lambda_reward_max
        self.max_optimism_eta = max_optimism_eta


    def get_name(self):
        return "{} lambda_reg {}".format(self.name, self.lambda_reward_max)


    def get_batch(self, growing_dataset, complement_supervised_dataset, batch_size):
        if growing_dataset.get_size() == 0:
            return get_random_reward_batch(complement_supervised_dataset, batch_size)

        return get_max_reward_batch(self.reward_model, complement_supervised_dataset, batch_size)



    def fit_data(self, num_opt_steps, growing_dataset, complement_unsupervised_dataset, opt_batch_size):

        self.optimizer = torch.optim.Adam(self.reward_model.network.parameters(), lr = 0.01)

        self.reward_model.reset_weights()

        for i in range(num_opt_steps):            
            self.optimizer.zero_grad()
            batch_X, batch_y = get_batches(growing_dataset, opt_batch_size)
            unsupervised_batch_X = get_batches(complement_unsupervised_dataset, 10000000)

            if self.verbose:
                print("Global batch num ", j, " opt step ", i)
            if self.l1:
                loss = self.reward_model.get_loss_l1(batch_X, batch_y)
            else:
                loss = self.reward_model.get_loss(batch_X, batch_y)

            if i == num_opt_steps-1:
                with torch.no_grad():
                    model_fitting_loss = loss.detach().numpy()


            predictions = self.reward_model.get_reward(unsupervised_batch_X)
            loss -= self.lambda_reward_max*(1.0/self.max_optimism_eta)*torch.log(   torch.sum(torch.exp(predictions*self.max_optimism_eta)     )  )

            if self.verbose:
                print("loss " , loss)
            loss.backward()
            self.optimizer.step()

        return model_fitting_loss





class HingePNormOptimism:
    def __init__(self, reward_model, l1, lambda_reward_max, l2_regularizer, range_regularizer, verbose = False, power = 4):
        self.name = "HingePNormOptimism"
        self.reward_model = reward_model
        self.l1 = l1
        self.verbose = verbose
        self.l2_regularizer = l2_regularizer
        self.range_regularizer = range_regularizer
        self.lambda_reward_max = lambda_reward_max
        self.power = power

    def get_name(self):
        return "{} lambda_reg {}".format(self.name, self.lambda_reward_max)


    def get_batch(self, growing_dataset, complement_supervised_dataset, batch_size):
        if growing_dataset.get_size() == 0:
            return get_random_reward_batch(complement_supervised_dataset, batch_size)

        return get_max_reward_batch(self.reward_model, complement_supervised_dataset, batch_size)



    def fit_data(self, num_opt_steps, growing_dataset, complement_unsupervised_dataset, opt_batch_size):

        self.optimizer = torch.optim.Adam(self.reward_model.network.parameters(), lr = 0.01)

        self.reward_model.reset_weights()

        for i in range(num_opt_steps):            
            self.optimizer.zero_grad()
            batch_X, batch_y = get_batches(growing_dataset, opt_batch_size)
            unsupervised_batch_X = get_batches(complement_unsupervised_dataset, 10000000)

            if self.verbose:
                print("Global batch num ", j, " opt step ", i)
            if self.l1:
                loss = self.reward_model.get_loss_l1(batch_X, batch_y)
            else:
                loss = self.reward_model.get_loss(batch_X, batch_y)

            if i == num_opt_steps-1:
                with torch.no_grad():
                    model_fitting_loss = loss.detach().numpy()


            predictions = self.reward_model.get_reward(unsupervised_batch_X)
            loss -= self.lambda_reward_max*torch.pow(torch.sum(torch.pow(torch.relu(predictions), self.power)), 1.0/self.power)

            if self.verbose:
                print("loss " , loss)
            loss.backward()
            self.optimizer.step()

        return model_fitting_loss


def get_max_reward_batch(reward_model, dataset, batch_size):
    batch_X, batch_y = get_batches( dataset, 1000000000) 
    

    rewards = reward_model.get_reward(np.array(batch_X))
    ### Find the top batch_size number of datapoints according to their predicted reward.
    ### get their labels. 
    list_rewards = list(rewards.detach().numpy())
    #IPython.embed()

    if dataset.return_dataframe:
        batch_index = list(batch_X.index)
    else:
        batch_index = list(range(len(list_rewards)))


    zip_list_rewards = list(zip(list_rewards, batch_index))
    zip_list_rewards.sort()
    batch_to_select = zip_list_rewards[-batch_size:]
    batch_indices = [b for (a,b) in batch_to_select]

    if dataset.return_dataframe:
        filtered_batch_X = batch_X.loc[batch_indices]
        filtered_batch_y = batch_y.loc[batch_indices]

    else:
        filtered_batch_X = batch_X[batch_indices, :]
        filtered_batch_y = batch_y[batch_indices,:]


    complement_to_select = zip_list_rewards[:-batch_size]
    complement_indices = [b for (a,b) in complement_to_select]
    if dataset.return_dataframe:
        complement_batch_X = batch_X.loc[complement_indices]
        complement_batch_y = batch_y.loc[complement_indices]
    else:
        complement_batch_X = batch_X[complement_indices, :]
        complement_batch_y = batch_y[complement_indices, :]

    return filtered_batch_X, filtered_batch_y, complement_batch_X, complement_batch_y    

# def get_kernel_function(kernel_type):
#     if kernel_type == "linear":
#         return 
#     if kernel_type == "gaussian":






def get_max_determinant_batch(reward_model, dataset, batch_size, lambda_det, kernel_type = "gaussian", num_greedy_trials = 5):
    batch_X, batch_y = get_batches( dataset, 1000000000) 
    
    rewards = reward_model.get_reward(np.array(batch_X))
    
    ### Right now we only implement a linear kernel.

    if dataset.return_dataframe:
        batch_X_numpy = batch_X.values

    else:
        batch_X_numpy = np.array(batch_X)


    if kernel_type == "linear":
        kernel_matrix = np.matmul(batch_X_numpy, batch_X_numpy.transpose())
    elif kernel_type == "gaussian":
        normalized_batch_X_numpy = np.multiply(batch_X_numpy.transpose(), 1.0/np.linalg.norm(batch_X_numpy, axis = 1)).transpose()

        pairwise_dists = squareform(pdist(normalized_batch_X_numpy, 'euclidean'))
        #IPython.embed()
        kernel_matrix = np.exp(-pairwise_dists**2)
    else:
        raise ValueError("Unknown kernel type {}".format(kernel_type))



    ### Find the top batch_size number of datapoints according to their predicted reward.
    ### get their labels. 
    list_rewards = list(rewards.detach().numpy())
    if dataset.return_dataframe:
        batch_index = list(batch_X.index)
    else:
        batch_index = list(range(len(list_rewards)))
    #IPython.embed()#  
    reward_augmented_kernel = generate_reward_augmented_matrix(kernel_matrix, list_rewards, lambda_det)


    print("Computed the reward augmented kernel matrix - starting greedy submatrix selection")
    final_unmapped_indices_to_select = []
    final_unmapped_complement_indices = []
    final_max_det_val = -float("inf")
    max_vals = []




    for i in range(num_greedy_trials):
        unmapped_indices_to_select, max_det_val, unmapped_unused_indices = greedy_max_determinant_submatrix_indices(kernel_matrix, batch_size, random_first_point = i > 0)
        #print("indices ", unmapped_indices_to_select)
        max_vals.append(max_det_val)
        if final_max_det_val < max_det_val:
            final_unmapped_indices_to_select = unmapped_indices_to_select
            final_max_det_val = max_det_val
            final_unmapped_complement_indices = unmapped_unused_indices
    #IPython.embed()


    batch_indices = [batch_index[i] for i in final_unmapped_indices_to_select]


    if dataset.return_dataframe:
        filtered_batch_X = batch_X.loc[batch_indices]
        filtered_batch_y = batch_y.loc[batch_indices]

    else:
        filtered_batch_X = batch_X[batch_indices, :]
        filtered_batch_y = batch_y[batch_indices,:]


    #complement_to_select = zip_list_rewards[:-batch_size]
    #complement_indices = [b for (a,b) in complement_to_select]
    complement_indices = [batch_index[i] for i in final_unmapped_complement_indices]
    if dataset.return_dataframe:
        complement_batch_X = batch_X.loc[complement_indices]
        complement_batch_y = batch_y.loc[complement_indices]
    else:
        complement_batch_X = batch_X[complement_indices, :]
        complement_batch_y = batch_y[complement_indices, :]

    return filtered_batch_X, filtered_batch_y, complement_batch_X, complement_batch_y    





def get_max_reward_batch_ensemble(reward_models, dataset, batch_size):
    batch_X, batch_y = get_batches( dataset, 1000000000) 
    all_rewards_predictions = [reward_model.get_reward(np.array(batch_X)) for reward_model in reward_models]
    all_rewards_predictions = torch.stack(all_rewards_predictions)
    all_rewards_predictions = all_rewards_predictions.detach().numpy()

    rewards = np.max(all_rewards_predictions, 0)

    #IPython.embed()
    
    ### Find the top batch_size number of datapoints according to their predicted reward.
    ### get their labels. 
    list_rewards = list(rewards)
    # zip_list_rewards = list(zip(list_rewards, list(range(len(list_rewards)))))
    # zip_list_rewards.sort()
    # batch_to_select = zip_list_rewards[-batch_size:]
    # batch_indices = [b for (a,b) in batch_to_select]
    # filtered_batch_X = batch_X[batch_indices, :]
    # filtered_batch_y = batch_y[batch_indices,:]

    # complement_to_select = zip_list_rewards[:-batch_size]
    # complement_indices = [b for (a,b) in complement_to_select]
    # complement_batch_X = batch_X[complement_indices, :]
    # complement_batch_y = batch_y[complement_indices, :]



    # list_rewards = list(rewards.detach().numpy())
    if dataset.return_dataframe:
        batch_index = list(batch_X.index)
    else:
        batch_index = list(range(len(list_rewards)))


    zip_list_rewards = list(zip(list_rewards, batch_index))
    zip_list_rewards.sort()
    batch_to_select = zip_list_rewards[-batch_size:]
    batch_indices = [b for (a,b) in batch_to_select]

    if dataset.return_dataframe:
        filtered_batch_X = batch_X.loc[batch_indices]
        filtered_batch_y = batch_y.loc[batch_indices]

    else:
        filtered_batch_X = batch_X[batch_indices, :]
        filtered_batch_y = batch_y[batch_indices,:]


    complement_to_select = zip_list_rewards[:-batch_size]
    complement_indices = [b for (a,b) in complement_to_select]
    if dataset.return_dataframe:
        complement_batch_X = batch_X.loc[complement_indices]
        complement_batch_y = batch_y.loc[complement_indices]
    else:
        complement_batch_X = batch_X[complement_indices, :]
        complement_batch_y = batch_y[complement_indices, :]

    return filtered_batch_X, filtered_batch_y, complement_batch_X, complement_batch_y    











    return filtered_batch_X, filtered_batch_y, complement_batch_X, complement_batch_y    





def get_random_reward_batch(dataset, batch_size):
    batch_X, batch_y = get_batches( dataset, 1000000000) 
    #rewards = reward_model.get_reward(batch_X)
    ### Find the top batch_size number of datapoints according to their predicted reward.
    ### get their labels. 

    if dataset.return_dataframe:
        scrambled_indices = list(batch_X.index)    
    else:
        scrambled_indices = list(range(batch_X.shape[0]))

    random.shuffle(scrambled_indices)

    batch_indices = scrambled_indices[:batch_size]




    # list_rewards = list(rewards.detach().numpy())
    # zip_list_rewards = list(zip(list_rewards, list(range(len(list_rewards)))))
    # zip_list_rewards.sort()
    # batch_to_select = zip_list_rewards[-batch_size:]
    #batch_indices = [b for (a,b) in batch_to_select]
    #IPython.embed()

    if dataset.return_dataframe:
        filtered_batch_X = batch_X.loc[batch_indices]
        filtered_batch_y = batch_y.loc[batch_indices]

    else:
        filtered_batch_X = batch_X[batch_indices, :]
        filtered_batch_y = batch_y[batch_indices,:]

    #complement_to_select = zip_list_rewards[:-batch_size]
    #complement_indices = [b for (a,b) in complement_to_select]
    complement_indices = scrambled_indices[batch_size:]

    if dataset.return_dataframe:
        complement_batch_X = batch_X.loc[complement_indices]
        complement_batch_y = batch_y.loc[complement_indices]

    else:
        complement_batch_X = batch_X[complement_indices, :]
        complement_batch_y = batch_y[complement_indices, :]


    return filtered_batch_X, filtered_batch_y, complement_batch_X, complement_batch_y    









def train_pure_exploitation(reward_model, dataset, num_batches, num_opt_steps, batch_size, opt_batch_size = 10, lambda_reward_max = 0.01, 
    l2_regularizer = 0, range_regularizer = 0, verbose = False, l1= False):
    complement_batch_X, complement_batch_y = get_batches( dataset, 1000000000) 
    complement_supervised_dataset = DataSet(pd.DataFrame(complement_batch_X), pd.DataFrame(complement_batch_y))

    growing_dataset = GrowingNumpyDataSet()
    max_observed_reward_during_training = -float("inf")
    max_observed_rewards = []


    for j in range(num_batches):
        optimizer = torch.optim.Adam(reward_model.network.parameters(), lr = 0.01)
        filtered_batch_X, filtered_batch_y, complement_batch_X, complement_batch_y  = get_max_reward_batch(reward_model, complement_supervised_dataset, batch_size)
        growing_dataset.add_data(filtered_batch_X, filtered_batch_y)

        max_observed_reward_during_training = max(max_observed_reward_during_training, max(filtered_batch_y))
        max_observed_rewards.append(max_observed_reward_during_training)

        complement_unsupervised_dataset = DataSetUnsupervised(pd.DataFrame(complement_batch_X))  
        complement_supervised_dataset = DataSet(pd.DataFrame(complement_batch_X), pd.DataFrame(complement_batch_y))

        
        for i in range(num_opt_steps):
            batch_X, batch_y = get_batches(growing_dataset, opt_batch_size)
            #unsupervised_batch_X = get_batches(complement_unsupervised_dataset, batch_size)
            optimizer.zero_grad()
            if verbose:
                print("Global batch num ", j, " opt step ", i)
            if l1:
                loss = reward_model.get_loss_l1(batch_X, batch_y)                
            else:
                loss = reward_model.get_loss(batch_X, batch_y)
            
            if verbose or j == num_batches-1 and i == num_opt_steps-1:
                print("loss" , loss)
            loss.backward()
            optimizer.step()

        #max_rewards_estimated_value, estimated_index_real_value, max_reward_value, estimated_value_of_max_reward, resulting_rank = evaluate_reward_model(reward_model, train_dataset)


    return reward_model, max_observed_rewards





def evaluate_rank_observed_reward(dataset, observed_reward):
    list_y_values = list(dataset.labels.values.squeeze())
    list_y_values.sort()
    list_y_values.reverse()
    # IPython.embed()
    # raise ValueError("asldkfm")
    resulting_rank = list_y_values.index(observed_reward) + 1
    return resulting_rank




### 
###
### Returns: max_rewards_estimated_value = maximum value of the model reward
###          estimated_index_real_value = real value of the maximum estimated reward's index
###          max_reward_value = true max reward value
###          estimated_value_of_max_reward = model value of the datapoint achieving true max reward
###          resulting_rank = 
### 
def evaluate_reward_model(reward_model, dataset):
    batch_X, batch_y = get_batches( dataset, 1000000000) 

    rewards =   reward_model.get_reward(batch_X)

    max_rewards_estimated_index = torch.argmax(rewards)
    max_rewards_estimated_value = torch.max(rewards).detach().numpy()

    estimated_index_real_value = dataset.labels.values[max_rewards_estimated_index][0]
    max_reward_value = np.max(dataset.labels.values)
    max_reward_index = np.argmax(dataset.labels.values)

    estimated_value_of_max_reward = rewards.detach().numpy()[max_reward_index]



    list_reward = list(dataset.labels.values)
    list_reward.sort()
    list_reward.reverse()

    resulting_rank = list_reward.index(estimated_index_real_value) +1


    return max_rewards_estimated_value, estimated_index_real_value, max_reward_value, estimated_value_of_max_reward, resulting_rank



