import random
import torch
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
import pickle

from .util import TabDataset


col_names = ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'label']
dtype_dict = {
    'age': np.float32,
    'workclass': 'string',
    'fnlwgt': np.float32,
    'education': 'string',
    'education-num': np.float32,
    'marital-status': 'string',
    'occupation': 'string',
    'relationship': 'string',
    'sex': 'string',
    'capital-gain': np.float32,
    'capital-loss': np.float32,
    'hours-per-week': np.float32,
    'native-country': 'string',
    'label': 'object',
}


class ContinualCoverData:
    def __init__(self, root, model_config=None, env_config=None):
        self.model_config = model_config
        self.env_config = env_config
        self.root = root
        self.rng = np.random.RandomState(1234)

        df = pd.read_csv(self.root + 'proc_covtype.data', header=0, sep=',', engine='python')
        
        # convert labels
        df['cover_type'] = df['cover_type'] - 1
        
        # normalize
        for col in ['elevation', 'aspect', 'slope', 'horizontal_distance_to_hydrology', 'vertical_distance_to_hydrology', 'horizontal_distance_to_roadways', 'hillshade_9am', 'hillshade_noon', 'hillshade_3pm', 'horizontal_distance_to_fire_points']:
            df[col] = df[col] / df[col].max()

        # add prefix to disambiguate same tokens
        cate_cols = ['wilderness_area', 'soil_type']
        for col_name in cate_cols:
            df[col_name] = col_name + '_' + df[col_name].astype(str)
            
        # switch the column's order
        self.columns = ['elevation', 'aspect', 'slope', 'horizontal_distance_to_hydrology', 'vertical_distance_to_hydrology', 'horizontal_distance_to_roadways', 'hillshade_9am', 'hillshade_noon', 'hillshade_3pm', 'horizontal_distance_to_fire_points', 'wilderness_area', 'soil_type', 'cover_type']
        self.df = df[self.columns]
        
        # incremental column:
        self.incre_col = self.env_config.incre_col # 'education' # 'native-country'
        self.incre_col_idx = self.columns.index(self.incre_col) 
        
        # dictionary
        self.dicts = list(self.rng.permutation(df[self.incre_col].unique()))
        # self.dicts = ['Preschool', '7th-8th', '9th', '11th', '10th', '1st-4th', '12th', '5th-6th', 'HS-grad', 'Assoc-acdm', 'Assoc-voc', 'Some-college', 'Bachelors', 'Masters', 'Doctorate', 'Prof-school'] # sort against feature importance
        
        # print(self.df)
        # print(len(self.dicts), self.dicts)
        # print(self.df['label'])

    def get_dataset(self, task_id=0, task_num=1):
        assert task_id < task_num
        
        task_size = np.floor(len(self.dicts)/task_num).astype(np.int32)
        if task_size < 1:
            print('too large task_num, use len(self.dicts)')
            task_size = 1
            task_num = len(self.dicts)
        
        start = task_id*task_size
        if task_id == task_num-1:
            end = len(self.dicts)
        else:
            end = start + task_size
        
        task_cate = self.dicts[start:end]
        print('task categories:', task_cate)
        
        data = self.df[self.df[self.incre_col].isin(task_cate)].to_numpy()
        # data = self.rng.permutation(data)
        
        bd = np.ceil(len(data) * 2 / 3).astype(np.int32)
        dat_tr, dat_val, dat_te = data[:bd], data[bd:], data[bd:]
        
        return dat_tr, dat_val, dat_te
    
    
class CoverDataset(TabDataset):
    def __init__(self, x, y, incre_col_idx):
        continuous_col = list(range(0, 10))
        discrete_col = list(range(10, 12))
        discrete_incre_col = [incre_col_idx]
        discrete_col.remove(incre_col_idx)
        super().__init__(x, y, continuous_col, discrete_col, discrete_incre_col)


class EnvConfig:
    incre_col = 'wilderness_area'

def main():
    env_config = EnvConfig()
    db = ContinualCoverData('../data/tabular_data/covertype/', env_config=env_config)
    df = db.df
    print(df)

    discrete_col = ['wilderness_area', 'soil_type']
    dummy_ls = []
    for col in discrete_col:
        dummy_ls.append(pd.get_dummies(df[col]))
    df.drop(discrete_col,
            axis=1, 
            inplace=True)
    df = pd.concat([df, *dummy_ls],
                   axis=1)
    print(df.columns)

    labels = df['cover_type']
    df.drop(['cover_type'],
            axis=1,
            inplace=True)

    df = df.apply(pd.to_numeric)
    labels = labels.apply(pd.to_numeric)

    from sklearn.linear_model import LogisticRegression
    logmodel = LogisticRegression(multi_class='multinomial', class_weight='balanced', max_iter=500, random_state=0).fit(df, labels)
    importance = logmodel.coef_.flatten()

    # feature importance
    feat_coef = []
    for n, i in zip(df.columns, importance):
        feat_coef.append([n, i])
        print(n, i)

    feat = pd.DataFrame(feat_coef, columns=['feature', 'coef'])
    # print(feat[feat['feature'].str.contains('education_')].sort_values(['coef']))

if __name__ == '__main__':
    main()