from utils import *
from .bases import BaseSet
from scipy.io import mmread
from torchvision.transforms import ToTensor, ToPILImage


DATA_INFO = {
              "DDSM": {"dataset_location": "DDSM"},
              "CheXpert": {"dataset_location": "CheXpert"},
              "ISIC2019": {"dataset_location": "ISIC2019"},
              "APTOS2019": {"dataset_location": "APTOS2019"},
              "Camelyon": {"dataset_location": "Camelyon"},    
              "BreakHis": {"dataset_location": "BreakHis"},
              "MURA": {"dataset_location": "MURA"},
}


class CheXpert(BaseSet):

    img_channels = 1
    is_multiclass = False
    task = 'classification'    
    mean = (0.503,)
    std = (0.292,)
    knn_nhood = 200
    target_metric = 'roc_auc'
    int_to_labels = {0: 'No Finding',
                     1: 'Enlarged Cardiomediastinum',
                     2: 'Cardiomegaly',
                     3: 'Lung Opacity',
                     4: 'Lung Lesion',
                     5: 'Edema',
                     6: 'Consolidation',
                     7: 'Pneumonia',
                     8: 'Atelectasis',
                     9: 'Pneumothorax',
                     10: 'Pleural Effusion',
                     11: 'Pleural Other',
                     12: 'Fracture',
                     13: 'Support Devices'
                     }
    n_classes = len(int_to_labels)
    labels_to_int = {val: key for key, val in int_to_labels.items()}
    
    def __init__(self, dataset_params, mode='train'):
        self.attr_from_dict(dataset_params)
        self.dataset_location = DATA_INFO["CheXpert"]["dataset_location"]
        self.root_dir = os.path.join(self.data_location, self.dataset_location)
        self.mode = mode
        self.data = self.get_dataset()
        self.transform, self.resizing = self.get_transforms()
    
    def get_data_as_list(self, data_loc):
        data_list = []     
        testval_size = 0.2
        datainfo = pd.read_csv(data_loc, engine='python')
        datainfo = datainfo.sort_values(by=['Path'])
        val_id_json = os.path.join(self.root_dir, 'val_ids.json') 
        testval_size = int(len(datainfo) * testval_size) + 15  # the 15 is just to make a better patient split 
        
        # data = np.array(datainfo.iloc[:, -self.n_classes:].fillna(0).replace(-1, 0).values.tolist())
        if self.mode == 'train':
            data = datainfo.loc[testval_size:]
        else:
            data = datainfo.loc[:testval_size]
            if self.mode in ['val', 'eval']:
                data = data.loc[:int(len(data)/2)]
            elif self.mode in ['test']:
                data = data.loc[int(len(data)/2):]                
            else:
                raise ValueError(f"mode {self.mode} not understood")
        
        # we use the U-Zeroes model i.e. we replace NaNs and -1s with 0s
        labels = data.iloc[:, -self.n_classes:].fillna(0).replace(-1, 0).values.tolist()
        # converting to 0: 'No Finding' every label that has only zeros
        for l in range(len(labels)):
            if sum(labels[l]) == 0:
                labels[l][self.labels_to_int['No Finding']] = 1
   
        # remove multilabels and keep only single labels (the 0 everywhere is also omitted)
        if self.is_multiclass:
            labels = [label for label in labels if sum(label) == 1]
            labels = np.where(np.array(labels)==1)[1].tolist()
        img_paths = data['Path'].values.tolist()
        img_paths = [os.path.join(self.root_dir, *img_path.split('/')[1:]) for img_path in img_paths]
        for img_path, label in zip(img_paths, labels):
            data_list.append({'img_path': img_path, 'label': label, 'dataset': self.name})
                    
        return data_list
    
    def get_dataset(self):
        if self.mode in ['train', 'val', 'eval', 'test']:
            self.df_path = os.path.join(self.root_dir, 'train.csv')
        else:
            self.df_path = os.path.join(self.root_dir, 'valid.csv')
        return self.get_data_as_list(self.df_path) 
    
    
