import os
import pandas as pd
import numpy as np
import torch
import os
from torch.utils.data import TensorDataset, DataLoader
import pickle

class Dataset(object):
    def __init__(self, config):
        self.config = config
        self.dataset = config['dataset']
        self.start_order = config['start_order']
        self._load_feat(os.path.join(self.config['rootPath'],'dataset/{}/{}{}.csv'.format(
            self.dataset,self.dataset,self.start_order)))        
        self.n_units = len(self.feat)
        self.config['n_units'] = self.n_units
        self.k = config['k']
        self._split_dataset()
        self._add_weight()
        self.config['treated_ratio'] = sum(self.feat['treatment'] == 1) / len(self)
        self.config['control_ratio'] = sum(self.feat['treatment'] == 0) / len(self)
        self.config['n_unit'] = len(self)
        self.config['n_covariate'] = len(self.feat.iloc[:, 5:].columns)

    def sigmoid(self, x):
        return 1 / (1 + np.exp(-x))

    def _add_weight(self):

        u = sum(self.train['treatment']==1)/len(self.train)

        treatments = self.train['treatment'].values
        w = pd.DataFrame({
            'weight':treatments/(2*u) + (1-treatments)/(2*(1-u))
        })
        self.train_x = self.train.iloc[:, 2:]
        self.train = pd.concat([w,self.train],axis=1)

        treatments = self.val['treatment'].values
        u = sum(self.val['treatment']==1)/len(self.val)
        w = pd.DataFrame({
            'weight':treatments/(2*u) + (1-treatments)/(2*(1-u))
        })

        self.val = pd.concat([w,self.val],axis=1)

        self.test_x = self.test.iloc[:, 2:]        

        self.test_treated = pd.DataFrame({
            'weight': np.ones(len(self.test)),
            'treatment':np.ones(len(self.test)),
            'yf':self.test['mu1'].values,
        })
        self.test_control = pd.DataFrame({
            'weight': np.ones(len(self.test)),
            'treatment': np.zeros(len(self.test)),
            'yf': self.test['mu0'].values          
        })

        self.train_treated = pd.DataFrame({
            'weight': np.ones(len(self.train)),
            'treatment':np.ones(len(self.train)),
            'yf':self.train['mu1'].values,
        })
        self.train_control = pd.DataFrame({
            'weight': np.ones(len(self.train)),
            'treatment': np.zeros(len(self.train)),
            'yf': self.train['mu0'].values          
        })

        w = pd.DataFrame({
            'weight': np.ones(len(self.test))
        })
        self.test = pd.concat([w, self.test], axis=1)
        
        self.test_treated = pd.concat([self.test_treated, self.test_x], axis=1)
        self.test_control = pd.concat([self.test_control, self.test_x], axis=1)
        self.train_treated = pd.concat([self.train_treated, self.train_x], axis=1)
        self.train_control = pd.concat([self.train_control, self.train_x], axis=1)

    def _split_dataset(self):

        splits = self.config['splits'].strip().split('/')
        n_train,n_val,n_test = float(splits[0]),float(splits[1]),float(splits[2])
        feat_index = set(range(0,self.n_units))

        train_index = list(np.random.choice(list(feat_index),int(n_train * self.n_units),replace=False))
        val_index = list(np.random.choice(list(feat_index-set(train_index)),int(n_val * self.n_units),replace=False))
        test_index = list(feat_index-set(train_index)-set(val_index))

        self.train = self.feat.iloc[train_index].reset_index(drop=True)
        self.val = self.feat.iloc[val_index].reset_index(drop=True)
        self.test = self.feat.iloc[test_index].reset_index(drop=True)

    def _load_feat(self,feat_path):

        if 'Syn' in self.dataset:
          file = open('dataset/{}/synthetic_data_{}'.format(self.dataset, self.start_order - 1), "rb")
          syn_x = pickle.load(file)
          t = pickle.load(file)
          yf = pickle.load(file)
          ycf = pickle.load(file)
          file.close()
          # For varying alpha in our paper
          # y0 = yf.copy()
          # y0[t == 1] = ycf[t == 1]
          # y1 = yf.copy()
          # y1[t == 0] = ycf[t == 0]
          # y1 *= 1.5
          # yf[t == 1] = y1[t == 1]
          # ycf[t == 0] = y1[t == 0]
          # final = pd.DataFrame([w, yf, ycf, yf, ycf, syn_x])
          final = np.c_[t, yf, ycf, yf, ycf, syn_x]
          final = pd.DataFrame(final)
          final = final.rename(columns={0:'treatment', 1:'yf', 2:'ycf', 3: 'mu0', 4:'mu1'})
          df = final.copy()
        self.feat = df

    def __getitem__(self, index):
        df = self.feat[index]
        return df

    def __len__(self):
        return len(self.feat)

    def __repr__(self):
        return self.__str__()

    def __str__(self):
        info = ['[{}-{}]'.format(self.dataset,self.start_order)]
        info.append('The number of units: {} ({} treated {}%, {} control {}%)'.format(
            len(self),sum(self.feat['treatment']==1),100*round(sum(self.feat['treatment']==1)/len(self),2),
            sum(self.feat['treatment']==0),100*round(sum(self.feat['treatment']==0)/len(self),2)))
        info.append('The number of covariates: {}'.format(len(self.feat.iloc[:,5:].columns)))
        info.append('The number of treatments: {}'.format(len(np.unique(self.feat['treatment'].values))))
        info.append('The splits ratios {}: {}/{}/{}'.format(self.config['splits'],
                                                            len(self.train),len(self.val),len(self.test)))
        return '\n'.join(info)


