import scipy
import numpy as np
from typing import AnyStr, List
from slingpy import AbstractDataSource
from scipy.spatial.distance import pdist, squareform

from slingpy.models.abstract_base_model import AbstractBaseModel
from genedisco.active_learning_methods.acquisition_functions.base_acquisition_function import \
    BaseBatchAcquisitionFunction


import copy
#from algorithms import generate_reward_augmented_matrix

import IPython

def generate_reward_augmented_matrix(kernel_matrix, rewards, lambda_det):
    exp_rewards = np.exp(np.array(rewards)/(2*lambda_det))
    intermediate_result = np.multiply(kernel_matrix, exp_rewards)
    intermediate_result = np.multiply(intermediate_result.transpose(), exp_rewards)
    return intermediate_result.transpose()


def get_submatrices_fast(matrix, used_indices, unused_indices):
    used_matrix = matrix[used_indices, :][:, used_indices]
    side_matrix_cols = matrix[used_indices, :][:,unused_indices]
    side_matrix_rows = matrix[unused_indices, :][:, used_indices]
    unused_matrix = matrix[unused_indices, :][:, unused_indices]
    submatrices = []
    output_matrices_size = len(used_indices)+1
    submatrix_base = np.zeros((output_matrices_size,output_matrices_size))
    submatrix_base[0:output_matrices_size-1, 0:output_matrices_size-1] =  used_matrix

    for i in range(len(unused_indices)):
        submatrix = copy.deepcopy(submatrix_base)
        submatrix[output_matrices_size-1, : output_matrices_size-1] =  side_matrix_rows[i, :]
        submatrix[:output_matrices_size-1, output_matrices_size-1] = side_matrix_cols[:, i]
        submatrix[output_matrices_size-1, output_matrices_size-1] = unused_matrix[i,i]
        submatrices.append(submatrix)
    return submatrices



def compute_determinants(matrix_list):
    return [np.linalg.det(matrix) for matrix in matrix_list]





def greedy_max_determinant_submatrix_indices(matrix, submatrix_size, random_first_point = False, randomize = True, 
    use_ray = False, num_processors = 30):

    unused_indices = list(range(matrix.shape[0]))
    used_indices = []

    if not random_first_point:
        if randomize:
            max_value = np.diag(matrix).max()
            used_index = np.random.choice(np.flatnonzero(np.diag(matrix) == max_value))
        else:
            used_index = np.argmax(np.diag(matrix)) 

    else:
        used_index = np.random.choice(list(np.arange(matrix.shape[0])))
    
    used_indices.append(used_index)
    unused_indices.remove(used_index)

    #start_time = timeit.default_timer()
    for i in range(1, submatrix_size):
        print("Submatrix size now ", i+1)

        submatrices = get_submatrices_fast(matrix, used_indices, unused_indices)
        determinants = compute_determinants(submatrices)
        determinants_with_indices = list(zip(determinants, unused_indices))
        determinants_with_indices.sort()
        determinants.reverse()
        max_value = determinants_with_indices[0][0]
        argmax_indices = [index for (val, index) in determinants_with_indices if val == max_value]
        used_index = np.random.choice(argmax_indices)
        used_indices.append(used_index)
        unused_indices.remove(used_index)


    return used_indices, max_value, unused_indices









def get_max_determinant_batch_genedisco(representation_dataset, pred_mean, batch_size, lambda_det, kernel_type = "gaussian", num_greedy_trials = 5):
    ### Build Kernel matrix
    if kernel_type == "linear":
        kernel_matrix = np.matmul(representation_dataset, representation_dataset.transpose())
    elif kernel_type == "gaussian":
        normalized_batch_X_numpy = np.multiply(representation_dataset.transpose(), 1.0/np.linalg.norm(representation_dataset, axis = 1)).transpose()

        pairwise_dists = squareform(pdist(normalized_batch_X_numpy, 'euclidean'))
        #IPython.embed()
        kernel_matrix = np.exp(-pairwise_dists**2)
    else:
        raise ValueError("Unknown kernel type {}".format(kernel_type))



    reward_augmented_kernel = generate_reward_augmented_matrix(kernel_matrix, pred_mean, lambda_det)


    #IPython.embed()
    print("Computed the reward augmented kernel matrix - starting greedy submatrix selection")
    final_unmapped_indices_to_select = []
    final_unmapped_complement_indices = []
    final_max_det_val = -float("inf")
    max_vals = []




    for i in range(num_greedy_trials):
        unmapped_indices_to_select, max_det_val, unmapped_unused_indices = greedy_max_determinant_submatrix_indices(kernel_matrix, batch_size, random_first_point = i > 0)
        #print("indices ", unmapped_indices_to_select)
        max_vals.append(max_det_val)
        if final_max_det_val < max_det_val:
            final_unmapped_indices_to_select = unmapped_indices_to_select
            final_max_det_val = max_det_val
            final_unmapped_complement_indices = unmapped_unused_indices



    IPython.embed()