class DDSM(BaseSet):
    img_channels = 1
    is_multiclass = True
    task = 'classification'   
    knn_nhood = 200
    target_metric = 'roc_auc'
    
    def __init__(self, dataset_params, mode='train', n_class=2, is_patch=False):
        self.attr_from_dict(dataset_params)
        self.mode = mode
        self.n_class = n_class
        self.is_patch = is_patch
        self.export_labels_as_int()
        self.init_stats()
        self.n_classes = len(self.int_to_labels)
        assert self.n_classes == self.n_class
        self.dataset_location = DATA_INFO["DDSM"]["dataset_location"]
        self.root_dir = os.path.join(self.data_location, self.dataset_location)
        self.label_mode = '{}class'.format(self.n_class)
        
        self.data = self.get_dataset()
        self.transform, self.resizing = self.get_transforms()
        
    def get_data_as_list(self, data_loc):
        data_list = []
        data = pd.read_csv(data_loc, sep=" ", header=None, engine='python')
        if self.is_patch:
            data.columns = ["img_path", "label"]
            for img_path, label in zip(data['img_path'], data['label']):
                img_path = os.path.join(*img_path.split("/")[1:])
                img_path = os.path.join(self.root_dir, 'ddsm_patches', img_path)
                data_list.append({'img_path': img_path, 'label': label, 'dataset': self.name})
        else:
            data.columns = ["img_path"]
            txt_to_lbl = {'normal': 0, 'benign': 1, 'cancer': 2}
            for img_path in data['img_path']:
                img_path = os.path.join(self.root_dir, 'ddsm_raw', img_path)
                label = os.path.basename(img_path).split("_")[0]
                label = txt_to_lbl[label]
                if self.n_classes == 2 and label > 1:
                    label = 1
                if not self.is_multiclass:
                    label = [float(label)]
                data_list.append({'img_path': img_path, 'label': label, 'dataset': self.name})
                    
        return data_list
    
    def get_dataset(self):
        if self.is_patch:
            self.df_path = os.path.join(self.root_dir, 'ddsm_labels', self.label_mode)
        else:
            self.df_path = os.path.join(self.root_dir, 'ddsm_raw_image_lists')
        if self.mode == 'train':
            self.df_path = os.path.join(self.df_path, 'train.txt')
        elif self.mode in ['val', 'eval']:
            self.df_path = os.path.join(self.df_path, 'val.txt')
        elif self.mode == 'test':
            self.df_path = os.path.join(self.df_path, 'test.txt')
        return self.get_data_as_list(self.df_path)
            
    def init_stats(self):
        if self.is_patch:
            self.mean = (0.44,)
            self.std = (0.25,)
        else:
            self.mean = (0.286,)
            self.std = (0.267,)      
        
    def export_labels_as_int(self):
        if self.n_class == 3:
            self.int_to_labels = {
                0: 'Normal',
                1: 'Benign',
                2: 'Cancer'
            }
        else:
            self.int_to_labels = {
                0: 'Normal',
                1: 'Cancer'
            }
        self.labels_to_int = {val: key for key, val in self.int_to_labels.items()} 
        
    
