from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from sklearn.preprocessing import MinMaxScaler
import torch
import numpy as np
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.sampler import Sampler
from folktables import ACSDataSource, ACSIncome, ACSEmployment
import os
import random
import time

from FairBatchSampler import FairBatch

class CustomBatch(Sampler):
    def __init__(self, model, x_tensor, y_tensor, z_tensor, trainloader, batch_size, seed=0):
        """Initializes FairBatch."""

        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)

        self.model = model
        self.x_data = x_tensor
        self.y_data = y_tensor
        self.z_data = z_tensor
        self.trainloader = trainloader

        self.N = len(z_tensor)

        self.batch_size = batch_size
        self.batch_num = int(len(self.y_data) / self.batch_size)

    def get_loss_arr(self):
        loss_arr = []
        criterion = torch.nn.CrossEntropyLoss(reduction='none')
        self.model.eval()
        for i, data in enumerate(self.trainloader, 0):
            inputs, labels, groups = data
            inputs, labels, groups = inputs.cuda(), labels.cuda(), groups.cuda()
            outputs = self.model(inputs.float())
            loss = criterion(outputs, labels.long())
            loss_arr.extend(loss.detach().cpu().numpy())

        return np.array(loss_arr)

    def __iter__(self):

        loss_arr = self.get_loss_arr()
        loss_arr_order = np.argsort(loss_arr)

        # print(loss_arr_order[:5])

        for i in range(self.batch_num):
            curr_index = i*self.batch_size
            index_diff = abs(np.array(range(len(loss_arr_order))) - curr_index) + 1
            index_prob = 1/index_diff
            index_prob = index_prob/np.sum(index_prob)

            loss_arr_batch_indices = np.random.choice(range(len(loss_arr_order)), size=self.batch_size, replace=False, p=index_prob)
            curr_batch = loss_arr_order[loss_arr_batch_indices]
            loss_arr_order = np.delete(loss_arr_order, loss_arr_batch_indices)

            yield curr_batch

    def __len__(self):
        return len(self.z_data)


# def shuffle_recreate_dataloader(dataloader):
#     for data in dataloader:
#         inputs, labels, groups = data
#
#     tensorx_test = torch.from_numpy(x_test)
#     tensory_test = torch.from_numpy(y_test)
#     tensorgroup_test = torch.from_numpy(group_test)
#     test_dataset = TensorDataset(tensorx_test, tensory_test, tensorgroup_test)
#
#     testloader = DataLoader(test_dataset, batch_size=128, shuffle=False, drop_last=drop_last_bool)


def load_celeba_partition(img_list, celeba_feat_dir, ydict, groupdict):
    x, y, group = [], [], []
    for img in img_list:
        feat_path = os.path.join(celeba_feat_dir, img[:-3] + 'npy')
        if not os.path.exists(feat_path):
            continue

        feat = np.load(feat_path)
        x.append(feat)
        y.append(ydict[img])
        group.append(groupdict[img])

    return np.array(x), np.array(y), np.array(group)

def load_celeba_dataset():
    celeba_dir = '/mnt/LargeDisk/Data/celeba'
    celeba_label_file = os.path.join(celeba_dir, 'list_attr_celeba.csv')
    celeba_partition_file = os.path.join(celeba_dir, 'list_eval_partition.csv')
    celeba_feat_dir = os.path.join(celeba_dir, 'img_align_celeba/feat_align_celeba')

    dflabel = pd.read_csv(celeba_label_file)
    ydict = {img_id: smiling_label==1 for img_id, smiling_label in zip(dflabel['image_id'], dflabel['Smiling'])}
    groupdict = {img_id: 1-max(male_label, 0) for img_id, male_label in zip(dflabel['image_id'], dflabel['Male'])}

    dfpart = pd.read_csv(celeba_partition_file)
    img_list = dfpart['image_id']
    partition = dfpart['partition']
    train_img = img_list[partition==0]
    valid_img = img_list[partition==1]
    test_img = img_list[partition==2]

    x_train, y_train, group_train = load_celeba_partition(train_img, celeba_feat_dir, ydict, groupdict)
    x_valid, y_valid, group_valid = load_celeba_partition(valid_img, celeba_feat_dir, ydict, groupdict)
    x_test, y_test, group_test = load_celeba_partition(test_img, celeba_feat_dir, ydict, groupdict)

    return x_train, y_train, group_train, x_test, y_test, group_test, x_valid, y_valid, group_valid

