from copy import copy
import pandas as pd
import pickle
import os
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from sklearn.model_selection import train_test_split
from torchvision import transforms
from datasets.basic_dataset_scaffold import MultiClassDataset, MultiLabelDataset

MODALITY = ["vision"]

def preprocess_chexpert(opt, datapath, only_frontal=True):
    datapath = Path(datapath)
    df = pd.read_csv(datapath/"map.csv")

    df['subject_id'] = df['Path'].apply(lambda x: int(Path(x).parent.parent.name[7:])).astype(str)
    details = pd.read_csv(datapath/"chexpert_demographics.csv")[['PATIENT', 'PRIMARY_RACE', 'ETHNICITY']]
    details['subject_id'] = details['PATIENT'].apply(lambda x: x[7:]).astype(int).astype(str)    
    
    df = pd.merge(df, details, on = 'subject_id', how = 'inner')

    copy_subjectid = df['subject_id']
    df.drop(columns = ['subject_id'], inplace=True)

    df.replace(
            [[None], None, -1, '[False]', '[True]', '[ True]', 'UNABLE TO OBTAIN', 'UNKNOWN'],
            [0, 0, 0, 0, 1, 1, 0, 0], inplace=True)

    df['subject_id'] = copy_subjectid

    if only_frontal:
        df = df[df['Frontal/Lateral'] == 'Frontal']

    df = df.astype({label_name: int for label_name in opt.label_names})
    df.rename({'PRIMARY_RACE': 'Primary Race', 'ETHNICITY': 'Ethnicity', 'Path': 'path'}, axis=1, inplace=True)
    df = df[['subject_id','path', 'Sex', 'Age', 'Ethnicity', 'Primary Race', 'Atelectasis', 'Pneumonia', 'Pleural Effusion', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumothorax', 'No Finding'] + opt.label_names]
    df = df.loc[:,~df.columns.duplicated()]
    
    if opt.exclusive:
        df = df.loc[df[opt.label_names].sum(axis=1) <= 1]

    ## train - val - test split 8 - 1 - 1
    train_idx, test_idx = train_test_split(np.arange(len(df)), test_size=(1-opt.tv_split_perc))
    val_idx, test_idx = train_test_split(test_idx, test_size=0.5)

    train_df, val_df, test_df = df.iloc[train_idx, :], df.iloc[val_idx, :], df.iloc[test_idx, :]
    return train_df, val_df, test_df
        
def Give(opt, label_names, datapath):
    opt.label_names = label_names

    opt.cache_dir = os.path.join(opt.cache_dir, opt.dataset)
    os.makedirs(opt.cache_dir, exist_ok=True)

    train_metadata, val_metadata, test_metadata = preprocess_chexpert(opt, datapath)

    if len(opt.label_names) == 1:
        label_name = opt.label_names[0]
        train_metadata["Not {}".format(label_name)] = np.logical_not(train_metadata[label_name].values).astype(int)
        val_metadata["Not {}".format(label_name)] = np.logical_not(val_metadata[label_name].values).astype(int)
        test_metadata["Not {}".format(label_name)] = np.logical_not(test_metadata[label_name].values).astype(int)
        opt.label_names = ["Not {}".format(label_name), label_name]
        ## TODO: raise warning if not exclusive
    
    conversion = dict(enumerate(opt.label_names))

    image_size = 224
    transform_list = []
    if opt.augmentation:
        transform_list.extend([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15)
        ])
    transform_list.extend([
            transforms.CenterCrop([image_size, image_size]),
            transforms.ToTensor()
            ])
    if not opt.not_pretrained:
        transform_list.append(transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]))
    
    train_transform = transforms.Compose(transform_list)
    eval_transform = transforms.Compose([
        transforms.ToTensor()
    ])

    if opt.exclusive:
        if not opt.use_tv_split:
            train_metadata = pd.concat([train_metadata, val_metadata])
            val_dataset = None
        else:
            val_image_dict = {key: val_metadata['path'].loc[val_metadata[value].values.astype(bool)].values.tolist() for key, value in conversion.items()}
            MultiClassDataset()
            val_dataset = MultiClassDataset(metadata=val_metadata, label_names=opt.label_names, image_dict=val_image_dict, conversion=conversion, transform=eval_transform, image_size=image_size, cache=True, cache_dir=opt.cache_dir)

        train_image_dict = {key: train_metadata['path'].loc[train_metadata[value].values.astype(bool)].values.tolist() for key, value in conversion.items()}
        train_dataset = MultiClassDataset(metadata=train_metadata, label_names=opt.label_names, image_dict=train_image_dict, transform=train_transform, image_size=image_size, conversion=conversion, cache=True, cache_dir=opt.cache_dir)

        eval_dataset = MultiClassDataset(metadata=train_metadata, label_names=opt.label_names, image_dict=train_image_dict, transform=eval_transform, image_size=image_size, conversion=conversion, cache=True, cache_dir=opt.cache_dir)
        eval_train_dataset = MultiClassDataset(metadata=train_metadata, label_names=opt.label_names, image_dict=train_image_dict, transform=train_transform, image_size=image_size, conversion=conversion, cache=True, cache_dir=opt.cache_dir)

        test_image_dict = {key: test_metadata['path'].loc[test_metadata[value].values.astype(bool)].values.tolist() for key, value in conversion.items()}
        test_dataset = MultiClassDataset(metadata=test_metadata, label_names=opt.label_names, image_dict=test_image_dict, conversion=conversion, transform=eval_transform, image_size=image_size, cache=True, cache_dir=opt.cache_dir)
    else:
        if not opt.use_tv_split:
            train_metadata = pd.concat([train_metadata, val_metadata])
            val_dataset = None
        else:
            val_dataset = MultiLabelDataset(metadata=val_metadata, label_names=opt.label_names, conversion=conversion, transform=eval_transform, image_size=image_size, cache=True, cache_dir=opt.cache_dir)

        train_dataset = MultiLabelDataset(metadata=train_metadata, label_names=opt.label_names, conversion=conversion, transform=train_transform, image_size=image_size, cache=True, cache_dir=opt.cache_dir)
        eval_dataset = MultiLabelDataset(metadata=train_metadata, label_names=opt.label_names, conversion=conversion, transform=eval_transform, image_size=image_size, cache=True, cache_dir=opt.cache_dir)
        eval_train_dataset = MultiLabelDataset(metadata=train_metadata, label_names=opt.label_names, conversion=conversion, transform=train_transform, image_size=image_size, cache=True, cache_dir=opt.cache_dir)
        test_dataset = MultiLabelDataset(metadata=test_metadata, label_names=opt.label_names, conversion=conversion, transform=eval_transform, image_size=image_size, cache=True, cache_dir=opt.cache_dir)

    return {'training':train_dataset, 'validation':val_dataset, 'testing':test_dataset, 'evaluation':eval_dataset, 'evaluation_train':eval_train_dataset}