class ISIC2019(BaseSet):
    
    img_channels = 3
    is_multiclass = True
    task = 'classification'
    mean = [0.66776717, 0.52960888, 0.52434725]
    std = [0.22381877, 0.20363036, 0.21538623]
    knn_nhood = 200    
    int_to_labels = {
        0: 'Melanoma',
        1: 'Melanocytic nevus',
        2: 'Basal cell carcinoma',
        3: 'Actinic keratosis',
        4: 'Benign keratosis',
        5: 'Dermatofibroma',
        6: 'Vascular lesion',
        7: 'Squamous cell carcinoma'
    }
    target_metric = 'recall'
    n_classes = len(int_to_labels)
    labels_to_int = {val: key for key, val in int_to_labels.items()}
    
    def __init__(self, dataset_params, mode='train'):
        self.attr_from_dict(dataset_params)
        self.dataset_location = DATA_INFO["ISIC2019"]["dataset_location"]
        self.root_dir = os.path.join(self.data_location, self.dataset_location)
        self.mode = mode
        self.data = self.get_data_as_list()
        self.transform, self.resizing = self.get_transforms()
        
    def get_data_as_list(self):
        data_list = []
        datainfo = pd.read_csv(os.path.join(self.root_dir, 'ISIC_2019_Training_GroundTruth.csv'), engine='python')
        metadata = pd.read_csv(os.path.join(self.root_dir, 'ISIC_2019_Training_Metadata.csv'), engine='python')
        labellist = datainfo.values[:, 1:].nonzero()[1].tolist()
        img_names = datainfo.values[:, 0].tolist()
        img_names = [os.path.join(self.root_dir, 'train',  imname + '.jpg') for imname in img_names]
        dataframe = pd.DataFrame(list(zip(img_names, labellist)), 
                                 columns=['img_path', 'label'])
        
        val_id_json = os.path.join(self.root_dir, 'val_ids.json')
        train_ids, test_val_ids = self.get_validation_ids(total_size=len(dataframe), val_size=0.2, 
                                                          json_path=val_id_json, 
                                                          dataset_name=self.name)
        val_ids = test_val_ids[:int(len(test_val_ids)/2)]
        test_ids = test_val_ids[int(len(test_val_ids)/2):]     
        
        if self.mode == 'train':
            data = dataframe.loc[train_ids, :]
        elif self.mode in ['val', 'eval']:
            data = dataframe.loc[val_ids, :]
        else:
            data = dataframe.loc[test_ids, :]
        labels = data['label'].values.tolist()
        img_paths = data['img_path'].values.tolist()
        data_list = [{'img_path': img_path, 'label': label, 'dataset': self.name}
                     for img_path, label in zip(img_paths, labels)]
                    
        return data_list  
    
    
class APTOS2019(BaseSet):
    
    img_channels = 3
    is_multiclass = True
    task = 'classification'
    mean = (0.415, 0.221, 0.073)
    std = (0.275, 0.150, 0.081)
    int_to_labels = {
        0: 'No DR',
        1: 'Mild',
        2: 'Moderate',
        3: 'Severe',
        4: 'Proliferative DR'
    }
    target_metric = 'quadratic_kappa'
    knn_nhood = 200    
    n_classes = len(int_to_labels)
    labels_to_int = {val: key for key, val in int_to_labels.items()}
    
    def __init__(self, dataset_params, mode='train'):
        self.attr_from_dict(dataset_params)
        self.dataset_location = DATA_INFO["APTOS2019"]["dataset_location"]
        self.root_dir = os.path.join(self.data_location, self.dataset_location)
        self.mode = mode
        self.data = self.get_data_as_list()
        self.transform, self.resizing = self.get_transforms()
        
    def get_data_as_list(self):
        data_list = []
        datainfo = pd.read_csv(os.path.join(self.root_dir, 'train.csv'), engine='python')
        labellist = datainfo.diagnosis.tolist()
        img_names = datainfo.id_code.tolist()
        img_names = [os.path.join(self.root_dir, 'train_images', imname + '.png') for imname in img_names]
        dataframe = pd.DataFrame(list(zip(img_names, labellist)), 
                                 columns=['img_path', 'label'])
        
        val_id_json = os.path.join(self.root_dir, 'val_ids.json')
        train_ids, test_val_ids = self.get_validation_ids(total_size=len(dataframe), val_size=0.3, 
                                                          json_path=val_id_json, 
                                                          dataset_name=self.name)
        val_ids = test_val_ids[:int(len(test_val_ids)/2)]
        test_ids = test_val_ids[int(len(test_val_ids)/2):]     
        
        if self.mode == 'train':
            data = dataframe.loc[train_ids, :]
        elif self.mode in ['val', 'eval']:
            data = dataframe.loc[val_ids, :]
        else:
            data = dataframe.loc[test_ids, :]
        labels = data['label'].values.tolist()
        img_paths = data['img_path'].values.tolist()
        data_list = [{'img_path': img_path, 'label': label, 'dataset': self.name}
                     for img_path, label in zip(img_paths, labels)]
                    
        return data_list    

    