class AbstractDataLoader(object):
    def __init__(self, config, dataset):

        self.config = config
        self.dataset = dataset


class TorchDataLoader(AbstractDataLoader):

    def __init__(self, config, dataset,batch_size=1024, shuffle=True):
        super().__init__(config, dataset)       
        # print(dataset)
        # print(dataset.shape)
        # XXx
        x = self.dataset.iloc[:,6:].values
        if config['robustness'] and config['testing']:

            samples = np.random.uniform(-0.1,0.1,x.shape)

            x = x + samples

        t = self.dataset['treatment'].values 
        y = self.dataset['yf'].values
        w = self.dataset['weight'].values
        self.x_all = torch.from_numpy(x).float()
        self.y_all = torch.from_numpy(y.reshape(-1, 1)).float()

        self.x_treated = self.x_all[t == 1]
        self.x_control = self.x_all[t == 0]
        self.y_treated = self.y_all[t == 1]
        self.y_control = self.y_all[t == 0]
        self.ds = TensorDataset(torch.from_numpy(x).float(),
                                torch.from_numpy(t.reshape(-1, 1)).int(),
                                torch.from_numpy(y.reshape(-1, 1)).float(),
                                torch.from_numpy(w.reshape(-1, 1)).float())

        self.dl = DataLoader(self.ds, batch_size=batch_size, shuffle=shuffle)
        self.size = x.shape

    def sigmoid(self, x):
        return 1 / (1 + np.exp(-x))

    def get_X_size(self):
        return self.size

    def __len__(self):
        return len(self.dl)

    def __iter__(self):
        for b in iter(self.dl):
            yield b


class SklearnDataLoader(AbstractDataLoader):
    def __init__(self,config,dataset):
        super().__init__(config, dataset)
        self.x = self.dataset.iloc[:, 6:].values
        self.t = self.dataset['treatment'].values
        self.y = self.dataset['yf'].values
        self.w = self.dataset['weight'].values.reshape(-1,1)
        self.size = self.x.shape

    def get_X_size(self):
        return self.size

    def __len__(self):
        return len(self.x)

    def get_data(self):
        return self.x,self.t,self.y,self.w