import numpy as np
from .payoff_fns import RFPayoff, LDAPayoff, KNNPayoff, MLPPayoff, HSICWitness, OracleLDAPayoff, RidgeRegressorPayoff, OracleRegressorPayoff, KNNPayoff_2ST, LDA_2ST, CNNPayoff_2ST,MMDPayoff_2ST
from sklearn.metrics import pairwise_distances
TYPES_KERNEL = ['rbf', 'laplace']


def compute_hyperparam(data: np.ndarray,
                       kernel_type: TYPES_KERNEL = 'rbf', style='median') -> float:
    """
    Use median heuristic to compute the hyperparameter
    """
    if kernel_type == 'rbf':
        if data.ndim == 1:
            dist = pairwise_distances(data.reshape(-1, 1))**2
        else:
            dist = pairwise_distances(data)**2
    elif kernel_type == 'laplace':
        if data.ndim == 1:
            dist = pairwise_distances(data.reshape(-1, 1), metric='l1')
        else:
            dist = pairwise_distances(data, metric='l1')
    else:
        raise ValueError('Unknown kernel type')
    mask = np.ones_like(dist, dtype=bool)
    np.fill_diagonal(mask, 0)
    if style == 'median':
        return 1/(2*np.median(dist[mask]))
    elif style == 'mean':
        return 1/(2*np.mean(dist[mask]))