class BreakHis(BaseSet):
    
    img_channels = 3
    is_multiclass = True
    task = 'classification'
    mean = (0.788, 0.627, 0.767)
    std = (0.129, 0.178, 0.114)
    int_to_labels = {0: "benign", 
                     1: "malignant"}  
    target_metric = 'accuracy'   
    knn_nhood = 200    
    n_classes = len(int_to_labels)
    labels_to_int = {val: key for key, val in int_to_labels.items()}
    
    def __init__(self, dataset_params, mode='train'):
        self.attr_from_dict(dataset_params)
        self.dataset_location = DATA_INFO["BreakHis"]["dataset_location"]
        self.root_dir = os.path.join(self.data_location, self.dataset_location)
        self.mode = mode
        self.data = self.get_data_as_list()
        self.transform, self.resizing = self.get_transforms()
        
    def get_data_as_list(self):
        data_list = []
        patient_dict = {}
        for classif in ["benign", "malignant"]:
            sev_dir = os.path.join(self.root_dir, classif, "SOB")
            for htype_path in glob(os.path.join(sev_dir, "*")):
                htype = os.path.basename(htype_path)
                for patient_path in glob(os.path.join(htype_path, "*")):
                    patient = os.path.basename(patient_path)
                    if patient not in patient_dict:
                        patient_dict[patient] = {}
                        patient_dict[patient]["images"] = []
                    patient_dict[patient]["class"] = classif
                    patient_dict[patient]["type"] = htype                    
                    for mag in glob(os.path.join(patient_path, "*")):
                        images = glob(os.path.join(mag, "*.png"))
                        patient_dict[patient]["images"] += images

        benign_patients = [p for p, v in patient_dict.items() if v['class'] == "benign"]
        malignant_patients = [p for p, v in patient_dict.items() if v['class'] == "malignant"]

        val_id_json_benign = os.path.join(self.root_dir, 'val_ids_benign.json')
        val_id_json_malignant = os.path.join(self.root_dir, 'val_ids_malignant.json')
        train_ids_benign, test_val_ids_benign = self.get_validation_ids(
            total_size=len(benign_patients), val_size=0.2, 
            json_path=val_id_json_benign, 
            dataset_name=self.name + " : benign") 
        train_ids_malignant, test_val_ids_malignant = self.get_validation_ids(
            total_size=len(malignant_patients), val_size=0.2, 
            json_path=val_id_json_malignant, 
            dataset_name=self.name + " : malignant")         

        val_ids_benign = test_val_ids_benign[:int(len(test_val_ids_benign)/2)]
        test_ids_benign = test_val_ids_benign[int(len(test_val_ids_benign)/2):]  
        val_ids_malignant = test_val_ids_malignant[:int(len(test_val_ids_malignant)/2)]
        test_ids_malignant = test_val_ids_malignant[int(len(test_val_ids_malignant)/2):]  
        
        train_patients = [benign_patients[p] for p in train_ids_benign] \
                            + [malignant_patients[p] for p in train_ids_malignant]
        val_patients = [benign_patients[p] for p in val_ids_benign] \
                            + [malignant_patients[p] for p in val_ids_malignant]
        test_patients = [benign_patients[p] for p in test_ids_benign] \
                            + [malignant_patients[p] for p in test_ids_malignant]
        train_data = []
        val_data = []
        test_data = []
        for p, v in patient_dict.items():
            for img_path in v['images']:
                label = self.labels_to_int[v["class"]]
                rec = {'img_path':img_path, 'label':label, 'dataset':self.name}
                if p in train_patients:
                    train_data.append(rec)
                elif p in val_patients:
                    val_data.append(rec)
                elif p in test_patients:
                    test_data.append(rec)
                else:
                    raise ValueError("Patient id not in any set")
                    
        if self.mode == 'train':
            data_list = train_data
        elif self.mode in ['val', 'eval']:
            data_list = val_data
        else:
            data_list = test_data
        
        return data_list      
    
