import pandas as pd
import pickle
import os
import torch
import numpy as np
from pathlib import Path
from sklearn.model_selection import train_test_split
from torchvision import transforms
from datasets.basic_dataset_scaffold import MultiClassDataset, MultiLabelDataset

MODALITY = ["vision"]
# DEFAULT_NAMES = ['Atelectasis', 'Pneumonia', 'Pleural Effusion', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumothorax', 'No Finding']

def preprocess_mimic(opt, datapath, only_frontal=True):
    datapath = Path(datapath)

    patients = pd.read_csv(datapath/'patients.csv.gz')
    ethnicities = pd.read_csv(datapath/'admissions.csv.gz').drop_duplicates(subset = ['subject_id']).set_index('subject_id')['ethnicity'].to_dict()
    patients['ethnicity'] = patients['subject_id'].map(ethnicities)
    labels = pd.read_csv(datapath/'mimic-cxr-2.0.0-negbio.csv.gz')
    meta = pd.read_csv(datapath/'mimic-cxr-2.0.0-metadata.csv.gz')

    df = meta.merge(patients, on = 'subject_id').merge(labels, on = ['subject_id', 'study_id'])
    df['age_decile'] = pd.cut(df['anchor_age'], bins = list(range(0, 101, 10))).apply(lambda x: f'{x.left}-{x.right}').astype(str)
    df['frontal'] = df.ViewPosition.isin(['AP', 'PA'])

    df['path'] = df.apply(lambda x: os.path.join(opt.source_path, 'files', f'p{str(x["subject_id"])[:2]}', f'p{x["subject_id"]}', f's{x["study_id"]}', f'{x["dicom_id"]}.jpg'), axis = 1)

    df.replace({col: {-1: 0} for col in opt.label_names}, inplace=True)
    df.fillna({col: 0 for col in opt.label_names}, inplace=True)

    if only_frontal:
        df = df[df.frontal]

    df = df.astype({label_name: int for label_name in opt.label_names})
    df.rename(mapper={'gender': 'sex'}, axis=1, inplace=True)
    df = df[['subject_id', 'study_id', 'path', 'sex', 'age_decile', 'ethnicity', 'Atelectasis', 'Pneumonia', 'Pleural Effusion', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumothorax', 'No Finding'] + opt.label_names]
    df = df.loc[:,~df.columns.duplicated()]

    if opt.exclusive:
        df = df.loc[df[opt.label_names].sum(axis=1) <= 1]

    ## train - val - test split 8 - 1 - 1
    train_idx, test_idx = train_test_split(np.arange(len(df)), test_size=(1-opt.tv_split_perc))
    val_idx, test_idx = train_test_split(test_idx, test_size=0.5)

    train_df, val_df, test_df = df.iloc[train_idx, :], df.iloc[val_idx, :], df.iloc[test_idx, :]

    if opt.debug:
        train_df = train_df.iloc[:1024]
        val_df = val_df.iloc[:512]
        test_df = test_df.iloc[:512]

    return train_df, val_df, test_df
        
def Give(opt, label_names, datapath):
    opt.label_names = label_names

    if not opt.no_cache:
        opt.cache_dir = os.path.join(opt.cache_dir, opt.dataset)
        os.makedirs(opt.cache_dir, exist_ok=True)

    train_metadata, val_metadata, test_metadata = preprocess_mimic(opt, datapath)

    if len(opt.label_names) == 1:
        label_name = opt.label_names[0]
        train_metadata["Not {}".format(label_name)] = np.logical_not(train_metadata[label_name].values).astype(int)
        val_metadata["Not {}".format(label_name)] = np.logical_not(val_metadata[label_name].values).astype(int)
        test_metadata["Not {}".format(label_name)] = np.logical_not(test_metadata[label_name].values).astype(int)
        opt.label_names = ["Not {}".format(label_name), label_name]
        ## TODO: raise warning if not exclusive
    
    conversion = dict(enumerate(opt.label_names))

    image_size = 224
    transform_list = []
    if opt.augmentation:
        transform_list.extend([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15)
        ])
    transform_list.extend([
            transforms.CenterCrop([image_size, image_size]),
            transforms.ToTensor()
            ])
    if not opt.not_pretrained:
        transform_list.append(transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]))
    
    train_transform = transforms.Compose(transform_list)
    eval_transform = transforms.Compose([
        transforms.ToTensor()
    ])

    if opt.exclusive:
        if not opt.use_tv_split:
            train_metadata = pd.concat([train_metadata, val_metadata])
            val_dataset = None
        else:
            val_image_dict = {key: val_metadata['path'].loc[val_metadata[value].values.astype(bool)].values.tolist() for key, value in conversion.items()}
            val_dataset = MultiClassDataset(metadata=val_metadata, label_names=opt.label_names, image_dict=val_image_dict, conversion=conversion, transform=eval_transform, image_size=image_size, cache=not opt.no_cache, cache_dir=opt.cache_dir)

        train_image_dict = {key: train_metadata['path'].loc[train_metadata[value].values.astype(bool)].values.tolist() for key, value in conversion.items()}
        train_dataset = MultiClassDataset(metadata=train_metadata, label_names=opt.label_names, image_dict=train_image_dict, transform=train_transform, image_size=image_size, conversion=conversion, cache=not opt.no_cache, cache_dir=opt.cache_dir)

        eval_dataset = MultiClassDataset(metadata=train_metadata, label_names=opt.label_names, image_dict=train_image_dict, transform=eval_transform, image_size=image_size, conversion=conversion, cache=not opt.no_cache, cache_dir=opt.cache_dir)
        eval_train_dataset = MultiClassDataset(metadata=train_metadata, label_names=opt.label_names, image_dict=train_image_dict, transform=train_transform, image_size=image_size, conversion=conversion, cache=not opt.no_cache, cache_dir=opt.cache_dir)

        test_image_dict = {key: test_metadata['path'].loc[test_metadata[value].values.astype(bool)].values.tolist() for key, value in conversion.items()}
        test_dataset = MultiClassDataset(metadata=test_metadata, label_names=opt.label_names, image_dict=test_image_dict, conversion=conversion, transform=eval_transform, image_size=image_size, cache=not opt.no_cache, cache_dir=opt.cache_dir)
    else:
        if not opt.use_tv_split:
            train_metadata = pd.concat([train_metadata, val_metadata])
            val_dataset = None
        else:
            val_dataset = MultiLabelDataset(metadata=val_metadata, label_names=opt.label_names, conversion=conversion, transform=eval_transform, image_size=image_size, cache=not opt.no_cache, cache_dir=opt.cache_dir)

        train_dataset = MultiLabelDataset(metadata=train_metadata, label_names=opt.label_names, conversion=conversion, transform=train_transform, image_size=image_size, cache=not opt.no_cache, cache_dir=opt.cache_dir)
        eval_dataset = MultiLabelDataset(metadata=train_metadata, label_names=opt.label_names, conversion=conversion, transform=eval_transform, image_size=image_size, cache=not opt.no_cache, cache_dir=opt.cache_dir)
        eval_train_dataset = MultiLabelDataset(metadata=train_metadata, label_names=opt.label_names, conversion=conversion, transform=train_transform, image_size=image_size, cache=not opt.no_cache, cache_dir=opt.cache_dir)
        test_dataset = MultiLabelDataset(metadata=test_metadata, label_names=opt.label_names, conversion=conversion, transform=eval_transform, image_size=image_size, cache=not opt.no_cache, cache_dir=opt.cache_dir)

    return {'training':train_dataset, 'validation':val_dataset, 'testing':test_dataset, 'evaluation':eval_dataset, 'evaluation_train':eval_train_dataset}
