import os
import torch
import random
import numpy as np
import torchvision
import glob
import pandas as pd
import pickle

from torchvision import transforms, datasets
from PIL import Image
from torchvision.datasets.imagenet import load_meta_file
from torchvision.datasets.utils import verify_str_arg
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import Dataset

from utils.utils import  set_random_seed

set_random_seed


def construct_datasets(args, dataset, data_path, load=False, split=False):
    if dataset == 'Diabetes':
        # save split data
        if load == False:
            file = os.path.join(data_path, 'diabetes_binary.csv')
            data = pd.read_csv(file)

            X = data.drop(['Diabetes_binary'], axis=1).to_numpy()
            y = data['Diabetes_binary'].to_numpy()

            X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=args.trainsize)

            # Use 20 % of (Train + Validation) set as Validation set
            X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2)

            scaler = MinMaxScaler()
            X_train = scaler.fit_transform(X_train)
            X_val = scaler.transform(X_val)

            path = os.path.join(args.moddir, args.dataset + '_' + args.modname + '_data.pickle')
            dict = {}
            dict['X_train'], dict['X_val'], dict['y_train'], dict['y_val'] = X_train, X_val, y_train, y_val
            with open(path, 'wb+') as f:
                pickle.dump(dict, f)
        else:
            path = os.path.join(args.moddir, args.dataset + '_' + args.modname + '_data.pickle')
            with open(path, 'rb') as handle:
                dict = pickle.load(handle)
                X_train, X_val, y_train, y_val = dict['X_train'], dict['X_val'], dict['y_train'], dict['y_val']

        if split == False:
            trainset = PackData(X_train, y_train)
            testset = PackData(X_val, y_val)

            classes = ['no diabetes', 'prediabetes or diabetes']
            trainset.classes = classes
            testset.classes = classes
        else:
            shardsize = len(X_train) // args.shards
            trainset = {}
            classes = ['no diabetes', 'prediabetes or diabetes']
            for m in range(args.shards):
                if m == args.shards - 1:
                    X_train_shard, y_train_shard = X_train[m * shardsize:len(X_train)], y_train[m * shardsize:len(X_train)]
                else:
                    X_train_shard, y_train_shard = X_train[m * shardsize:(m + 1) * shardsize], y_train[m * shardsize:(m + 1) * shardsize]
                trainsubset = PackData(X_train_shard, y_train_shard)
                trainsubset.classes = classes
                trainset[m] = trainsubset
            testset = PackData(X_val, y_val)
            testset.classes = classes

    elif dataset == 'CIFAR10':
        data_mean = (0.4914, 0.4822, 0.4465)
        data_std = (0.2023, 0.1994, 0.2010)
        transform_train = transforms.Compose([
            #transforms.RandomCrop(32, padding=4),
            #transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(data_mean, data_std), ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(data_mean, data_std), ])

        train_data = CIFAR10(root=data_path, train=True, download=True, transform=transform_train)
        test_data= CIFAR10(root=data_path, train=False, download=True, transform=transform_test)

        if load == False:
            train_idx, _ = train_test_split(np.arange(len(train_data)), train_size=args.trainsize)
            test_idx, _ = train_test_split(np.arange(len(test_data)), train_size=args.trainsize)

            path = os.path.join(args.moddir, args.dataset + '_' + args.modname + '_data.pickle')
            dict = {}
            dict['train_idx'], dict['test_idx'] = train_idx, test_idx
            with open(path, 'wb+') as f:
                pickle.dump(dict, f)
        else:
            path = os.path.join(args.moddir, args.dataset + '_' + args.modname + '_data.pickle')
            with open(path, 'rb') as handle:
                dict = pickle.load(handle)
                train_idx, test_idx = dict['train_idx'], dict['test_idx']

        if split == False:
            trainset = torch.utils.data.Subset(train_data, train_idx)
            testset = torch.utils.data.Subset(test_data, test_idx)

            trainset.classes = train_data.classes
            testset.classes = test_data.classes
        else:
            shardsize = len(train_idx) // args.shards
            trainset = {}
            for m in range(args.shards):
                if m == args.shards - 1:
                    train_idx_shard = train_idx[m * shardsize:len(train_idx)]
                else:
                    train_idx_shard = train_idx[m * shardsize:(m + 1) * shardsize]
                trainsubset = torch.utils.data.Subset(train_data, train_idx_shard)
                trainsubset.classes = train_data.classes
                trainset[m] = trainsubset
            testset = torch.utils.data.Subset(test_data, test_idx)
            testset.classes = test_data.classes

    return trainset, testset


class PackData(Dataset):
    def __init__(self, X, Y):
        self.X = torch.FloatTensor(X)
        self.Y = torch.LongTensor(Y)

    def __len__(self):
        return len(self.X) # return length

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx], idx # return list of batch data [data, labels]


class CIFAR10(datasets.CIFAR10):
    """Super-class CIFAR10 to return image ids with images."""
    def __getitem__(self, index):
        """
        Returns: (image, target, idx) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets to return a PIL Image
        img = Image.fromarray(img)
        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index

    def get_target(self, index):
        """
        Returns: (target, idx) where target is class_index of the target class.
        """
        target = self.targets[index]

        if self.target_transform is not None:
            target = self.target_transform(target)

        return target, index


def targetData(dataset, targetclass, targetids):
    ids = []
    for i in range(len(dataset)):
        if dataset[i][1] == targetclass:
          ids.append(dataset[i][2])

    return [ids[x] for x in targetids]


def poisonData(dataset, poisonclass, npoison, targeted=True):
    if targeted:
        ids = []
        for i in range(len(dataset)):
            if dataset[i][1] == poisonclass:
              ids.append(dataset[i][2])

        selected = random.sample(ids, npoison)
    else:
        selected = random.sample(range(len(dataset)), npoison)

    return sorted(selected)