import os
from pprint import pprint

import matplotlib.pyplot as plt
import numpy as np
try:
    import pandas as pd
except ModuleNotFoundError:
    print('pandas not installed, NIH dataset wont work')
import torch
import torchvision
from PIL import Image
try:
    from skmultilearn.model_selection import iterative_train_test_split
except ModuleNotFoundError:
    print('skmultilearn not installed, NIH dataset wont work')

from torch.utils.data import Dataset, DataLoader

from src.utils.sysutils import get_cores_count

NIH_CXR_PRED_LABEL = [
    'Atelectasis',
    'Cardiomegaly',
    'Effusion',
    'Infiltration',
    'Mass',
    'Nodule',
    'Pneumonia',
    'Pneumothorax',
    'Consolidation',
    'Edema',
    'Emphysema',
    'Fibrosis',
    'Pleural_Thickening',
    'Hernia']

NIH_CXR_PRED_LABEL_TO_INDEX_MAP = dict(map(lambda t: (t[1], t[0]), enumerate(NIH_CXR_PRED_LABEL)))


def map_label_to_one_hot(labels):
    labels = labels.split('|')
    one_hot_encoding = np.zeros(len(NIH_CXR_PRED_LABEL))
    for label in labels:
        if label != 'No Finding':
            one_hot_encoding[NIH_CXR_PRED_LABEL_TO_INDEX_MAP[label]] = 1
    return one_hot_encoding


class NihCxrDataframeDataset(torch.utils.data.Dataset):

    def __init__(self, df, dataset_dir, transform=None):
        self.df = df
        self.dataset_dir = dataset_dir
        self.path_to_images = os.path.join(self.dataset_dir, 'images')
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        sample = self.df.iloc[idx]
        try:
            image = Image.open(
                os.path.join(
                    self.path_to_images,
                    sample['Image Index']))
            image = image.convert('RGB')
        except IOError:
            print(f"failed to load image {os.path.join(self.path_to_images, sample['Image Index'])}")
            raise

        label = torch.FloatTensor(map_label_to_one_hot(sample['Finding Labels']))

        if self.transform:
            image = self.transform(image)

        return image, label


def map_one_hot_to_label(one_hot_encoding):
    if np.count_nonzero(one_hot_encoding) == 0:
        return 'No Finding'
    label_indices = np.where(one_hot_encoding == 1)[0]
    return '|'.join(NIH_CXR_PRED_LABEL[j] for j in range(len(label_indices)))