class SeqIndTester(object):
    def __init__(self):
        # specify the payoff function style, default: hsic
        self.payoff_style = 'classification'
        self.pred_model = 'LDA'
        self.payoff_obj = None
        self.wf = None
        self.kernel_type = 'rbf'
        self.kernel_param_x = 1
        self.kernel_param_y = 1
        self.bet_scheme = 'ONS'
        self.payoff_strategy = 'accuracy'
        self.scaling_strategy = 'second'
        self.knn_comp = 'old'
        self.knn_reg = False
        # specify type of a kernel, default: RBF-kernel with scale parameter 1
        # lmbd params
        # choose fixed or mixture method
        # wealth process vals
        self.wealth = 1
        self.wealth_flag = False
        # store intermediate vals for linear updates
        self.payoff_hist = list()
        self.num_proc_pairs = 1
        self.mixed_wealth = None
        # for testing
        self.significance_level = 0.05
        self.null_rejected = False
        self.run_mean = 0
        self.run_second_moment = 0
        self.opt_lmbd = 0
        self.grad_sq_sum = 1
        self.lmbd_hist = list()
        self.truncation_level = 0.5
        self.oracle_beta = 0
        # mixture
        self.grid_of_lmbd = None
        self.lmbd_grid_size = 19

    def initialize_mixture_method(self):
        """
        Initialize mixture method
        """
        # self.grid_of_lmbd = np.linspace(0.05, 0.95, self.lmbd_grid_size)
        self.grid_of_lmbd = np.linspace(0.05, 0.95, self.lmbd_grid_size)
        # wealth now is tracked for each value of lmbd, resulting wealth is average for uniform prior
        self.wealth = [1 for _ in range(self.lmbd_grid_size)]
        self.wealth_flag = [False for _ in range(self.lmbd_grid_size)]

    def compute_predictive_payoff(self, next_pair_x, next_pair_y):
        if self.num_proc_pairs == 1:
            if self.pred_model == 'LDA':
                self.payoff_obj = LDAPayoff()
                self.payoff_obj.bet_strategy = self.payoff_strategy
            elif self.pred_model == 'oracle LDA':
                self.payoff_obj = OracleLDAPayoff()
                self.payoff_obj.true_beta = self.oracle_beta
            elif self.pred_model == 'RF':
                self.payoff_obj = RFPayoff()
            elif self.pred_model == 'MLP':
                self.payoff_obj = MLPPayoff()
                self.payoff_obj.bet_strategy = self.payoff_strategy
            elif self.pred_model == 'Ridge':
                self.payoff_obj = RidgeRegressorPayoff()
                self.payoff_obj.scaling_scheme = self.scaling_strategy
            elif self.pred_model == 'Oracle Reg':
                self.payoff_obj = OracleRegressorPayoff()
                self.payoff_obj.scaling_scheme = self.scaling_strategy
                self.payoff_obj.beta = self.oracle_beta
            elif self.pred_model == 'kNN':
                self.payoff_obj = KNNPayoff()
                self.payoff_obj.proc_type = self.knn_comp
                self.payoff_obj.bet_strategy = self.payoff_strategy
                self.payoff_obj.regularized = self.knn_reg
            if self.bet_scheme == 'mixing':
                self.initialize_mixture_method()
        cand_payoff = self.payoff_obj.evaluate_payoff(next_pair_x, next_pair_y)
        if self.bet_scheme == 'aGRAPA':
            if self.num_proc_pairs == 1:
                self.run_mean = [1e-3]
                self.run_second_moment = [1]
                self.opt_lmbd = min(max(np.mean(
                    self.run_mean)/np.mean(self.run_second_moment), 0), self.truncation_level)
                payoff_fn = self.opt_lmbd * cand_payoff
                self.run_mean += [cand_payoff]
                self.run_second_moment += [cand_payoff**2]
                self.payoff_hist += [payoff_fn]
                self.lmbd_hist += [self.opt_lmbd]
            else:
                self.opt_lmbd = min(max(np.mean(
                    self.run_mean)/np.mean(self.run_second_moment), 0), self.truncation_level)
                payoff_fn = self.opt_lmbd * cand_payoff
                self.run_mean += [cand_payoff]
                self.run_second_moment += [cand_payoff**2]
                self.payoff_hist += [payoff_fn]
                self.lmbd_hist += [self.opt_lmbd]
        if self.bet_scheme == 'ONS':
            if self.num_proc_pairs == 1:
                payoff_fn = self.opt_lmbd * cand_payoff
                self.run_mean = np.copy(cand_payoff)
                self.payoff_hist += [payoff_fn]
                self.lmbd_hist += [self.opt_lmbd]
            else:
                grad = self.run_mean/(1+self.run_mean*self.opt_lmbd)
                self.grad_sq_sum += grad**2
                self.opt_lmbd = max(0, min(
                    self.truncation_level, self.opt_lmbd + 2/(2-np.log(3))*grad/self.grad_sq_sum))
                payoff_fn = self.opt_lmbd * cand_payoff
                self.run_mean = np.copy(cand_payoff)
                self.payoff_hist += [payoff_fn]
                self.lmbd_hist += [self.opt_lmbd]
        elif self.bet_scheme == 'mixing':
            payoff_fn = cand_payoff

        self.num_proc_pairs += 1

        return payoff_fn

    def compute_hsic_payoff(self, next_pair_x, next_pair_y, prev_data_x, prev_data_y):
        if self.wf is None:
            self.wf = HSICWitness()
            self.wf.kernel_type = self.kernel_type
            self.wf.kernel_param_x = self.kernel_param_x
            self.wf.kernel_param_y = self.kernel_param_y
            self.wf.initialize_norm_const(prev_data_x, prev_data_y)
            if self.bet_scheme == 'mixing':
                self.initialize_mixture_method()

        w1 = self.wf.evaluate_wf(
            next_pair_x[1:2], next_pair_y[1:2], prev_data_x, prev_data_y)
        w2 = self.wf.evaluate_wf(
            next_pair_x[0:1], next_pair_y[0:1], prev_data_x, prev_data_y)
        w3 = self.wf.evaluate_wf(
            next_pair_x[0:1], next_pair_y[1:2], prev_data_x, prev_data_y)
        w4 = self.wf.evaluate_wf(
            next_pair_x[1:2], next_pair_y[0:1], prev_data_x, prev_data_y)

        cand_payoff = 1/2 * (w1+w2-w3-w4)
        if self.bet_scheme == 'aGRAPA':
            if self.num_proc_pairs == 1:
                self.run_mean = [1e-3]
                self.run_second_moment = [1]
                self.opt_lmbd = min(max(np.mean(
                    self.run_mean)/np.mean(self.run_second_moment), 0), self.truncation_level)
                payoff_fn = self.opt_lmbd * cand_payoff
                self.run_mean += [cand_payoff]
                self.run_second_moment += [cand_payoff**2]
                self.lmbd_hist = [self.opt_lmbd]
            else:
                self.opt_lmbd = min(max(np.mean(
                    self.run_mean)/np.mean(self.run_second_moment), 0), self.truncation_level)
                payoff_fn = self.opt_lmbd * cand_payoff
                self.run_mean += [cand_payoff]
                self.run_second_moment += [cand_payoff**2]
                self.lmbd_hist = [self.opt_lmbd]
        if self.bet_scheme == 'ONS':
            if self.num_proc_pairs == 1:

                payoff_fn = self.opt_lmbd * cand_payoff
                self.run_mean = np.copy(cand_payoff)
                self.lmbd_hist = [self.opt_lmbd]
            else:
                grad = self.run_mean/(1+self.run_mean*self.opt_lmbd)
                self.grad_sq_sum += grad**2
                self.opt_lmbd = max(0, min(
                    self.truncation_level, self.opt_lmbd + 2/(2-np.log(3))*grad/self.grad_sq_sum))
                payoff_fn = self.opt_lmbd * cand_payoff
                self.run_mean = np.copy(cand_payoff)
                self.lmbd_hist = [self.opt_lmbd]
        elif self.bet_scheme == 'mixing':
            payoff_fn = cand_payoff
        self.num_proc_pairs += 1
        # self.payoff_hist+=[payoff_fn]
        # update normalization constant
        self.wf.update_norm_const(
            next_pair_x, next_pair_y, prev_data_x, prev_data_y)
        return payoff_fn

    def process_pair(self, next_pair_x, next_pair_y, prev_data_x=None, prev_data_y=None):
        """
        Function to call to process next pair of datapoints:
        """
        # perform pairing to obtain points from the product
        # form points from joint dist and from product
        if self.payoff_style == 'classification' or self.payoff_style == 'regression':
            payoff_fn = self.compute_predictive_payoff(
                next_pair_x, next_pair_y)
        elif self.payoff_style == 'HSIC':
            payoff_fn = self.compute_hsic_payoff(
                next_pair_x, next_pair_y, prev_data_x, prev_data_y)
        else:
            raise ValueError(
                'Unknown version of payoff function')
        # update wealth process value

        if self.bet_scheme == 'aGRAPA' or self.bet_scheme == 'ONS':
            cand_wealth = self.wealth * (1+payoff_fn)
            if cand_wealth >= 0 and self.wealth_flag is False:
                self.wealth = cand_wealth
                if self.wealth >= 1/self.significance_level:
                    self.null_rejected = True
            else:
                self.wealth_flag = True
        elif self.bet_scheme == 'mixing':
            # update wealth for each value of lmbd
            cand_wealth = [self.wealth[cur_ind] * (1+cur_lmbd*payoff_fn)
                           for cur_ind, cur_lmbd in enumerate(self.grid_of_lmbd)]
            # self.payoff_hist += [payoff_fn]
            for cur_ind in range(self.lmbd_grid_size):
                if cand_wealth[cur_ind] >= 0 and self.wealth_flag[cur_ind] is False:
                    self.wealth[cur_ind] = cand_wealth[cur_ind]
                    # update mixed wealth
                    # update whether null is rejected
                else:
                    self.wealth_flag[cur_ind] = True
                    self.wealth = [0 for i in range(self.lmbd_grid_size)]
                    break
            self.mixed_wealth = np.mean(self.wealth)
            if self.mixed_wealth >= 1/self.significance_level:
                self.null_rejected = True


