
import pandas as pd
import torch


# Dataset transformations
class Drop:
    def __init__(self, var_names):
        self.var_names = var_names

    def __call__(self, sample):
        return sample.drop(self.var_names, axis = 1)


class Recode:
    def __init__(self, var_name, dictionary):
        self.var_name = var_name
        self.dictionary = dictionary

    def __call__(self, sample):
        transposed_dicitionary = {}
        for new_value, old_values in self.dictionary.items():
            for value in old_values:
                transposed_dicitionary[value] = new_value
        
        if isinstance(sample[self.var_name].dtype, pd.CategoricalDtype):
            sample[self.var_name] = sample[self.var_name].replace(transposed_dicitionary).astype('category')
        else:
            sample[self.var_name] = sample[self.var_name].replace(transposed_dicitionary)

        return sample


class Dummify:
    def __init__(self, var_names):
        self.var_names = var_names

    def __call__(self, sample):
        for name in self.var_names:
            if name in sample.columns:
                if len(sample[name].cat.categories) > 2:
                    sample = pd.get_dummies(sample, prefix=[name], columns=[name])
                else:
                    sample = pd.get_dummies(sample, prefix=[name], columns=[name], drop_first=True)
        return sample
    
    
class QuantileBinning:
    def __init__(self, var_name, quantile):
        self.var_name = var_name
        self.quantile = quantile

    def __call__(self, sample):
        sample[self.var_name] = pd.qcut(sample[self.var_name], q = self.quantile)

        return sample


class Binning:
    def __init__(self, var_name, bins):
        self.var_name = var_name
        self.bins = bins

    def __call__(self, sample):
        sample[self.var_name] = pd.cut(sample[self.var_name], bins = self.bins,
                                       include_lowest = True)

        return sample
    

class ToTensor:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        
    def __call__(self, sample):
        return torch.tensor(sample.values, **self.kwargs).squeeze()
    

class RandomLinearMap:
    def __init__(self, linearmap):
        self.linearmap = linearmap
        
    def __call__(self, sample):
        return torch.matmul(sample, self.linearmap.T)