class NihCxr:

    def __init__(self,
                 train_data_args,
                 val_data_args,
                 dataset_args,
                 split_ratio=0.9):
        self.cpu_count = get_cores_count()
        self.train_data_args = train_data_args
        self.val_data_args = val_data_args

        self.dataset_dir = dataset_args.get('path_to_dataset', './data/' + self.__class__.__name__)

        # RGB Order
        # use imagenet mean,std for normalization
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]

        self.demean = [-m / s for m, s in zip(mean, std)]
        self.destd = [1 / s for s in std]

        self.normalize_transform = torchvision.transforms.Compose(
            [torchvision.transforms.Resize(224),
             torchvision.transforms.CenterCrop(224),
             torchvision.transforms.ToTensor(),
             torchvision.transforms.Normalize(mean, std)])

        # Normalization transform does (x - mean) / std
        # To denormalize use mean* = (-mean/std) and std* = (1/std)
        self.denormalization_transform = torchvision.transforms.Normalize(self.demean, self.destd, inplace=False)

        self.df = pd.read_csv(os.path.join(self.dataset_dir, "Data_Entry_2017.csv"))
        self.train_val_df_indices = pd.read_csv(os.path.join(self.dataset_dir, 'train_val_list.txt'),
                                                header=None, names=['Image Index'])
        self.test_df_indices = pd.read_csv(os.path.join(self.dataset_dir, 'test_list.txt'),
                                           header=None, names=['Image Index'])

        # Split train data into training and cross validation dataset using split ratio
        self.split_ratio = split_ratio
        self.trainset, self.validationset, self.testset = self.preprocess_csv()

    def preprocess_csv(self):
        self.df.drop(self.df.columns[2:], axis=1, inplace=True)
        train_val_df = self.df[self.df['Image Index'].isin(self.train_val_df_indices['Image Index'])]

        # Since samples of each patient should go to one set, each patient can be clustered into single sample
        # To make sure labels are properly stratified, generate combined set of label for each patient
        # [We ignore the fact that patients might have different number of images. This on average will get balanced
        # out.]
        def add_patient_id(row):
            return row['Image Index'].split('_')[0]

        train_val_df['Patient Id'] = train_val_df.apply(add_patient_id, axis=1)

        num_patients = train_val_df['Patient Id'].nunique()
        x = patient_ids = train_val_df['Patient Id'].unique()
        # For No Finding, we add an extra column. This allows presenting if clustered samples of one patient to
        # represent if one of those sample was of No Finding
        y = np.zeros(shape=(num_patients, len(NIH_CXR_PRED_LABEL) + 1), dtype=np.int)
        PRED_LABEL_TO_INDEX_MAP = dict(map(lambda t: (t[1], t[0]), enumerate(NIH_CXR_PRED_LABEL)))
        PRED_LABEL_TO_INDEX_MAP['No Finding'] = len(NIH_CXR_PRED_LABEL)

        if not os.path.isfile(os.path.join(self.dataset_dir, 'ClusteredLabelsPatientWise.npy')):
            for index, patient_id in enumerate(patient_ids):
                labels = '|'.join(train_val_df[train_val_df['Patient Id'] == patient_id]['Finding Labels'])
                unique_labels = set(labels.split('|'))
                for label in unique_labels:
                    y[index, PRED_LABEL_TO_INDEX_MAP[label]] = 1
                np.save(os.path.join(self.dataset_dir, 'ClusteredLabelsPatientWise.npy'), y)
        else:
            y = np.load(os.path.join(self.dataset_dir, 'ClusteredLabelsPatientWise.npy'))

        x = np.reshape(x, (-1, 1))  # iterative_train_test_split requires x to be array of array
        x_train, y_train, x_val, y_val = iterative_train_test_split(x, y, test_size=1 - self.split_ratio)
        x_train = np.squeeze(np.reshape(x_train, (-1))).tolist()

        train_indices = train_val_df['Patient Id'].isin(x_train)
        val_indices = ~train_indices

        self.train_df = train_val_df[train_indices]
        self.val_df = train_val_df[val_indices]
        self.test_df = self.df[self.df['Image Index'].isin(self.test_df_indices['Image Index'])]

        # Create torch datasets.
        trainset = NihCxrDataframeDataset(df=self.train_df,
                                          dataset_dir=self.dataset_dir,
                                          transform=self.normalize_transform)
        validationset = NihCxrDataframeDataset(df=self.val_df,
                                               dataset_dir=self.dataset_dir,
                                               transform=self.normalize_transform)
        testset = NihCxrDataframeDataset(df=self.test_df,
                                         dataset_dir=self.dataset_dir,
                                         transform=self.normalize_transform)
        return trainset, validationset, testset

    @property
    def train_dataloader(self) -> DataLoader:
        return torch.utils.data.DataLoader(self.trainset,
                                           batch_size=self.train_data_args['batch_size'],
                                           shuffle=self.train_data_args['shuffle'],
                                           pin_memory=True,
                                           num_workers=get_cores_count())

    @property
    def validation_dataloader(self) -> DataLoader:
        return torch.utils.data.DataLoader(self.validationset,
                                           batch_size=self.train_data_args['batch_size'],
                                           shuffle=self.train_data_args['shuffle'],
                                           pin_memory=True,
                                           num_workers=get_cores_count())

    @property
    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.testset,
                                           batch_size=self.val_data_args['batch_size'],
                                           shuffle=self.val_data_args['shuffle'],
                                           pin_memory=True,
                                           num_workers=get_cores_count())

    def imshow(self, img):
        # clamp to get rid of numerical errors
        img = torch.clamp(self.denormalize(img), 0.0, 1.0)  # denormalize
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
        plt.show()

    def debug(self):
        # get some random training images
        data_iter = iter(self.train_dataloader)
        images, labels = data_iter.next()

        # show images
        self.imshow(torchvision.utils.make_grid(images))

        # print labels
        for j in range(len(images)):
            pprint(map_one_hot_to_label(labels[j]))

    def denormalize(self, x):
        return self.denormalization_transform(x)

    @property
    def train_dataset_size(self):
        return len(self.trainset)

    @property
    def val_dataset_size(self):
        return len(self.validationset)

    @property
    def test_dataset_size(self):
        return len(self.testset)

    @property
    def classes(self):
        return NIH_CXR_PRED_LABEL

    def pos_neg_balance_weights(self):
        pos_neg_weights = []

        # ToDo - Confirm weighting strategy
        for label in self.classes:
            num_negatives = self.train_df['Finding Labels'].str.contains(label).sum()
            num_positives = len(self.train_df) - num_negatives
            pos_neg_weights.append(num_negatives / num_positives)

        return torch.Tensor(pos_neg_weights)

    def pos_neg_balance_weights_in_batch(self, labels):
        pos_neg_weights = []

        for label_index, label in enumerate(NIH_CXR_PRED_LABEL):
            num_positives = labels[:, label_index].sum()
            num_negatives = labels.shape[0] - num_positives
            if num_positives == 0:
                beta_p = 1  # ToDo Discuss with Ashkan or try with values: {1, num_negatives}
            else:
                beta_p = num_negatives / num_positives
            pos_neg_weights.append(beta_p)
        return torch.Tensor(pos_neg_weights)


if __name__ == '__main__':
    dataset_args = dict(
    )

    train_data_args = dict(
        batch_size=8,
        shuffle=True,
    )

    val_data_args = dict(
        batch_size=train_data_args['batch_size'] * 4,
        shuffle=False,
        validate_step_size=1,
    )
    dataset = NihCxr(train_data_args, val_data_args, dataset_args, split_ratio=7.0 / 8)
    dataset.debug()
    print(dataset.pos_neg_balance_weights())
    train_dataloader = dataset.train_dataloader
    data, labels = iter(train_dataloader).next()
    print(labels)
    print(dataset.pos_neg_balance_weights_in_batch(labels))
