import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from scipy.stats import gengamma

def linear_data_df(num_samples, c_x, c_t, c_yz, c_yt, s_x, s_t, s_y):
    """c are the structural parameters and s are the standard deviations of the error terms
    For each x: z -c_x-> x+e_x
    For t: z -c_t-> t+e_t
    For y: z -c_yz-> y+e_y <-c_yt- t"""
    x_dim = len(c_x)
    z = np.random.standard_normal((num_samples,1))
    x = np.random.normal(np.tile(c_x, (num_samples,1))*z,
                         np.tile(s_x, (num_samples,1)),(num_samples, x_dim))
    t = np.random.normal(np.tile(c_t, (num_samples,1))*z,
                         np.tile(s_t, (num_samples,1)),(num_samples, 1))
    y = np.random.normal(np.tile(c_yz, (num_samples,1))*z + np.tile(c_yt, (num_samples,1))*t,
                        np.tile(s_y, (num_samples,1)), (num_samples, 1))
    
    df = pd.DataFrame(np.concatenate([z,x,t,y], axis=1), 
                      columns=['z'] + ['x{}'.format(i) for i in range(x_dim)] + ['t','y'])
    return df

def generate_linear_parameters(x_dim):
    #Generate parameters for the model
    s_x = gengamma(1,5).rvs(x_dim)
    c_x = (gengamma(0.3,4).rvs(x_dim)*s_x + s_x/2)  * np.array([2*(int(np.random.random()>0.5)-0.5) for i in range(2)])
    s_t = gengamma(1,5).rvs(1)
    c_t = (gengamma(0.3,4).rvs(1)*s_t + s_t/2) * 2*(int(np.random.random() > 0.5) - 0.5)
    s_y = gengamma(1,5).rvs(1)
    c_yt = (gengamma(0.3,4).rvs(1)*s_y + s_y/2) * 2*(int(np.random.random() > 0.5) - 0.5)
    c_yz = (gengamma(0.3,4).rvs(1)*s_y + s_y/2) * 2*(int(np.random.random() > 0.5) - 0.5)
    return c_x, c_t, c_yz, c_yt, s_x, s_t, s_y

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def linear_data_binary_ty_df(num_samples, c_x, s_x, t_a, t_b, y_a0, y_b0, y_a1, y_b1):
    z = np.random.standard_normal((num_samples,1))
    x = np.random.normal(np.tile(c_x, (num_samples,1))*z,
                         np.tile(s_x, (num_samples,1)),(num_samples, 2))
    t = (np.random.random((num_samples, 1)) < sigmoid(t_a*z + t_b)).astype(int)
    y = (np.random.random((num_samples, 1)) < sigmoid(y_a1*z + y_b1)).astype(int)*t \
        + (np.random.random((num_samples, 1)) < sigmoid(y_a0*z + y_b0)).astype(int)*(1-t)
    x_dim = len(c_x)
    df = pd.DataFrame(np.concatenate([z,x,t,y], axis=1), 
                      columns=['z'] + ['x{}'.format(i) for i in range(x_dim)] + ['t','y'])
    return df

# Define pytorch datasets and loaders
class LinearDataset(Dataset):
    def __init__(self, data: pd.DataFrame, z_dim=1):
        self.length = data.shape[0]
        x_dim = data.shape[1]-2-z_dim#minus z,t,yf,y0 and y1
        self.t = data.loc[:, ['t']].values
        self.X = data.iloc[:, z_dim:z_dim+x_dim].values#Assumes that the dataframe is generated by linear_data_df
        self.y = data.loc[:, ['y']].values

    def __getitem__(self, idx):
        return {
            'X': self.X[idx],
            't': self.t[idx],
            'y': self.y[idx]
        }

    def __len__(self):
        return self.length

class LinearDataLoader(DataLoader):
    def __init__(self, dataset, validation_split=0.2, shuffle=True):
        dataset_size = len(dataset)
        indices = list(range(dataset_size))
        split = int(np.floor(validation_split * dataset_size))
        if shuffle:
            np.random.shuffle(indices)
        train_indices, valid_indices = indices[split:], indices[: split]

        self.dataset = dataset
        self.train_sampler = SubsetRandomSampler(train_indices)
        self.valid_sampler = SubsetRandomSampler(valid_indices)

    def collate_fn(self, batch):#TODO <- is this super slow?
        keys = list(batch[0].keys())
        processed_batch = {k: [] for k in keys}
        for _, sample in enumerate(batch):
            for key, value in sample.items():
                processed_batch[key].append(value)
        
        processed_batch['t'] = torch.FloatTensor(processed_batch['t'])
        processed_batch['X'] = torch.FloatTensor(processed_batch['X'])
        processed_batch['y'] = torch.FloatTensor(processed_batch['y'])
        return processed_batch

    def train_loader(self, batch_size, num_workers=0):
        train_loader = DataLoader(
            dataset=self.dataset,
            batch_size=batch_size,
            collate_fn=self.collate_fn,
            sampler=self.train_sampler,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True
        )

        return train_loader

    def test_loader(self, batch_size, num_workers=0):
        test_loader = DataLoader(
            dataset=self.dataset,
            batch_size=batch_size,
            collate_fn=self.collate_fn,
            sampler=self.valid_sampler,
            num_workers=num_workers,
            pin_memory=True,
            shuffle=False,
            drop_last=True
        )

        return test_loader

    def get_loaders(self, batch_size):
        train_loader = self.train_loader(batch_size)
        test_loader = self.test_loader(batch_size)

        return train_loader, test_loader