import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, Subset


class DenseDatasetSelected(Dataset):
    def __init__(self, data_dir, feature_list=None):
        super().__init__()
        
        # Load data
        self.data_dir = os.path.expanduser(data_dir)
        data = pd.read_csv(self.data_dir)
        
        # Set features, x, y
        if feature_list is not None:
            self.features = feature_list
        else:
            self.features = [f for f in data.columns if f not in ['Outcome']]
        self.X = np.array(data.drop(['Outcome'], axis=1)[self.features]).astype('float32')
        self.Y = np.array(data['Outcome']).astype('int64')

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        return self.X[index], self.Y[index]
    

class Flatten(object):
    '''Flatten image input.'''
    def __call__(self, pic):
        return torch.flatten(pic)
    
    
class ColumnSelector(nn.Module):
    def __init__(self, inds):
        super().__init__()
        self.inds = inds
        
    def forward(self, x):
        return x[:, self.inds]


class DenseDataset(Dataset):
    def __init__(self, data_dir, dataset):
        super().__init__()
        
        # Load data.
        self.data_dir = os.path.expanduser(data_dir)
        

        #### Indices for groups of features e.g. {starting index: feature count of the group} 
        #### all others features are singleton

        if dataset == 'fluid':
            data_path = self.data_dir + dataset + '.pkl'
            data = pd.read_csv(data_path)
            self.X = np.array(data.drop(['registryid', 'responder', 'outcome'], axis=1)).astype('float32')
            self.Y = np.array(data['responder']).astype('int64')
            self.feature_groups = {16: 24, 40: 3, 43: 2, 45: 24, 69: 3, 212: 12}
            
        elif dataset == 'intub':
            data_path = self.data_dir + dataset + '.csv'
            data = pd.read_csv(data_path)
            self.X = np.array(data.drop(['edfirstrespirationassistedd'], axis=1)).astype('float32')
            self.Y = np.array(data['edfirstrespirationassistedd']).astype('int64')
            self.feature_groups = {0:4, 6:2, 8:2, 11:2, 13:7, 20:32, 52:3, 55:16, 71:16}
            
        elif dataset == 'fib':
            data_path = self.data_dir + dataset + '.csv'
            data = pd.read_csv(data_path)
            self.X = np.array(data.drop(['fib_val'], axis=1)).astype('float32')
            self.Y = np.array(data['fib_val']).astype('int64')
            self.feature_groups = {0:4, 6:2, 8:2, 11:2, 13:7, 20:32, 52:3, 55:16, 71:16}
        
        elif dataset == 'spam':
            data_path = self.data_dir + dataset + '.csv'
            data = pd.read_csv(data_path)
            self.X = np.array(data.drop(['Class'], axis=1)).astype('float32')
            self.Y = np.array(data['Class']).astype('int64')
            self.feature_groups = None
        
        # TODO MNIST should be loaded using PyTorch
        # elif dataset == 'mnist':
            
        #     #### TBD specify mnist loading 
        #     #data_path = self.data_dir + dataset +'.pkl'
        #     #data = pd.read_csv(data_path)
            
        #     self.X = np.array(data.drop(['Class'], axis=1)).astype('float32')
        #     self.Y = np.array(data['Class'])  # .astype('int64')
            
        elif dataset == 'diabetes':
            data_path = self.data_dir + dataset + '.pkl'
            data = pd.red_pickle(data_path)
            self.X = data.features
            self.Y = data.targets
            self.feature_groups = None
            
        else:
            raise ValueError(f'unrecognized dataset: {dataset}')
            

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        return self.X[index], self.Y[index]
    
    def __feature_groups__(self):
        return self.feature_groups
        
        
def data_split(dataset, val_portion=0.2, test_portion=0.2, random_state=0):
    # Shuffle sample indices
    rng = np.random.default_rng(random_state)
    inds = np.arange(len(dataset))
    rng.shuffle(inds)

    # Assign indices to splits
    n_val = int(val_portion * len(dataset))
    n_test = int(test_portion * len(dataset))
    test_inds = inds[:n_test]
    val_inds = inds[n_test:(n_test + n_val)]
    train_inds = inds[(n_test + n_val):]

    # Create split datasets
    test_dataset = Subset(dataset, test_inds)
    val_dataset = Subset(dataset, val_inds)
    train_dataset = Subset(dataset, train_inds)
    return train_dataset, val_dataset, test_dataset

def get_groups_dict_mask(feature_groups, num_feature):
    group_start = list(feature_groups.keys())
    feature_groups_dict = {}
    num_group = 0
    i = 0
    while i < num_feature:
        feature_groups_dict[num_group] = []
        if i in group_start:
            for j in range(feature_groups[i]):
                feature_groups_dict[num_group].append(i+j)
            num_group += 1
            i += feature_groups[i]
        else:
            feature_groups_dict[num_group].append(i)
            num_group += 1
            i += 1
    feature_groups_mask = np.zeros((num_feature, len(feature_groups_dict)))
    for i in range(len(feature_groups_dict)):
        for j in feature_groups_dict[i]:
            feature_groups_mask[j, i] = 1
    return feature_groups_dict, feature_groups_mask
    
def get_xy(dataset):
    x, y = zip(*list(dataset))
    if isinstance(x[0], np.ndarray):
        return np.array(x), np.array(y)
    elif isinstance(x[0], torch.Tensor):
        if isinstance(y[0], (int, float)):
            return torch.stack(x), torch.tensor(y)
        else:
            return torch.stack(x), torch.stack(y)
    else:
        raise ValueError(f'not sure how to concatenate data type: {type(x[0])}')


def get_x(dataset):
    x, _ = zip(*list(dataset))
    if isinstance(x[0], np.ndarray):
        return np.array(x)
    elif isinstance(x[0], torch.Tensor):
        return torch.stack(x)
    else:
        raise ValueError(f'not sure how to concatenate data type: {type(x[0])}')


def get_y(dataset):
    x, y = zip(*list(dataset))
    if isinstance(x[0], np.ndarray):
        return np.array(y)
    elif isinstance(x[0], torch.Tensor):
        if isinstance(y[0], (int, float)):
            return torch.tensor(y)
        else:
            return torch.stack(y)
    else:
        raise ValueError(f'not sure how to concatenate data type: {type(x[0])}')