class MURA(BaseSet):
    
    img_channels = 1
    is_multiclass = True
    task = 'classification'
    mean = (0.206,)
    std = (0.177,)
    int_to_labels = {0: "Normal", 
                     1: "Abnormal"}
    target_metric = 'cohen_kappa' 
    knn_nhood = 200    
    n_classes = len(int_to_labels)
    labels_to_int = {val: key for key, val in int_to_labels.items()}
    
    def __init__(self, dataset_params, mode='train'):
        self.attr_from_dict(dataset_params)
        self.dataset_location = DATA_INFO["MURA"]["dataset_location"]
        self.root_dir = os.path.join(self.data_location, self.dataset_location)
        self.mode = mode
        self.data = self.get_data_as_list()
        self.transform, self.resizing = self.get_transforms()
        
    def get_data_as_list(self):
        data_list = []
        rel_dir = "MURA-v1.1/"
        label_mapping = {"negative":0, "positive":1}
        train_paths = pd.read_csv(os.path.join(self.root_dir, 'train_image_paths.csv'), engine='python')
        test_paths = pd.read_csv(os.path.join(self.root_dir, 'valid_image_paths.csv'), engine='python')  
        # get relative img paths
        train_paths = train_paths.iloc[:,0].tolist()
        test_paths = test_paths.iloc[:,0].tolist()       
        # convert to absolute paths
        train_paths = [os.path.join(self.root_dir, rec.split(rel_dir)[-1]) for rec in train_paths]
        test_paths = [os.path.join(self.root_dir, rec.split(rel_dir)[-1]) for rec in test_paths]
        # get labels as int
        train_labels = [os.path.basename(os.path.dirname(rec)).split("_")[-1] for rec in train_paths]
        test_labels = [os.path.basename(os.path.dirname(rec)).split("_")[-1] for rec in test_paths]
        # convert labels to int
        train_labels = [label_mapping[rec] for rec in train_labels]
        test_labels = [label_mapping[rec] for rec in test_labels]
        assert len(train_labels) == len(train_paths)
        assert len(test_paths) == len(test_paths)
        train_dict = {}
        for p, l in zip(train_paths, train_labels):
            pid = p.split('/')[-3]
            if pid not in train_dict:
                train_dict[pid] = {"path":[], "label":[]}
            train_dict[pid]["path"].append(p)
            train_dict[pid]["label"].append(l)

        # train/val plits
        pids = list(train_dict.keys())
        pids.sort()
        val_id_json = os.path.join(self.root_dir, 'val_ids.json')
        train_ids, val_ids = self.get_validation_ids(total_size=len(pids), val_size=0.1, 
                                                          json_path=val_id_json, 
                                                          dataset_name=self.name)
        train_pids = [pids[pid] for pid in train_ids]
        val_pids = [pids[pid] for pid in val_ids]
        train_paths = []
        train_labels = []
        val_paths = []
        val_labels = []
        for pid in train_pids:
            train_paths += train_dict[pid]["path"]
            train_labels += train_dict[pid]["label"] 
        for pid in val_pids:
            val_paths += train_dict[pid]["path"]
            val_labels += train_dict[pid]["label"]                   

        if self.mode == 'train':
            paths, labels = train_paths, train_labels
        elif self.mode in ['val', 'eval']:
            paths, labels = val_paths, val_labels
        else:
            paths, labels = test_paths, test_labels
        data_list = [{'img_path':img_path, 'label':label, 'dataset':self.name} 
                     for img_path, label in zip(paths, labels)]
                    
        return data_list    
    
    