def balance_dataset_group_count(features, label, group, random_subsample=False):
    numz = len(label[group==0])
    numo = len(label[group==1])

    if numz < numo:
        lessindex = 0
    else:
        lessindex = 1

    target_len = len(label[group==lessindex])

    if random_subsample:
        total_len = 2*target_len
        features = features[:total_len]
        label = label[:total_len]
        group = group[:total_len]
    else:
        features = np.concatenate((features[group==0][:target_len], features[group==1][:target_len]), axis=0)
        label = np.concatenate((label[group==0][:target_len], label[group==1][:target_len]), axis=0)
        group = np.concatenate((group[group==0][:target_len], group[group==1][:target_len]), axis=0)

    return features, label, group

def balance_dataset_group_and_label_count(features, label, group, random_subsample=False):
    features_pos = features[label==1]
    label_pos = label[label==1]
    group_pos = group[label==1]

    features_neg = features[label==0]
    label_neg = label[label==0]
    group_neg = group[label==0]

    features_pos, label_pos, group_pos = balance_dataset_group_count(features_pos, label_pos, group_pos, random_subsample=random_subsample)
    features_neg, label_neg, group_neg = balance_dataset_group_count(features_neg, label_neg, group_neg, random_subsample=random_subsample)

    features = np.concatenate((features_pos, features_neg), axis=0)
    label = np.concatenate((label_pos, label_neg), axis=0)
    group = np.concatenate((group_pos, group_neg), axis=0)

    return features, label, group

def balance_dataset_label_count_per_group(features, label, group, random_subsample=False):
    numz = len(label[label==0])
    numo = len(label[label==1])

    if numz < numo:
        lessindex = 0
    else:
        lessindex = 1

    target_len = len(label[label==lessindex])

    if random_subsample:
        total_len = 2*target_len
        features = features[:total_len]
        label = label[:total_len]
        group = group[:total_len]
    else:
        features = np.concatenate((features[label==0][:target_len], features[label==1][:target_len]), axis=0)
        group = np.concatenate((group[label==0][:target_len], group[label==1][:target_len]), axis=0)
        label = np.concatenate((label[label==0][:target_len], label[label==1][:target_len]), axis=0)

    return features, label, group

def balance_dataset_label_count(features, label, group, random_subsample=False):
    features_one = features[group==1]
    label_one = label[group==1]
    group_one = group[group==1]

    features_zero = features[group==0]
    label_zero = label[group==0]
    group_zero = group[group==0]

    features_one, label_one, group_one = balance_dataset_label_count_per_group(features_one, label_one, group_one, random_subsample=random_subsample)
    features_zero, label_zero, group_zero = balance_dataset_label_count_per_group(features_zero, label_zero, group_zero, random_subsample=random_subsample)

    features = np.concatenate((features_one, features_zero), axis=0)
    label = np.concatenate((label_one, label_zero), axis=0)
    group = np.concatenate((group_one, group_zero), axis=0)

    return features, label, group

def imbalance_dataset(features, label, group):
    features_o_pos, label_o_pos, group_o_pos = features[group==1][label[group==1]==1], label[group==1][label[group==1]==1], group[group==1][label[group==1]==1]
    features_o_neg, label_o_neg, group_o_neg = features[group==1][label[group==1]==0], label[group==1][label[group==1]==0], group[group==1][label[group==1]==0]
    features_z, label_z, group_z = features[group==0], label[group==0], group[group==0]

    sublen_o_pos = int(0.5*len(features_o_pos))
    # sublen_o_neg = int(0.5*len(features_o_neg))
    sublen_o_neg = int(0.25*len(features_o_neg))
    features_o_pos, label_o_pos, group_o_pos = features_o_pos[:sublen_o_pos], label_o_pos[:sublen_o_pos], group_o_pos[:sublen_o_pos]
    features_o_neg, label_o_neg, group_o_neg = features_o_neg[:sublen_o_neg], label_o_neg[:sublen_o_neg], group_o_neg[:sublen_o_neg]

    features = np.concatenate((features_o_pos, features_o_neg, features_z), axis=0)
    label = np.concatenate((label_o_pos, label_o_neg, label_z), axis=0)
    group = np.concatenate((group_o_pos, group_o_neg, group_z), axis=0)

    return features, label, group

