import pandas as pd
import numpy as np
import torch
from collections import Counter
from torch.utils.data.sampler import WeightedRandomSampler
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from utils import StandardScaler


class MyDataset(Dataset):

    def __init__(self, path, mask=0):
        # super().__init__(path, params)
        self.data = np.loadtxt(path, delimiter=',', dtype=np.float32)

        self.x_dim = self.data.shape[-1] - 5
        self.x_dim_start = int(mask*self.x_dim)
        self.x_dim -= self.x_dim_start
        self.sample_num = self.data.shape[0]
        
        # self.scaler = StandardScaler()
        # self.data = self.scaler.fit_transform(self.data)

    def __getitem__(self, index):

        return self.data[index, self.x_dim_start:]

    def __len__(self):

        return len(self.data)

    def get_sampler(self, treat_weight=1):

        t = self.data[:, -3].astype(np.int16)
        count = Counter(t)
        class_count = np.array([count[0], count[1]*treat_weight])
        weight = 1. / class_count
        samples_weight = torch.tensor([weight[item] for item in t])
        sampler = WeightedRandomSampler(
            samples_weight,
            len(samples_weight),
            replacement=True)

        return sampler


def twins_processor():

    N_individual = 8244
    
    file_path = f'dataset/Twins/{N_individual}/{N_individual}_biased.csv'
    data = np.loadtxt(file_path, delimiter=',', skiprows=1)
    t = data[:, [0]]
    yf = data[:, [1]]
    ycf = data[:, [2]]
    mu0 = data[:, [3]]
    mu1 = data[:, [4]]
    muf = mu1 * t + mu0 * (1-t)
    mucf = mu1 * (1-t) + mu0 * t
    x = data[:, 5:] 
    output = np.concatenate([x, t, yf, ycf, mu0, mu1], axis=-1)
    # we use yf for training, and mu for test. It enables us to split the data in this way since the ground truth mu would not be used in training.
    train, eval_test = train_test_split(output, test_size=0.37, stratify=output[:, -5], random_state=42)
    evaluation, test = train_test_split(eval_test, test_size=0.27, stratify=eval_test[:, -5], random_state=42)

    np.savetxt(f"dataset/Twins/{N_individual}/train.csv", train, delimiter=",")
    np.savetxt(f"dataset/Twins/{N_individual}/traineval.csv", train, delimiter=",")
    np.savetxt(f"dataset/Twins/{N_individual}/eval.csv", evaluation, delimiter=",")
    np.savetxt(f"dataset/Twins/{N_individual}/test.csv", test, delimiter=",")

    return None

def jobs_processor():

    # N_individual = X
    
    random_path = './randomized.csv'
    nonrandom_path = './nonrandomized.csv'
    
    data = np.loadtxt(random_path, delimiter=',', skiprows=1)
    t = data[:, [0]]
    yf = data[:, [1]]
    ycf = data[:, [2]]
    # no use
    mu0 = data[:, [3]]
    mu1 = data[:, [4]]
    muf = mu1 * t + mu0 * (1-t)
    mucf = mu1 * (1-t) + mu0 * t
    x = data[:, 5:] 
    output = np.concatenate([x, t, yf, ycf, mu0, mu1], axis=-1)
    # we use yf for training, and mu for test. It enables us to split the data in this way since the ground truth mu would not be used in training.
    train_random, eval_test_random = train_test_split(output, test_size=0.44, stratify=output[:, -5], random_state=42)
    evaluation_random, test_random = train_test_split(eval_test_random, test_size=0.4545, stratify=eval_test_random[:, -5], random_state=42)
    
    data = np.loadtxt(nonrandom_path, delimiter=',', skiprows=1)
    t = data[:, [0]]
    yf = data[:, [1]]
    ycf = data[:, [2]]
    mu0 = data[:, [3]]
    mu1 = data[:, [4]]
    muf = mu1 * t + mu0 * (1-t)
    mucf = mu1 * (1-t) + mu0 * t
    x = data[:, 5:] 
    output = np.concatenate([x, t, yf, ycf, mu0, mu1], axis=-1)
    # we use yf for training, and mu for test. It enables us to split the data in this way since the ground truth mu would not be used in training.
    train_nonrandom, eval_test_nonrandom = train_test_split(output, test_size=0.44, stratify=output[:, -5], random_state=42)
    evaluation_nonrandom, test_nonrandom = train_test_split(eval_test_nonrandom, test_size=0.4545, stratify=eval_test_nonrandom[:, -5], random_state=42)

    train_combined = np.concatenate([train_random, train_nonrandom], axis=0)

    np.savetxt(f"dataset/Jobs/train.csv", train_combined, delimiter=",")
    np.savetxt(f"dataset/Jobs/traineval.csv", train_random, delimiter=",")
    np.savetxt(f"dataset/Jobs/eval.csv", evaluation_random, delimiter=",")
    np.savetxt(f"dataset/Jobs/test.csv", test_random, delimiter=",")

    return None






if __name__ == "__main__":

    twins_processor()
    jobs_processor()


# 