import pandas as pd
import numpy as np
from pathlib import Path
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
import copy
from sklearn.preprocessing import StandardScaler
from transformers import AutoTokenizer

MODALITY = ["text", "timeseries"]

def Give(opt, label_names, datapath):
    datapath = Path(datapath)
    meta_df = pd.read_pickle(datapath/'meta_df.pkl')
    ts_df = pd.read_pickle(datapath/'ts_df.pkl')

    opt.label_names = label_names

    if opt.exclusive:
        mask = meta_df[opt.label_names].sum(axis = 1) == 1
        meta_df = meta_df[mask]
        ts_df = ts_df.loc[pd.IndexSlice[meta_df.index, :], :]

    if opt.debug:
        meta_df = meta_df.iloc[:2000]
        ts_df = ts_df.loc[pd.IndexSlice[meta_df.index, :], :]
    
    # filter out samples with no positive labels. Remove once multiproxynca loss is fixed.
    mask = meta_df[opt.label_names].sum(axis = 1) > 0
    meta_df = meta_df[mask]
    ts_df = ts_df.loc[pd.IndexSlice[meta_df.index, :], :]

    train_idx, test_idx = train_test_split(np.arange(len(meta_df)), test_size=(1-opt.tv_split_perc))
    val_idx, test_idx = train_test_split(test_idx, test_size=0.5)

    train_ids, val_ids, test_ids = (
        list(meta_df.iloc[train_idx].index),
        list(meta_df.iloc[val_idx].index),
        list(meta_df.iloc[test_idx].index),
    )
    train_ts, val_ts, test_ts = ts_df.loc[pd.IndexSlice[train_ids, :], :], ts_df.loc[pd.IndexSlice[val_ids, :], :], ts_df.loc[pd.IndexSlice[test_ids, :], :]
    train_mean, train_std = train_ts.mean(axis = 0), train_ts.std(axis = 0)
    train_ts = (train_ts - train_mean)/train_std
    val_ts = (val_ts - train_mean)/train_std    
    test_ts = (test_ts - train_mean)/train_std

    train_dataset = MIMICDataset(meta_df.iloc[train_idx], train_ts, opt, **vars(opt)) # 
    val_dataset = MIMICDataset(meta_df.iloc[val_idx], val_ts, opt, **vars(opt))
    test_dataset = MIMICDataset(meta_df.iloc[test_idx], test_ts, opt, **vars(opt))

    return {'training':train_dataset, 'validation':val_dataset, 'testing':test_dataset, 'evaluation':copy.deepcopy(train_dataset), 'evaluation_train': copy.deepcopy(train_dataset)}


class MIMICDataset(Dataset):
    def __init__(self, meta_df, ts_df, opt, ts_format = 'ts', **kwargs):
        '''
        ts_format: how to encode the time series
            "concat": concatenates all 48 hours into a long (~1100) vector
            "ts": return the time series with the temporal dimension intact

        text_format: how to encode the notes
            "bert": use BlueBERT encodings
            "raw": return the text as a string
        '''
        super().__init__()
        self.meta_df = meta_df
        self.ts_df = ts_df
        self.ts_format = ts_format
        self.opt = opt
        image_dict = {}
        for c, i in enumerate(self.opt.label_names):
            image_dict[c] = self.meta_df[self.meta_df[i] == 1].reset_index()[['ID']].reset_index()[['ID', 'index']].values.tolist()

        self.image_dict = image_dict
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

        self.init_setup()
        
    def __getitem__(self, idx):
        row = self.meta_df.iloc[idx]
        ID = row.name
        ts = self.ts_df.loc[pd.IndexSlice[ID, :], :].values.astype('float32')

        if self.ts_format == 'concat':
            ts = ts.flatten()
        elif self.ts_format == 'ts':
            pass
            
        text = row['TEXT']
        text_enc = self.tokenizer.encode_plus(text, return_tensors="pt", max_length = 512,
                                        padding = 'max_length', truncation = True)
        input_ids, token_type_ids, attention_mask  = text_enc['input_ids'][0, :], text_enc['token_type_ids'][0, :], text_enc['attention_mask'][0, :]

        if self.opt.exclusive:
            labels = (row[self.opt.label_names].values).astype(int).nonzero()[0]
        else:
            labels = (row[self.opt.label_names].values).astype(int)

        if self.opt.multimodal:
            return {
                'labels': labels, 
                'x':ts, 
                'input_ids': input_ids, 
                'token_type_ids': token_type_ids,
                'attention_mask': attention_mask, 
                'idx': idx
            }
        else:
            if "timeseries" in self.opt.modality:
                return {
                    'labels': labels, 
                    'x':ts, 
                    'idx': idx
                }
            elif "text" in self.opt.modality:
                return {
                    'labels': labels, 
                    'input_ids': input_ids, 
                    'token_type_ids': token_type_ids,
                    'attention_mask': attention_mask, 
                    'idx': idx
                }


    def init_setup(self):
        self.n_files = np.sum([len(self.image_dict[key]) for key in self.image_dict.keys()])
        self.avail_classes = sorted(list(self.image_dict.keys()))
        
        df = pd.concat([pd.DataFrame(self.image_dict[key]).assign(label = key) for key in self.image_dict])
        df.columns = ['path', 'idx', 'label']
        df = df.sort_values(by = 'idx', ascending = True)
        idx_mapping = df.drop_duplicates(subset = ['path']).set_index('path')['idx']
        agg_list = df.groupby('path').agg({'label': lambda x: list(x)}).loc[idx_mapping.index].reset_index().to_numpy().tolist()
        
        if self.opt.exclusive:
            for i in agg_list:
                assert len(i[-1]) == 1 # one label per image
                i[-1] = i[-1][0]
        else:
            for i in agg_list:
                i[-1] = np.array(i[-1])
        self.image_list = agg_list
        self.image_paths = self.image_list
        self.is_init = True

    def __len__(self):
        return len(self.meta_df)
