import os
import pandas as pd
import numpy as np
import torch
import os
from torch.utils.data import TensorDataset, DataLoader

class Dataset(object):
    def __init__(self, config):
        self.config = config
        self.dataset = config['dataset']
        self.start_order = config['start_order']
        self._load_feat(os.path.join(self.config['rootPath'],'dataset',self.dataset.strip("'"),'{}{}.csv'.format(
            self.dataset.strip("'"),self.start_order)))
        self.n_units = len(self.feat)
        self.config['n_units'] = self.n_units

        self._split_dataset()

        self._add_weight()
        self.config['treated_ratio'] = sum(self.feat['treatment'] == 1) / len(self)
        self.config['control_ratio'] = sum(self.feat['treatment'] == 0) / len(self)
        self.config['n_unit'] = len(self)
        self.config['n_covariate'] = len(self.feat.iloc[:, 5:].columns)


    def _add_weight(self):
        u = sum(self.train['treatment']==1)/len(self.train)

        treatments = self.train['treatment'].values
        w = pd.DataFrame({
            'weight':treatments/(2*u) + (1-treatments)/(2*(1-u))
        })

        self.train = pd.concat([w,self.train],axis=1)

        self.val_x = self.val.iloc[:, 2:]

        self.val_treated = pd.DataFrame({
            'weight': np.ones(len(self.val)),
            'treatment': np.ones(len(self.val)),
            'yf': self.val['mu1'].values,
        })
        self.val_control = pd.DataFrame({
            'weight': np.ones(len(self.val)),
            'treatment': np.zeros(len(self.val)),
            'yf': self.val['mu0'].values
        })

        self.val_treated = pd.concat([self.val_treated, self.val_x], axis=1)
        self.val_control = pd.concat([self.val_control, self.val_x], axis=1)

        self.test_x = self.test.iloc[:, 2:]

        self.test_treated = pd.DataFrame({
            'weight': np.ones(len(self.test)),
            'treatment':np.ones(len(self.test)),
            'yf':self.test['mu1'].values,
        })
        self.test_control = pd.DataFrame({
            'weight': np.ones(len(self.test)),
            'treatment': np.zeros(len(self.test)),
            'yf': self.test['mu0'].values
        })

        w = pd.DataFrame({
            'weight': np.ones(len(self.val))
        })
        self.val = pd.concat([w, self.val], axis=1)
        self.test_treated = pd.concat([self.test_treated, self.test_x], axis=1)
        self.test_control = pd.concat([self.test_control, self.test_x], axis=1)

    def _split_dataset(self):

        splits = self.config['splits'].strip().split('/')
        n_train,n_val,n_test = float(splits[0]),float(splits[1]),float(splits[2])
        feat_index = set(range(0,self.n_units))
        if self.config['pertubation_fix']:
            RANDOM_SEED = self.config['seed']
            branch_rng = np.random.RandomState(RANDOM_SEED)
            train_index = list(branch_rng.choice(list(feat_index),int(n_train * self.n_units),replace=False))
            val_index = list(branch_rng.choice(list(feat_index-set(train_index)),int(n_val * self.n_units),replace=False))
            test_index = list(feat_index-set(train_index)-set(val_index))
        else:
            train_index = list(np.random.choice(list(feat_index),int(n_train * self.n_units),replace=False))
            val_index = list(np.random.choice(list(feat_index-set(train_index)),int(n_val * self.n_units),replace=False))
            test_index = list(feat_index-set(train_index)-set(val_index))
 
        self.train = self.feat.iloc[train_index].reset_index(drop=True)
        self.val = self.feat.iloc[val_index].reset_index(drop=True)
        self.test = self.feat.iloc[test_index].reset_index(drop=True)

    def _load_feat(self,feat_path):

        df = pd.read_csv(feat_path,).sample(frac=1).reset_index(drop=True)

        self.feat = df

    def __getitem__(self, index):
        df = self.feat[index]
        return df

    def __len__(self):
        return len(self.feat)

    def __repr__(self):
        return self.__str__()

    def __str__(self):
        info = ['[{}-{}]'.format(self.dataset,self.start_order)]
        info.append('The number of units: {} ({} treated {}%, {} control {}%)'.format(
            len(self),sum(self.feat['treatment']==1),100*round(sum(self.feat['treatment']==1)/len(self),2),
            sum(self.feat['treatment']==0),100*round(sum(self.feat['treatment']==0)/len(self),2)))
        info.append('The number of covariates: {}'.format(len(self.feat.iloc[:,5:].columns)))
        info.append('The number of treatments: {}'.format(len(np.unique(self.feat['treatment'].values))))
        info.append('The splits ratios {}: {}/{}/{}'.format(self.config['splits'],
                                                            len(self.train),len(self.val),len(self.test)))
        return '\n'.join(info)


class AbstractDataLoader(object):
    def __init__(self, config, dataset):

        self.config = config
        self.dataset = dataset


class TorchDataLoader(AbstractDataLoader):

    def __init__(self, config, dataset,batch_size=1024, shuffle=True):
        super().__init__(config, dataset)

        x = self.dataset.iloc[:,6:].values

        t = self.dataset['treatment'].values
        y = self.dataset['yf'].values
        w = self.dataset['weight'].values
        self.x_all = torch.from_numpy(x).float()
        self.y_all = torch.from_numpy(y.reshape(-1, 1)).float()

        self.x_treated = self.x_all[t == 1]
        self.x_control = self.x_all[t == 0]
        self.y_treated = self.y_all[t == 1]
        self.y_control = self.y_all[t == 0]

        self.ds = TensorDataset(torch.from_numpy(x).float(),                # 0
                                torch.from_numpy(t.reshape(-1, 1)).int(),   # 1
                                torch.from_numpy(y.reshape(-1, 1)).float(), # 2
                                torch.from_numpy(w.reshape(-1, 1)).float()) # 3

        self.dl = DataLoader(self.ds, batch_size=batch_size, shuffle=shuffle)
        self.size = x.shape

    def get_X_size(self):
        return self.size

    def __len__(self):
        return len(self.dl)

    def __iter__(self):
        for b in iter(self.dl):
            yield b


class SklearnDataLoader(AbstractDataLoader):
    def __init__(self,config,dataset):
        super().__init__(config, dataset)
        self.x = self.dataset.iloc[:, 6:].values
        self.t = self.dataset['treatment'].values
        self.y = self.dataset['yf'].values
        self.w = self.dataset['weight'].values.reshape(-1,1)
        self.size = self.x.shape

    def get_X_size(self):
        return self.size

    def __len__(self):
        return len(self.x)

    def get_data(self):
        return self.x,self.t,self.y,self.w