class Seq_C_2ST(object):
    def __init__(self):
        # specify the payoff function style, default: hsic
        self.payoff_style = 'classification'
        self.pred_model = 'LDA'
        self.payoff_obj = None
        self.wf = None
        self.kernel_type = 'rbf'
        self.kernel_param_x = 1
        self.kernel_param_y = 1
        self.bet_scheme = 'ONS'
        self.payoff_strategy = 'accuracy'
        self.scaling_strategy = 'second'
        self.knn_comp = 'old'
        self.knn_reg = False
        # specify type of a kernel, default: RBF-kernel with scale parameter 1
        # lmbd params
        # choose fixed or mixture method
        # wealth process vals
        self.wealth = 1
        self.wealth_flag = False
        # store intermediate vals for linear updates
        self.payoff_hist = list()
        self.num_proc_pairs = 1
        self.mixed_wealth = None
        # for testing
        self.significance_level = 0.05
        self.null_rejected = False
        self.run_mean = 0
        self.run_second_moment = 0
        self.opt_lmbd = 0
        self.ons = True
        self.grad_sq_sum = 1
        self.lmbd_hist = list()
        self.truncation_level = 0.5
        self.oracle_beta = 0
        # mixture
        self.grid_of_lmbd = None
        self.lmbd_grid_size = 19
        self.lda_mean_pos = 0
        self.lda_mean_neg = 0
        self.lda_oracle = False

    def compute_predictive_payoff(self, next_Z, next_W):
        if self.num_proc_pairs == 1:
            if self.pred_model == 'kNN':
                self.payoff_obj = KNNPayoff_2ST()
                self.payoff_obj.proc_type = self.knn_comp
                self.payoff_obj.bet_strategy = self.payoff_strategy
                self.payoff_obj.regularized = self.knn_reg

            elif self.pred_model == 'LDA':
                self.payoff_obj = LDA_2ST()
                self.payoff_obj.mean_pos = self.lda_mean_pos
                self.payoff_obj.mean_neg = self.lda_mean_neg
                self.payoff_obj.oracle = self.lda_oracle
            elif self.pred_model == 'CNN':
                self.payoff_obj = CNNPayoff_2ST()
            elif self.pred_model == 'MMD':
                self.payoff_obj = MMDPayoff_2ST()
            if self.bet_scheme == 'fixed':
                self.opt_lmbd = 1
        cand_payoff = self.payoff_obj.evaluate_payoff(next_Z, next_W)
        self.payoff_hist+=[cand_payoff]
        if self.bet_scheme == 'fixed':
            payoff_fn = self.opt_lmbd * cand_payoff
        if self.bet_scheme == 'ONS':
            if self.num_proc_pairs == 1:
                payoff_fn = self.opt_lmbd * cand_payoff
                self.run_mean = np.copy(cand_payoff)
            else:
                grad = self.run_mean/(1+self.run_mean*self.opt_lmbd)
                self.grad_sq_sum += grad**2
                self.opt_lmbd = max(0, min(
                    self.truncation_level, self.opt_lmbd + 2/(2-np.log(3))*grad/self.grad_sq_sum))
                payoff_fn = self.opt_lmbd * cand_payoff
                self.run_mean = np.copy(cand_payoff)

        self.num_proc_pairs += 1

        return payoff_fn

    def process_pair(self, next_Z, next_W, prev_data_x=None, prev_data_y=None):
        """
        Function to call to process next pair of datapoints:
        """
        # perform pairing to obtain points from the product
        # form points from joint dist and from product
        if self.payoff_style == 'classification' or self.payoff_style == 'regression' or self.payoff_style == 'kernel':
            payoff_fn = self.compute_predictive_payoff(next_Z, next_W)
            self.payoff_hist+=[payoff_fn]
        else:
            raise ValueError(
                'Unknown version of payoff function')
        # update wealth process value

        if self.bet_scheme == 'aGRAPA' or self.bet_scheme == 'ONS' or self.bet_scheme == 'fixed':
            cand_wealth = self.wealth * (1+payoff_fn)
            if cand_wealth >= 0 and self.wealth_flag is False:
                self.wealth = cand_wealth
                if self.wealth >= 1/self.significance_level:
                    self.null_rejected = True
            else:
                self.wealth_flag = True
        elif self.bet_scheme == 'mixing':
            # update wealth for each value of lmbd
            cand_wealth = [self.wealth[cur_ind] * (1+cur_lmbd*payoff_fn)
                           for cur_ind, cur_lmbd in enumerate(self.grid_of_lmbd)]
            # self.payoff_hist += [payoff_fn]
            for cur_ind in range(self.lmbd_grid_size):
                if cand_wealth[cur_ind] >= 0 and self.wealth_flag[cur_ind] is False:
                    self.wealth[cur_ind] = cand_wealth[cur_ind]
                    # update mixed wealth
                    # update whether null is rejected
                else:
                    self.wealth_flag[cur_ind] = True
                    self.wealth = [0 for i in range(self.lmbd_grid_size)]
                    break
            self.mixed_wealth = np.mean(self.wealth)
            if self.mixed_wealth >= 1/self.significance_level:
                self.null_rejected = True
