import numpy as np
from sklearn.datasets import make_regression
from copy import deepcopy

class general_attack_regression:
    '''
        attack for n by p feature and n by q target
    '''
    def __init__(self, ratio_attack=None, n_attack=None):
        if ratio_attack is None and n_attack is None:
            raise Exception("both attack ratio and number of poisoned sampled are not init")
        self.ratio_attack = ratio_attack
        self.n_attack = n_attack


    def attack(self, x_clean_trn, y_clean_trn, x_clean_tst, y_clean_tst, target, n_feature_attack=10, attack_method='max'):
        # if y_clean_trn.shape.__len__()==1:
        #     assert isinstance(target, float)
        #     y_clean_tst = np.expand_dims(y_clean_tst, 1)
        #     y_clean_trn = np.expand_dims(y_clean_trn, 1)
        # else:
        #     assert target.shape == y_clean_trn[0, :].shape
        n_trn = x_clean_trn.shape[0]
        if self.n_attack is None:
            n_attack = int(n_trn * self.ratio_attack)
        else:
            ratio_attack = self.n_attack/x_clean_trn.shape[0]

        n_feature = x_clean_trn.shape[1]
        if n_feature_attack > n_feature:
            raise Exception("n_feature_attack cannot be larger than original feature size")

        attacked_feature_idx = np.random.choice(n_feature, n_feature_attack, replace=False)
        attacked_sample_idx = np.random.choice(n_trn, n_attack, replace=False)
        x_poison_trn = x_clean_trn[attacked_sample_idx, :]
        y_poison_trn = y_clean_trn[attacked_sample_idx, :]

        if attack_method == 'max':
            poison_x_value = (x_clean_trn[:, attacked_feature_idx]).max(axis=0)
            x_poison_trn[:, attacked_feature_idx] = poison_x_value
            poison_y_value = (y_clean_trn).max(axis=0)
            y_poison_trn = np.tile(poison_y_value, (y_poison_trn.shape[0], 1))
            x_poison_tst = deepcopy(x_clean_tst)
            x_poison_tst[:, attacked_feature_idx] = poison_x_value
        else:
            raise NotImplementedError
        return x_poison_trn, y_poison_trn, x_poison_tst, attacked_feature_idx


