def get_folktables_dataset(train_valid_test_split=[0.7, 0.1, 0.2],
                           protected_class='sex', survey_year='2018', states=['CA'],
                           shuffle_seed=0, fairbatch=False, model=None, dataloadershuffle=True,
                           dataset_type='acsincome'):

    # # dataset_type = 'celeba'
    # # dataset_type = 'acsincome'
    # dataset_type = 'acsemployment'

    if 'acs' in dataset_type:
        data_source = ACSDataSource(survey_year=survey_year, horizon='1-Year', survey='person')
        acs_data = data_source.get_data(states=states, download=True)

        if dataset_type=='acsincome':
            task_class = ACSIncome
        elif dataset_type=='acsemployment':
            print("Dataset Type : ACS Employment")
            task_class = ACSEmployment

        if protected_class=='sex':
            task_class._group = 'SEX'
            features, label, group = task_class.df_to_numpy(acs_data)
            group = group - 1
        elif protected_class=='race':
            task_class._group = 'RAC1P'
            features, label, group = task_class.df_to_numpy(acs_data)
            group[group>1] = 2
            group = group - 1

        assert sum(train_valid_test_split)==1
        train_split, valid_split, test_split = train_valid_test_split[0], train_valid_test_split[1], train_valid_test_split[2]

        x_train, x_test, y_train, y_test, group_train, group_test = train_test_split(features, label, group, test_size=test_split, random_state=0)
        x_train, x_valid, y_train, y_valid, group_train, group_valid = train_test_split(x_train, y_train, group_train, test_size=valid_split/(1-test_split), random_state=0)

    elif dataset_type=='celeba':
        print("Dataset Type : CelebA")
        x_train, y_train, group_train, x_test, y_test, group_test, x_valid, y_valid, group_valid = load_celeba_dataset()

    x_train, y_train, group_train = shuffle(x_train, y_train, group_train, random_state=shuffle_seed)

    # x_train, y_train, group_train = balance_dataset_label_count(x_train, y_train, group_train, random_subsample=False)
    # x_train, y_train, group_train = shuffle(x_train, y_train, group_train, random_state=shuffle_seed)
    # x_train, y_train, group_train = balance_dataset_group_count(x_train, y_train, group_train, random_subsample=False)
    # x_test, y_test, group_test = balance_dataset_label_count(x_test, y_test, group_test, random_subsample=False)
    # x_test, y_test, group_test = shuffle(x_test, y_test, group_test, random_state=shuffle_seed)
    # x_test, y_test, group_test = balance_dataset_group_count(x_test, y_test, group_test, random_subsample=False)
    # x_train, y_train, group_train = balance_dataset_group_and_label_count(x_train, y_train, group_train, random_subsample=False)

    # x_train, y_train, group_train = imbalance_dataset(x_train, y_train, group_train)
    # x_test, y_test, group_test = imbalance_dataset(x_test, y_test, group_test)
    # x_valid, y_valid, group_valid = imbalance_dataset(x_valid, y_valid, group_valid)

    # print(np.sum(y_train[group_train==0]==0)/len(group_train))
    # print(np.sum(y_train[group_train==0]==1)/len(group_train))
    # print(np.sum(y_train[group_train==1]==0)/len(group_train))
    # print(np.sum(y_train[group_train==1]==1)/len(group_train))
    # exit()
    # # print(np.sum(y_test[group_test==0]==0))
    # # print(np.sum(y_test[group_test==0]==1))
    # # print(np.sum(y_test[group_test==1]==0))
    # # print(np.sum(y_test[group_test==1]==1))
    # exit()
    # print("n:", len(x_train))
    # print("n_p:", len(x_train[group_train==0]))
    # print("n_up:", len(x_train[group_train==1]))
    # print("n_fav:", len(x_train[y_train==1]))
    # print("n_unfav:", len(x_train[y_train==0]))
    # print("n_p_fav:", np.sum(y_train[group_train==0]==1))
    # print("n_p_unfav:", np.sum(y_train[group_train==0]==0))
    # print("n_up_fav:", np.sum(y_train[group_train==1]==1))
    # print("n_up_unfav:", np.sum(y_train[group_train==1]==0))
    # exit()
    # x_o_pos, y_o_pos, group_o_pos = x_train[group_train==1][y_train[group_train==1]==1], y_train[group_train==1][y_train[group_train==1]==1], group_train[group_train==1][y_train[group_train==1]==1]
    # x_o_neg, y_o_neg, group_o_neg = x_train[group_train==1][y_train[group_train==1]==0], y_train[group_train==1][y_train[group_train==1]==0], group_train[group_train==1][y_train[group_train==1]==0]
    # x_z, y_z, group_z = x_train[group_train==0], y_train[group_train==0], group_train[group_train==0]
    #
    # x_train = np.concatenate((x_o_neg, x_z, x_o_pos), axis=0)
    # y_train = np.concatenate((y_o_neg, y_z, y_o_pos), axis=0)
    # group_train = np.concatenate((group_o_neg, group_z, group_o_pos), axis=0)

    ## Replace with a better scaler since certain labels might require separate attention
    datascaler = MinMaxScaler()
    datascaler.fit(x_train)

    x_train = datascaler.transform(x_train)
    x_valid = datascaler.transform(x_valid)
    x_test = datascaler.transform(x_test)

    ## Setup DataLoader
    tensorx_train = torch.from_numpy(x_train)
    tensory_train = torch.from_numpy(y_train)
    tensorgroup_train = torch.from_numpy(group_train)
    train_dataset = TensorDataset(tensorx_train, tensory_train, tensorgroup_train)

    tensorx_valid = torch.from_numpy(x_valid)
    tensory_valid = torch.from_numpy(y_valid)
    tensorgroup_valid = torch.from_numpy(group_valid)
    valid_dataset = TensorDataset(tensorx_valid, tensory_valid, tensorgroup_valid)

    tensorx_test = torch.from_numpy(x_test)
    tensory_test = torch.from_numpy(y_test)
    tensorgroup_test = torch.from_numpy(group_test)
    test_dataset = TensorDataset(tensorx_test, tensory_test, tensorgroup_test)

    drop_last_bool = True
    custombatch = False
    # custombatch = True
    if fairbatch:
        sampler = FairBatch(model, tensorx_train.cuda().float(), tensory_train.cuda().long(), tensorgroup_train.cuda(), batch_size=128,
                            alpha=0.005, target_fairness='eqodds', replacement=False, seed=0)
        trainloader = DataLoader(train_dataset, sampler=sampler)
    elif custombatch:
        trainloader = DataLoader(train_dataset, batch_size=128, shuffle=False, drop_last=drop_last_bool)
        sampler = CustomBatch(model, tensorx_train.cuda().float(), tensory_train.cuda().long(), tensorgroup_train.cuda(), trainloader, batch_size=128, seed=0)
        trainloader = DataLoader(train_dataset, sampler=sampler)
    else:
        # trainloader = DataLoader(train_dataset, batch_size=128, shuffle=dataloadershuffle, drop_last=drop_last_bool)
        trainloader = DataLoader(train_dataset, batch_size=128, shuffle=dataloadershuffle, drop_last=drop_last_bool)

    validloader = DataLoader(valid_dataset, batch_size=128, shuffle=False, drop_last=drop_last_bool)
    testloader = DataLoader(test_dataset, batch_size=128, shuffle=False, drop_last=drop_last_bool)

    return trainloader, validloader, testloader
