import os
import torch
import pandas as pd
from PIL import Image
import numpy as np
from wilds.datasets.wilds_dataset import WILDSDataset
from wilds.common.grouper import CombinatorialGrouper
from wilds.common.metrics.all_metrics import Accuracy
from pathlib import Path
from sklearn.model_selection import train_test_split

class CXRDataset(WILDSDataset):
    _dataset_name = 'cxr'

    def __init__(self, root_dir='', 
            split_scheme='official', test_pct = 0.1, val_pct = 0.05, data_seed = None):
        # self._data_dir = Path(self.initialize_data_dir(root_dir))

        if data_seed is not None:
            state = np.random.get_state()
            np.random.seed(data_seed)

        # mimic_dir = Path(os.path.join(self._data_dir, "MIMIC-CXR-JPG"))
        # chexpert_dir = Path(os.path.join(self._data_dir, "CheXpert-v1.0-small"))
        mimic_dir = '/scratch/hdd001/projects/ml4h/projects/mimic_access_required/MIMIC-CXR-JPG'
        chexpert_dir = '/scratch/hdd001/projects/ml4h/projects/CheXpert_small/CheXpert-v1.0-small'
        self._data_dir = chexpert_dir

        # assert (f'{mimic_dir}/mimic-cxr-2.0.0-metadata.csv.gz').is_file()
        # assert (f'{mimic_dir}/files/p19/p19316207/s55102753/31ec769b-463d6f30-a56a7e09-76716ec1-91ad34b6.jpg').is_file()
        # assert (f'{chexpert_dir}/train.csv').is_file()
        # assert (f'{chexpert_dir}/train/patient48822/study1/view1_frontal.jpg').is_file()
        # assert (f'{chexpert_dir}/valid/patient64636/study1/view1_frontal.jpg').is_file()

        labels_mimic = pd.read_csv(f'{mimic_dir}/mimic-cxr-2.0.0-negbio.csv.gz')
        meta_mimic = pd.read_csv(f'{mimic_dir}/mimic-cxr-2.0.0-metadata.csv.gz')

        df_mimic = meta_mimic.merge(labels_mimic, on=['subject_id', 'study_id'])

        df_mimic['path'] = df_mimic.apply(
            lambda x: os.path.join(
                mimic_dir,
                'files', f'p{str(x["subject_id"])[:2]}', f'p{x["subject_id"]}', f's{x["study_id"]}', f'{x["dicom_id"]}.jpg'
            ), axis=1)

        df_mimic['a'] = 0

        df_cxp = pd.concat([pd.read_csv(f'{chexpert_dir}/train.csv'), 
                        pd.read_csv(f'{chexpert_dir}/valid.csv')],
                        ignore_index = True)

        df_cxp['path'] = df_cxp['Path'].astype(str).apply(lambda x: os.path.join(chexpert_dir, x[x.index('/')+1:]))
        df_cxp['a'] = 1

        cols_to_take = ['path', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum',
        'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding',
        'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'a']

        df = pd.concat((df_mimic[cols_to_take],
                        df_cxp[cols_to_take]
                        ), ignore_index = True)
        df['y'] = df['Pneumonia'].fillna(0).astype(int).map({-1:0, 1:1, 0:0})

        # 90% of healthy patients are from MIMIC-CXR
        healthy = df[df.y == 0]
        n_to_keep = int((healthy.a == 0).sum()/0.9) - (healthy.a == 0).sum()
        healthy_90 = pd.concat((healthy[healthy.a == 0], healthy[healthy.a == 1].sample(n = n_to_keep, random_state = 42, replace = False)), 
                    ignore_index = True)

        # 90% of sick patients are from CheXpert
        sick = df[df.y == 1]
        n_to_keep = int(sick.a.sum()/0.9) - sick.a.sum()
        sick_90 = pd.concat((sick[sick.a == 1], sick[sick.a == 0].sample(n = n_to_keep, random_state = 42, replace = False)), 
                    ignore_index = True)

        df = pd.concat((healthy_90, sick_90), ignore_index = True)

        train_val_idx, test_idx = train_test_split(df.index, test_size=test_pct, random_state=42, stratify=df['y'])
        train_idx, val_idx = train_test_split(
            train_val_idx, test_size=val_pct/(1-test_pct), random_state=42, stratify=df.loc[train_val_idx, 'y'])

        df['split'] = 0
        df.loc[val_idx, 'split'] = 1
        df.loc[test_idx, 'split'] = 2
        
        df = df[['path', 'y', 'split', 'a']]
        self._y_array = torch.LongTensor(df['y'].values)
        self._y_size = 1
        self._n_classes = 2

        self._metadata_array = torch.stack(
            (torch.LongTensor(df['a'].values), self._y_array),
            dim=1
        )
        self._metadata_fields = ['a', 'y']
        self._metadata_map = {
            'a': ['MIMIC-CXR', 'CheXpert'], 
            'y': ['No Pneumonia', 'Pneumonia']
        }

        self._original_resolution = (224, 224) # as images are different sizes, we resize everything to 224 x 224

        self._eval_grouper = CombinatorialGrouper(
            dataset=self,
            groupby_fields=(['a', 'y']))
        
        self._split_scheme = split_scheme
        self._split_array = df['split']

        self.df = df

        if data_seed is not None:
            np.random.set_state(state)

        super().__init__(self._data_dir, split_scheme)


    def get_input(self, idx):
       img_filename =self.df.iloc[idx]['path']
       x = Image.open(img_filename).convert('RGB').resize((self._original_resolution))
       return x

    def eval(self, y_pred, y_true, metadata, prediction_fn=None):
        """
        Computes all evaluation metrics.
        Args:
            - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor).
                               But they can also be other model outputs such that prediction_fn(y_pred)
                               are predicted labels.
            - y_true (LongTensor): Ground-truth labels
            - metadata (Tensor): Metadata
            - prediction_fn (function): A function that turns y_pred into predicted labels 
        Output:
            - results (dictionary): Dictionary of evaluation metrics
            - results_str (str): String summarizing the evaluation metrics
        """
        metric = Accuracy(prediction_fn=prediction_fn)
        return self.standard_group_eval(
            metric,
            self._eval_grouper,
            y_pred, y_true, metadata)
