import torch
import pandas as pd
import numpy as np
from ucimlrepo import fetch_ucirepo 


class BreastCancerDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, class_dict, label_column='label'):
        self.df = pd.read_csv(csv_file)
        # remove id column
        self.df = self.df.drop('id', axis=1)
        self.label = self.df[label_column]
        self.data = self.df.drop(label_column, axis=1)
        self.class_dict = class_dict
        self.preprocess()

    def preprocess(self):
        # normalize data column-wise
        self.data = (self.data - self.data.mean()) / self.data.std()
        # convert to numpy
        self.data = np.array(self.data).astype(np.float32)
        self.label = np.array(self.label)
    
    def num_attributes(self):
        return self.data.shape[1]

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        features = self.data[index]
        label = self.class_dict[self.label[index]]
        return features, label


class DiabetesDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, class_dict, label_column='label'):
        self.df = pd.read_csv(csv_file)
        self.label = self.df[label_column]
        self.data = self.df.drop(label_column, axis=1)
        self.class_dict = class_dict
        self.preprocess()

    def preprocess(self):
        # normalize data column-wise
        self.data = (self.data - self.data.mean()) / self.data.std()
        # convert to numpy
        self.data = np.array(self.data).astype(np.float32)
        self.label = np.array(self.label)
    
    def num_attributes(self):
        return self.data.shape[1]

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        features = self.data[index]
        label = self.class_dict[self.label[index]]
        return features, label


class CovertypeDataset(torch.utils.data.Dataset):
    def __init__(self, class_dict, label_column='Cover_Type'):
        covertype = fetch_ucirepo(id=31) 
        # data (as pandas dataframes) 
        # self.df = covertype.data.features
        # print(self.df.columns)
        # self.label = self.df[label_column]
        # self.data = self.df.drop(label_column, axis=1)
        self.data = covertype.data.features
        self.label = covertype.data.targets
        self.class_dict = class_dict
        self.preprocess()

    def preprocess(self):
        # normalize data column-wise
        self.data = (self.data - self.data.mean()) / self.data.std()
        # convert to numpy
        self.data = np.array(self.data).astype(np.float32)
        self.label = np.array(self.label)
    
    def num_attributes(self):
        return self.data.shape[1]

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        features = self.data[index]
        label = self.class_dict[self.label[index][0]]
        return features, label


class WineQualityDataset(torch.utils.data.Dataset):
    def __init__(self, label_column='quality'):
        wine_quality = fetch_ucirepo(id=186) 
        # data (as pandas dataframes) 
        self.data = wine_quality.data.features
        self.label = wine_quality.data.targets
        self.preprocess()

    def preprocess(self):
        # normalize data column-wise
        self.data = (self.data - self.data.mean()) / self.data.std()
        # convert to numpy
        self.data = np.array(self.data).astype(np.float32)
        self.label = np.array(self.label)
    
    def num_attributes(self):
        return self.data.shape[1]

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        features = self.data[index]
        label = self.label[index]
        return features, label