class Camelyon(BaseSet):
    
    img_channels = 3
    is_multiclass = True
    task = 'classification'    
    knn_nhood = 200    
    mean = [0.69991202, 0.53839318, 0.69108667]
    std = [0.2308955 , 0.27435492, 0.20865005]
    int_to_labels = {0: 'Normal',
                    1: 'Tumor'}   
    target_metric = 'roc_auc'    
    n_classes = len(int_to_labels)
    labels_to_int = {val: key for key, val in int_to_labels.items()}
    
    def __init__(self, dataset_params, mode='train'):
        self.attr_from_dict(dataset_params)
        self.dataset_location = DATA_INFO["Camelyon"]["dataset_location"]
        self.root_dir = os.path.join(self.data_location, self.dataset_location)
        self.mode = mode
        self.data = self.get_data_as_list()
        self.transform, self.resizing = self.get_transforms()
            
    def get_data_as_list(self):
        data_list = []  
        prefix = "camelyonpatch_level_2_split_"
        if self.mode == "train":
            mod = "train"
        elif self.mode in ["val", "eval"]:
            mod = "valid"
        elif self.mode == "test":
            mod = "test"
        else:
            raise ValueError(f"Mode {self.mode} not understood")
                    
        h5_data = os.path.join(self.root_dir, f"{prefix}{mod}_x.h5")
        h5_labels = os.path.join(self.root_dir, f"{prefix}{mod}_y.h5")  
        meta = pd.read_csv(os.path.join(self.root_dir, f"{prefix}{mod}_meta.csv"), engine='python')
        png_data_path = os.path.join(self.root_dir, f"{mod}_png_files")
        labels_data_path = os.path.join(self.root_dir, f"{mod}_labels.json") 
        check_dir(png_data_path)
        
        img_paths = [os.path.join(png_data_path, f"img_{img_i}.png") for img_i in range(len(meta))]
        
        data = None
        labels = None
        if os.path.exists(h5_data):
            data = h5py.File(h5_data,'r')
        if os.path.exists(h5_labels):
            labels = h5py.File(h5_labels,'r')['y']
        
        # getting the labels
        if os.path.exists(labels_data_path) and (len(load_json(labels_data_path)) == len(meta)):
            labels = load_json(labels_data_path)
        else:
            if is_rank0(torch.cuda.current_device()):
                print("Extracting labels from h5")
                labels = [int(lbl.flatten()[0]) for lbl in h5py.File(h5_labels,'r')['y']]
                save_json(labels, labels_data_path)
            synchronize()
            labels = load_json(labels_data_path)
        
        # getting the data
        if sum([os.path.exists(p) for p in img_paths]) == len(meta):
            pass
        elif isinstance(data, h5py._hl.files.File):
            if is_rank0(torch.cuda.current_device()):
                data = data['x']                
                def deserialize(img, img_path, image_size=(512,512),
                                interpolation_method=Image.LANCZOS, resize=False):
                    if os.path.exists(img_path):
                        return
                    img = Image.fromarray(img)
                    if resize:
                        img = img.resize(image_size, resample=interpolation_method)
                    img.save(img_path)
                print("Converting h5 container to PNG data.. This might take a while")
                for idx in tqdm(range(len(data))):
                    deserialize(data[idx], img_paths[idx])
            synchronize()
            return self.get_data_as_list()
        else:
            raise ValueError("Data: empty container or incorect path")            
            
        data_list = [{'img_path':img_path, 'label':label, 'dataset':self.name} 
                     for img_path, label in zip(img_paths, labels)]
                    
        return data_list      
    
