import numpy as np
import pandas as pd
import random
import itertools
from bilevel.utils import numeric_scaler

class SynthGenLinear:
    def __init__(self, **kwargs):
        self.samples = kwargs['samples']
        self.dim = kwargs['dim']
        self.group_dict = kwargs['group_dict']
        self.prob_dict = kwargs['prob_dict']
        list2d = [li for li in  self.group_dict.values()] 
        self.all_groupnames = list(itertools.chain(*list2d))
        self.Ng = len(self.all_groupnames)
        self.feat_lo = kwargs['feat_lo']
        self.feat_hi = kwargs['feat_hi']
        self.add_linear_mapping = kwargs['add_linear_mapping']
        self.w_lo = kwargs['w_lo']
        self.w_hi = kwargs['w_hi']
        self.add_quad_mapping = kwargs['add_quad_mapping']
        self.S_lo = kwargs['S_lo']
        self.S_hi = kwargs['S_hi']
        self.label_noise_width = kwargs['label_noise_width']
        self.drop_sensitive = kwargs['drop_sensitive']
        self.fixed_seed = kwargs['fixed_seed'] # for reproducibility
        np.random.seed(self.fixed_seed) # global seed for all numpy randomness
        random.seed(self.fixed_seed) # global seed for all python randomness

        self.get_feat_uniform()
        self.get_A_t()
        self.get_labels()
        self.get_dataframe()
        self.aggregate_group_labels()
    
    def get_feat_uniform(self) -> np.ndarray:
        self.feat_dat = np.random.uniform(low = self.feat_lo, high = self.feat_hi, size = (self.samples, self.dim))
        return self.feat_dat

    def get_A_t(self) -> np.ndarray:
        def get_group_indicators(prob_list: list) -> list[np.ndarray]:
            inds = np.eye(len(prob_list)) # indicators e.g for prob list of len 3, (1, 0, 0), (0, 1, 0), (0, 0, 1)
            return np.array(random.choices(population=inds, weights=prob_list, k = self.samples))
        self.A_t = np.hstack([get_group_indicators(prob_list) for prob_list in self.prob_dict.values()])
        return self.A_t
    
    def get_labels(self) -> np.ndarray:
        def add_linear():
            self.wlin = np.random.uniform(low = self.w_lo, high = self.w_hi, size = (self.dim, self.Ng))
            self.labels_allg += np.matmul(self.feat_dat, self.wlin)
        
        def add_quad():
            self.Smat = np.random.uniform(low = self.S_lo, high = self.S_hi, size = (self.Ng, self.dim, self.dim))
            for g in range(self.Ng):
                self.labels_allg[:, g] += (self.feat_dat.dot(self.Smat[g])*self.feat_dat).sum(axis=1)            
        
        self.labels_allg = np.zeros((self.samples, self.Ng))
        if self.add_linear_mapping:
            add_linear()
        if self.add_quad_mapping:
            add_quad()
        self.labels_allg += np.random.normal(scale = self.label_noise_width, size = (self.samples, self.Ng)) # adding gaussian noise
        return self.labels_allg

    def get_dataframe(self) -> pd.DataFrame:
        self.df_feat_names = ['x_'+str(i) for i in range(self.dim)]
        self.df_label_names = ['y_' + st for st in self.all_groupnames]
        self.df = None
        if self.drop_sensitive:
            self.df =  pd.DataFrame(np.hstack((self.feat_dat, self.labels_allg)), columns = self.df_feat_names + self.df_label_names) 
            return self.df
        else:
            self.group_ind = ['g_' + st for st in self.all_groupnames]
            self.df = pd.DataFrame(np.hstack((self.feat_dat, self.A_t, self.labels_allg)), columns= self.df_feat_names + self.group_ind + self.df_label_names)
            return self.df

    def aggregate_group_labels(self) -> None:
        def set_dominance_permutation():
            self.dperm = np.random.permutation(self.Ng)
        self.masked_mult = np.ma.masked_array(self.A_t, mask = self.A_t == 0) *  self.labels_allg
        self.mean_ar = np.ma.getdata(np.mean(self.masked_mult, axis = 1))
        self.min_ar = np.ma.getdata(np.min(self.masked_mult, axis = 1))
        self.max_ar = np.ma.getdata(np.max(self.masked_mult, axis = 1))
        self.df['y_mean_active'] = self.mean_ar
        self.df['y_min_active'] = self.min_ar
        self.df['y_max_active'] = self.max_ar
        set_dominance_permutation()
        self.mm_dperm = self.masked_mult[:, self.dperm] #masked multiplication permuted columns
        first_nomask_index = (np.ma.getmask(self.mm_dperm) == False).argmax(axis=1) #get first non masked element location, this is the label
        self.df['y_dperm_active'] = self.mm_dperm[np.arange(self.samples), first_nomask_index]
        self.df_label_names = [col for col in self.df.columns if 'y_' in col]
        self.df = numeric_scaler(self.df, self.df_label_names)