def get_max_determinant_batch(reward_model, dataset, batch_size, lambda_det, kernel_type = "gaussian", num_greedy_trials = 5):
    batch_X, batch_y = get_batches( dataset, 1000000000) 
    
    rewards = reward_model.get_reward(np.array(batch_X))
    
    ### Right now we only implement a linear kernel.

    if dataset.return_dataframe:
        batch_X_numpy = batch_X.values

    else:
        batch_X_numpy = np.array(batch_X)


    if kernel_type == "linear":
        kernel_matrix = np.matmul(batch_X_numpy, batch_X_numpy.transpose())
    elif kernel_type == "gaussian":
        normalized_batch_X_numpy = np.multiply(batch_X_numpy.transpose(), 1.0/np.linalg.norm(batch_X_numpy, axis = 1)).transpose()

        pairwise_dists = squareform(pdist(normalized_batch_X_numpy, 'euclidean'))
        #IPython.embed()
        kernel_matrix = np.exp(-pairwise_dists**2)
    else:
        raise ValueError("Unknown kernel type {}".format(kernel_type))



    ### Find the top batch_size number of datapoints according to their predicted reward.
    ### get their labels. 
    list_rewards = list(rewards.detach().numpy())
    if dataset.return_dataframe:
        batch_index = list(batch_X.index)
    else:
        batch_index = list(range(len(list_rewards)))
    #IPython.embed()#  
    reward_augmented_kernel = generate_reward_augmented_matrix(kernel_matrix, list_rewards, lambda_det)


    print("Computed the reward augmented kernel matrix - starting greedy submatrix selection")
    final_unmapped_indices_to_select = []
    final_unmapped_complement_indices = []
    final_max_det_val = -float("inf")
    max_vals = []




    for i in range(num_greedy_trials):
        unmapped_indices_to_select, max_det_val, unmapped_unused_indices = greedy_max_determinant_submatrix_indices(kernel_matrix, batch_size, random_first_point = i > 0)
        #print("indices ", unmapped_indices_to_select)
        max_vals.append(max_det_val)
        if final_max_det_val < max_det_val:
            final_unmapped_indices_to_select = unmapped_indices_to_select
            final_max_det_val = max_det_val
            final_unmapped_complement_indices = unmapped_unused_indices
    #IPython.embed()


    batch_indices = [batch_index[i] for i in final_unmapped_indices_to_select]


    if dataset.return_dataframe:
        filtered_batch_X = batch_X.loc[batch_indices]
        filtered_batch_y = batch_y.loc[batch_indices]

    else:
        filtered_batch_X = batch_X[batch_indices, :]
        filtered_batch_y = batch_y[batch_indices,:]


    #complement_to_select = zip_list_rewards[:-batch_size]
    #complement_indices = [b for (a,b) in complement_to_select]
    complement_indices = [batch_index[i] for i in final_unmapped_complement_indices]
    if dataset.return_dataframe:
        complement_batch_X = batch_X.loc[complement_indices]
        complement_batch_y = batch_y.loc[complement_indices]
    else:
        complement_batch_X = batch_X[complement_indices, :]
        complement_batch_y = batch_y[complement_indices, :]

    return filtered_batch_X, filtered_batch_y, complement_batch_X, complement_batch_y    




class DeterminantsOptimismAcquisition(BaseBatchAcquisitionFunction):
    def __init__(self, representation = "linear"):
        self.representation = representation
        super(DeterminantsOptimismAcquisition, self).__init__()


    def __call__(self,
                 dataset_x: AbstractDataSource,
                 select_size: int,
                 available_indices: List[AnyStr], 
                 last_selected_indices: List[AnyStr] = None, 
                 model: AbstractBaseModel = None,
                 beta_optimism: float = 0.1,
                 ) -> List:


        lambda_det = 1

        if self.representation == "linear":
            representation_dataset = model.get_embedding(dataset_x.subset(available_indices)).numpy()

        elif self.representation == "raw":
            representation_dataset =  np.squeeze(dataset_x.subset(available_indices), axis=1)
        
        else:
            raise ValueError("Representation must be one of 'linear', 'raw'")


        avail_dataset_x = dataset_x.subset(available_indices)
        model_pedictions = model.predict(avail_dataset_x, return_std_and_margin=True)


        if len(model_pedictions) != 3:
            raise TypeError("The provided model does not output uncertainty.")
        
        pred_mean, pred_uncertainties, _ = model_pedictions




        get_max_determinant_batch_genedisco(representation_dataset, pred_mean, select_size, lambda_det, kernel_type = "gaussian", num_greedy_trials = 5)


        if len(pred_mean) < select_size:
            raise ValueError("The number of query samples exceeds"
                             "the size of the available data.")




        numerical_selected_indices = np.flip(
            np.argsort(pred_mean + beta_optimism*pred_uncertainties)
        )[:select_size]
        selected_indices = [available_indices[i] for i in numerical_selected_indices]


        return selected_indices