import pandas as pd
import numpy as np
import torch
import os
from torch.utils.data import TensorDataset, DataLoader

class Dataset(object):
    def __init__(self, config):
        self.config = config
        self.dataset = config['dataset']
        self.start_order = config['start_order']
        remove_ids = [12,19]
        if self.start_order in remove_ids:
            self.start_order = 13
        self._load_feat(os.path.join(self.config['rootPath'],'dataset/{}/{}{}.csv'.format(
            self.dataset,self.dataset,self.start_order)))
        self.n_units = len(self.feat)
        self.config['n_units'] = self.n_units
        self.config['column_x'] = ['x' + str(index + 1) for index in range(self.config['n_covariate'])]
        self._split_dataset()
        self._add_weight()
        self.config['treated_ratio'] = sum(self.feat['treatment'] == 1) / len(self)
        self.config['control_ratio'] = sum(self.feat['treatment'] == 0) / len(self)
        self.config['n_unit'] = len(self)
        # self.config['n_covariate'] = len(self.feat.iloc[:, 5:].columns)


    def _add_weight(self):
        u = sum(self.train['treatment']==1)/len(self.train)
        train_treatments = self.train['treatment'].values


        val_u = sum(self.val['treatment'] == 1) / len(self.val)
        val_treatments = self.val['treatment'].values

        self.test_x = self.test[self.config['column_x']]
        self.val_x = self.val[self.config['column_x']]
        self.train_x = self.train[self.config['column_x']]

        if self.config['dataset'] in ['Jobs']:
            test_treat_yf = np.ones(len(self.test))
            test_control_yf = np.ones(len(self.test))
            test_indicator_random = self.test['e'].values
            test_factual_outcome = self.test['yf'].values

            train_treat_yf = np.ones(len(self.train))
            train_control_yf = np.ones(len(self.train))
            train_indicator_random = self.train['e'].values
            train_factual_outcome = self.train['yf'].values
        else:
            test_treat_yf = self.test['mu1'].values
            test_control_yf = self.test['mu0'].values
            test_indicator_random = np.ones(len(self.test))
            test_factual_outcome = np.ones(len(self.test))

            train_treat_yf = self.train['mu1'].values
            train_control_yf = self.train['mu0'].values
            train_indicator_random = np.ones(len(self.train))
            train_factual_outcome = np.ones(len(self.train))

        self.val_aug = pd.DataFrame({
            'weight': val_treatments / (2 * val_u) + (1 - val_treatments) / (2 * (1 - val_u)),
            'treatment': self.val['treatment'].values,
            'yf': self.val['yf'].values,
            'test_treatment': np.ones(len(self.val)),
            'indicator_random': np.ones(len(self.val)),
            'factual_outcome': np.ones(len(self.val))
        })

        self.train_aug = pd.DataFrame({
            'weight': train_treatments/(2*u) + (1-train_treatments)/(2*(1-u)),
            'treatment': self.train['treatment'].values,
            'yf': self.train['yf'].values,
            'test_treatment': np.ones(len(self.train)),
            'indicator_random': np.ones(len(self.train)),
            'factual_outcome': np.ones(len(self.train))
        })

        self.test_treated = pd.DataFrame({
            'weight': np.ones(len(self.test)),
            'treatment':np.ones(len(self.test)),
            'yf':test_treat_yf,
            'test_treatment': self.test['treatment'].values,
            'indicator_random': test_indicator_random,
            'factual_outcome': test_factual_outcome
        })
        self.test_control = pd.DataFrame({
            'weight': np.ones(len(self.test)),
            'treatment': np.zeros(len(self.test)),
            'yf': test_control_yf,
            'test_treatment': self.test['treatment'].values,
            'indicator_random': test_indicator_random,
            'factual_outcome': test_factual_outcome
        })
        self.train_treated = pd.DataFrame({
            'weight': np.ones(len(self.train)),
            'treatment': np.ones(len(self.train)),
            'yf': train_treat_yf,
            'test_treatment': self.train['treatment'].values,
            'indicator_random': train_indicator_random,
            'factual_outcome': train_factual_outcome
        })
        self.train_control = pd.DataFrame({
            'weight': np.zeros(len(self.train)),
            'treatment': np.zeros(len(self.train)),
            'yf': train_control_yf,
            'test_treatment': self.train['treatment'].values,
            'indicator_random': train_indicator_random,
            'factual_outcome': train_factual_outcome
        })

        self.test_treated = pd.concat([self.test_treated, self.test_x], axis=1)
        self.test_control = pd.concat([self.test_control, self.test_x], axis=1)
        self.train_treated = pd.concat([self.train_treated, self.train_x], axis=1)
        self.train_control = pd.concat([self.train_control, self.train_x], axis=1)

        self.val = pd.concat([self.val_aug, self.val_x], axis=1)
        self.train = pd.concat([self.train_aug, self.train_x], axis=1)

    def _split_dataset(self):

        splits = self.config['splits'].strip().split('/')
        n_train,n_val,n_test = float(splits[0]),float(splits[1]),float(splits[2])
        feat_index = set(range(0,self.n_units))

        train_index = list(np.random.choice(list(feat_index),int(n_train * self.n_units),replace=False))
        val_index = list(np.random.choice(list(feat_index-set(train_index)),int(n_val * self.n_units),replace=False))
        test_index = list(feat_index-set(train_index)-set(val_index))

        self.train = self.feat.iloc[train_index].reset_index(drop=True)
        self.val = self.feat.iloc[val_index].reset_index(drop=True)
        self.test = self.feat.iloc[test_index].reset_index(drop=True)

    def _load_feat(self,feat_path):

        df = pd.read_csv(feat_path,).sample(frac=1).reset_index(drop=True)
        df['yf'] = (df['yf']-df['yf'].min()) / (df['yf'].max()-df['yf'].min())
        if 'ycf' in df.columns:
            df['ycf'] = (df['ycf'] - df['ycf'].min()) / (df['ycf'].max() - df['ycf'].min())
        if 'mu0' in df.columns:
            df['mu0'] = (df['mu0'] - df['mu0'].min()) / (df['mu0'].max() - df['mu0'].min())
        if 'mu1' in df.columns:
            df['mu1'] = (df['mu1'] - df['mu1'].min()) / (df['mu1'].max() - df['mu1'].min())
        self.feat = df

    def __getitem__(self, index):
        df = self.feat[index]
        return df

    def __len__(self):
        return len(self.feat)

    def __repr__(self):
        return self.__str__()

    def __str__(self):
        info = ['[{}-{}]'.format(self.dataset,self.start_order)]
        info.append('The number of units: {} ({} treated {}%, {} control {}%)'.format(
            len(self),sum(self.feat['treatment']==1),100*round(sum(self.feat['treatment']==1)/len(self),2),
            sum(self.feat['treatment']==0),100*round(sum(self.feat['treatment']==0)/len(self),2)))
        info.append('The number of covariates: {}'.format(len(self.feat.iloc[:,5:].columns)))
        info.append('The number of treatments: {}'.format(len(np.unique(self.feat['treatment'].values))))
        info.append('The splits ratios {}: {}/{}/{}'.format(self.config['splits'],
                                                            len(self.train),len(self.val),len(self.test)))
        return '\n'.join(info)


