import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
import os

import albumentations as A
from albumentations.pytorch.transforms import ToTensor

def get_augmentations_v1(image_size=256, is_test=True):
    '''
    https://www.kaggle.com/vishnus/a-simple-pytorch-starter-code-single-fold-93
    '''
    imagenet_stats = {'mean':[0.485, 0.456, 0.406], 'std':[0.229, 0.224, 0.225]}
    train_tfms = A.Compose([
        A.Cutout(p=0.5),
        A.RandomRotate90(p=0.5),
        A.Flip(p=0.5),
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.2,
                                       contrast_limit=0.2,
                                       ),
            A.HueSaturationValue(
                hue_shift_limit=20,
                sat_shift_limit=50,
                val_shift_limit=50)
        ], p=0.5),
        A.OneOf([
            A.IAAAdditiveGaussianNoise(),
            A.GaussNoise(),
        ], p=0.5),
        A.OneOf([
            A.MotionBlur(p=0.2),
            A.MedianBlur(blur_limit=3, p=0.1),
            A.Blur(blur_limit=3, p=0.1),
        ], p=0.5),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.5),
        A.OneOf([
            A.OpticalDistortion(p=0.3),
            A.GridDistortion(p=0.1),
            A.IAAPiecewiseAffine(p=0.3),
        ], p=0.5), 
        ToTensor(normalize=imagenet_stats)
        ])
    
    test_tfms = A.Compose([
        ToTensor(normalize=imagenet_stats)
        ])
    if is_test:
        return test_tfms
    else:
        return train_tfms

class Melanoma(Dataset):
    r'''
        Reference:
           - https://www.kaggle.com/cdeotte/jpeg-melanoma-256x256
           - https://www.kaggle.com/vishnus/a-simple-pytorch-starter-code-single-fold-93
           - https://www.kaggle.com/haqishen/1st-place-soluiton-code-small-ver
    '''
    def __init__(self, root, test_size=0.2, is_test=False, is_valid=False, transforms=None):
        assert os.path.isfile(root + '/melanoma/train.csv'), 'There is no train.csv in %s!'%root
        self.data = pd.read_csv(root + '/melanoma/train.csv')
        print('self.data shape:', self.data.shape)
        data_len = len(self.data)
        # self.train_df, self.test_valid_df = self.get_train_val_split(self.data, test_size=test_size)
        self.train_df = self.data[:int(data_len * 0.8)]
        print('self.train_df shape:', self.train_df.shape)
        # self.valid_df, self.test_df = self.get_train_val_split(self.test_valid_df, test_size=0.5)
        self.valid_df = self.data[int(data_len * 0.8):int(data_len * 0.9)]
        self.test_df = self.data[int(data_len * 0.9):]
        print('self.valid_df shape: ', self.valid_df.shape, 'self.test_df shape: ', self.test_df.shape)
        self.is_test = is_test
        self.is_valid = is_valid
       
        if is_test:
            self.df = self.test_df.copy()
        else:
            if is_valid:
                self.df = self.valid_df.copy()
            else:
                self.df = self.train_df.copy()
            
        self._num_images = len(self.df)
        self.value_counts_dict = self.df.target.value_counts().to_dict()
        print('value_counts_dict: ', self.value_counts_dict)
        self.imratio = self.value_counts_dict[1]/self.value_counts_dict[0]
        print ('Found %s image in total, %s postive images, %s negative images.'%(self._num_images, self.value_counts_dict[1], self.value_counts_dict[0]))

        self.pos_len = self.value_counts_dict[1]
        self.neg_len = self.value_counts_dict[0]

        # get path 
        dir_name = 'melanoma/train'
        self._images_list = [f"{root}/{dir_name}/{img}.jpg" for img in self.df.image_name]
        self._labels_list =  self.df.target.values.tolist()
        self.targets = np.array(self._labels_list).astype(np.int32)

        if not transforms:
            self.transforms = get_augmentations_v1(is_test=is_test)
        else:
            self.transforms = transforms(is_test=is_test)
            
    @property        
    def class_counts(self):
        return self.value_counts_dict
    
    @property
    def imbalance_ratio(self):
        return self.imratio

    @property
    def num_classes(self):
        return 1
    
    def get_train_val_split(self, df, test_size=0.2):
        print ('test set split is %s'%test_size)
        #Remove Duplicates
        df = df[df.tfrecord != -1].reset_index(drop=True)
        #We are splitting data based on triple stratified kernel provided here https://www.kaggle.com/c/siim-isic-melanoma-classification/discussion/165526
        num_tfrecords = len(df.tfrecord.unique())
        print('num_tfrecords: ', num_tfrecords)
        train_tf_records = list(range(len(df.tfrecord.unique())))[:-int(num_tfrecords*test_size)]
        print('train_tf_records: ', train_tf_records)
        split_cond = df.tfrecord.apply(lambda x: x in train_tf_records)
        print('split_cond: ', split_cond)
        train_df = df[split_cond].reset_index()
        valid_df = df[~split_cond].reset_index()
        return train_df, valid_df
    
    def __len__(self):
        return self.df.shape[0]   
    
    def __getitem__(self,idx):
        img_path = self._images_list[idx]
        image = Image.open(img_path)
        image = self.transforms(**{"image": np.array(image)})["image"]
        target = torch.tensor([self._labels_list[idx]],dtype=torch.float32) 
        return image, target, idx
    
# if __name__ == '__main__':
#     trainSet = Melanoma(root='./data/', is_test=False, test_size=0.2)
#     validSet = Melanoma(root='./data/', is_test=False, is_valid=True, test_size=0.2)
#     testSet = Melanoma(root='./data/', is_test=True, test_size=0.2)
#     batch_size = 128
#     trainloader = DataLoader(dataset=trainSet, batch_size=batch_size, shuffle=True, num_workers=1)
#     validloader = DataLoader(dataset=validSet, batch_size=batch_size, shuffle=False, num_workers=1)
#     testloader = DataLoader(dataset=testSet, batch_size=batch_size, shuffle=False, num_workers=1)
#
#     for idx, data in enumerate(trainloader):
#         if idx < 1:
#             # print('data shape: ', data.shape)
#             print('data :', data)
#             for item in data:
#                 print("item shape: ", item.shape)
#         else:
#             break


    