class AbstractDataLoader(object):
    def __init__(self, config, dataset):

        self.config = config
        self.dataset = dataset



class TorchDataLoader(AbstractDataLoader):

    def __init__(self, config, dataset,batch_size=1024, shuffle=True):
        super().__init__(config, dataset)

        x = self.dataset[self.config['column_x']].values
        if config['robustness'] and config['testing']:

            samples = np.random.uniform(low = -config['low'],high=config['high'],size=x.shape)
            x = x + samples

        t = self.dataset['treatment'].values
        y = self.dataset['yf'].values
        w = self.dataset['weight'].values
        self.x_all = torch.from_numpy(x).float()
        self.y_all = torch.from_numpy(y.reshape(-1, 1)).float()

        self.x_treated = self.x_all[t == 1]
        self.x_control = self.x_all[t == 0]
        self.y_treated = self.y_all[t == 1]
        self.y_control = self.y_all[t == 0]

        test_treatment = self.dataset['test_treatment'].values
        indicator_random = self.dataset['indicator_random'].values
        factual_outcome = self.dataset['factual_outcome'].values

        self.ds = TensorDataset(torch.from_numpy(x).float(),
                                    torch.from_numpy(t.reshape(-1, 1)).int(),
                                    torch.from_numpy(y.reshape(-1, 1)).float(),
                                    torch.from_numpy(w.reshape(-1, 1)).float(),
                                    torch.from_numpy(test_treatment.reshape(-1, 1)).int(),
                                    torch.from_numpy(indicator_random.reshape(-1, 1)).int(),
                                    torch.from_numpy(factual_outcome.reshape(-1, 1)).float()
                                )


        self.dl = DataLoader(self.ds, batch_size=batch_size, shuffle=shuffle)
        self.size = x.shape

    def get_X_size(self):
        return self.size

    def __len__(self):
        return len(self.dl)

    def __iter__(self):
        for b in iter(self.dl):
            yield b


class SklearnDataLoader(AbstractDataLoader):
    def __init__(self,config,dataset):
        super().__init__(config, dataset)
        self.x = self.dataset[self.config['column_x']].values
        self.t = self.dataset['treatment'].values
        self.y = self.dataset['yf'].values
        self.w = self.dataset['weight'].values.reshape(-1,1)
        self.e = self.dataset['indicator_random'].values
        self.factual_outcome = self.dataset['factual_outcome'].values
        self.test_treatment = self.dataset['test_treatment']
        self.size = self.x.shape

    def get_X_size(self):
        return self.size

    def __len__(self):
        return len(self.x)

    def get_data(self):
        return self.x,self.t,self.y,self.w,self.e,self.factual_outcome,self.test_treatment