import numpy
import pandas
import ast
import time
import copy
import jax
from jax import numpy as jnp

from ProtLig_GPCRclassA.mol2graph.jraph.convert import smiles_to_jraph
from ProtLig_GPCRclassA.mol2graph.exceptions import NoBondsError

from ProtLig_GPCRclassA.utils import create_line_graph, Sequence, Label, Concentration
from ProtLig_GPCRclassA.base_loader import BaseDataset


# ----------------------------------
# Precomputed amino acids embedding:
# ----------------------------------
class AminoConcentrationDatasetPrecompute(BaseDataset):
    """
    consider introducing mol_buffer to save already preprocessed graphs.
    """
    def __init__(self, data_csv, mols_csv, seq_id_col, mol_id_col, mol_col, parameter_col, n_ec50_copies, conc_value_col, conc_value_screen_col, label_col, weight_col = None, seqs_csv = None,
                atom_features = ['AtomicNum'], bond_features = ['BondType'], 
                oversampling_function = None,
                line_graph_max_size = None, # 10 * padding_n_node
                line_graph = True,
                class_alpha = None,
                sample_concentration_and_label = None,
                element_preprocess = None,
                **kwargs):
        """
        Parameters:
        -----------
        data_csv : str
            path to csv

        seq_id_col : str
            name of the column with protein IDs

        mol_col : str
            name of the column with smiles
        
        label_col : str
            name of the column with labels (in case of multilabel problem, 
            labels should be in one column)

        atom_features : list
            list of atom features.

        bond_features : list
            list of bond features.

        **kwargs
            IncludeHs
            sep
            seq_sep

        Notes:
        ------
        sequence representations are retrived in Collate.
        """
        self.data_csv = data_csv
        self.mols_csv = mols_csv
        self.seqs_csv = seqs_csv
        self.sep = kwargs.get('sep', ';')

        self.conc_parameter_id_map = {'ec50_nd' : 0,
                                  'ec50_greater_than' : 1,
                                  'ec50' : 2,
                                  'screening' : 3}
        self.conc_parameter_id_inverse_map = {val : key for key,val in self.conc_parameter_id_map.items()}

        # self.seq_json = seq_json
        # self.poc_csv = poc_csv

        self.seq_id_col = seq_id_col # 'Gene'
        self.mol_id_col = mol_id_col
        self.mol_col = mol_col # 'SMILES'
        self.parameter_col = parameter_col
        self.n_ec50_copies = n_ec50_copies
        self.conc_value_col = conc_value_col
        self.conc_value_screen_col = conc_value_screen_col
        self.label_col = label_col # 'Responsive'
        self.weight_col = weight_col
        # if weight_col is None:
        #     self._weight_col = []
        # else:
        #     self._weight_col = [weight_col]

        self.class_alpha = class_alpha
        
        self.oversampling_function = oversampling_function

        self.IncludeHs = kwargs.get('IncludeHs', False)
        self.self_loops = kwargs.get('self_loops', False)

        # Exclude parameters:
        self.include_conc_parameter_list = kwargs.get('include_conc_parameter_list', None)

        # Additional information:
        self.auxiliary_label_cols = kwargs.get('auxiliary_label_cols', [])
        self.auxiliary_weight_cols = kwargs.get('auxiliary_weight_cols', [])
        self.mol_global_cols = kwargs.get('mol_global_cols', [])
        self.seq_global_cols = kwargs.get('seq_global_cols', [])

        # Adjusted_class_weight:
        if weight_col == '_adjusted_class_weight':
            self.weight_col = None

        if len(self.mol_global_cols) > 0:
            print(' -----'*8, '\nWARNING: ast.literal_eval is used in _read_graph!\n', '----- '*8)        

        self.atom_features = atom_features
        self.bond_features = bond_features

        self.line_graph = line_graph
        self.line_graph_max_size = line_graph_max_size # 10 * mid_padding_n_node

        legacy_version = kwargs.get('legacy_version', False)
        if legacy_version:
            raise NotImplementedError('Use older git chekpoint for legacy version...')

        self.sample_concentration_and_label = sample_concentration_and_label
        self.element_preprocess = element_preprocess

        self.mean_ec50 = None
        self.std_ec50 = None

        self.data = self.read()
        # self.df = self.data.copy()

    def _read_graph(self, x):
        """
        """
        try:
            if len(self.mol_global_cols) > 0:
                u = {col : x[col] for col in self.mol_global_cols}
            else:
                u = None
            G = smiles_to_jraph(x[self.mol_col], u = u, validate = False, IncludeHs = self.IncludeHs,
                            atom_features = self.atom_features, bond_features = self.bond_features,
                            self_loops = self.self_loops)
        except NoBondsError:
            return float('nan')
        except AssertionError:
            print(x[self.mol_col])
            print('WARNING: having this except is short term solution. Discard this data in data preprocessing!!\n--------')
            return float('nan')
        # return G
        if self.line_graph:
            raise NotImplementedError('Logic for line graph will change entirely.')
            return (G, create_line_graph(G, max_size = self.line_graph_max_size))
        else:
            return (G, )

    @staticmethod
    def _cast(x):
        if isinstance(x, str):
            _x = x.strip().lower()
            if _x[0] == '[' and _x[-1] == ']': # expecting list..
                _x_content = _x[1:-1]
                if len(_x_content) > 0:
                    if ', ' in _x_content:
                        return numpy.array([float(ele) for ele in _x_content.split(', ')])
                    elif ',' in _x_content:
                        return numpy.array([float(ele) for ele in _x_content.split(',')])
                else:
                    return numpy.array([])
                # return numpy.array(ast.literal_eval(_x))
            else:
                return x
        elif numpy.isnan(x).any():
            return x
        elif isinstance(x, int):
            return numpy.array([x], dtype = numpy.int32)
        elif isinstance(x, float):
            return numpy.array([x], dtype = numpy.float32)
        else:
            return x

    @staticmethod
    def _get_nan_default(series):
        for x in series.items():
            idx, val = x
            if not numpy.isnan(val).any():
                return val
        raise ValueError('All values in column {} are NaN.'.format(series.name))

    @staticmethod
    def _get_nan_mask(x, default_val):
        if numpy.isnan(x).any():
            return {'value' : default_val, 'mask' : False}
        else:
            return {'value' : x, 'mask' : True}

    def _create_seq_dict(self, x):
        _seq_id = x[self.seq_id_col]
        _seq_id = str(_seq_id)
        seq = Sequence()
        # seq[_seq_id] = _seq_id
        seq['_seq_id'] = _seq_id
        for col in self.seq_global_cols:
            seq[col] = x[col]
        return seq
    
    def _create_conc_dict(self, x):
        conc_dict = Concentration()
        param = x[self.parameter_col]
        if param == 'ec50':
            val = x[self.conc_value_col]
            if val == 'n.d':
                conc_dict['value'] = numpy.array([0], dtype = numpy.float32)
                conc_dict['parameter'] = numpy.array([self.conc_parameter_id_map['ec50_nd']], dtype = numpy.int32)
            elif '>' in val:
                conc_dict['value'] = numpy.array([val[1:]], dtype = numpy.float32)
                conc_dict['parameter'] = numpy.array([self.conc_parameter_id_map['ec50_greater_than']], dtype = numpy.int32)
            else:    
                conc_dict['value'] = numpy.array([val], dtype = numpy.float32)
                conc_dict['parameter'] = numpy.array([self.conc_parameter_id_map['ec50']], dtype = numpy.int32)
        else:
            conc_dict['value'] = numpy.array([x[self.conc_value_screen_col]], dtype = numpy.float32)
            conc_dict['parameter'] = numpy.array([self.conc_parameter_id_map['screening']], dtype = numpy.int32)
        return conc_dict

    def _create_label_dict(self, x):
        label_dict = Label()
        label_dict['_main_label'] = x[self.label_col]
        if self.weight_col is not None:
            label_dict['_main_sample_weight'] = x[self.weight_col]
        for col in self.auxiliary_label_cols:
            val = x[col]
            label_dict[col] = val['value']
            label_dict[col + '_mask'] = val['mask']
        for col in self.auxiliary_weight_cols:
            val = x[col]
            label_dict[col] = val['value']
        return label_dict

    def _get_mask_for_include_conc_parameter_list(self, x):
        """
        NOTE: Implementation is not efficient now. It is called twice in _create_conc_dict and here.
        """
        param = x[self.parameter_col]
        if param == 'ec50':
            val = x[self.conc_value_col]
            if val == 'n.d':
                conc_parameter = 'ec50_nd'
            elif '>' in val:
                conc_parameter = 'ec50_greater_than'
            else:
                conc_parameter = 'ec50'
        else:
            conc_parameter = 'screening'
        return conc_parameter in self.include_conc_parameter_list


    def read(self):
        if isinstance(self.data_csv, pandas.DataFrame):
            df = self.data_csv # [[self.mol_col, self.seq_id_col, self.label_col] + self._weight_col]
        else:
            df = pandas.read_csv(self.data_csv, sep = self.sep) # , usecols = [self.mol_id_col, self.seq_id_col, self.label_col] + self._weight_col + self.auxiliary_label_cols)

        df_seq_ids = None
        df_mol_ids = pandas.Index(df[self.mol_id_col].unique())

        if self.seqs_csv is not None:
            # raise NotImplementedError('Overall structure is here, but check when there is a usecase. Dont forget to change the columns that are picked at the end of self.read.')
            seqs = pandas.read_csv(self.seqs_csv, sep = self.sep, index_col = self.seq_id_col)
            seqs = seqs.applymap(self._cast)
            df = df.join(seqs, on = self.seq_id_col, how = 'inner', rsuffix = '__seq')

        # mols_csv can not be None:
        mols = pandas.read_csv(self.mols_csv, sep = self.sep, index_col = self.mol_id_col)
        mols = mols.drop_duplicates()
        mols = mols.loc[mols.index.intersection(df_mol_ids)] # delete mols not in df:

        _mol_other_cols = mols.columns.difference(set([self.mol_col]))
        mols[_mol_other_cols] = mols[_mol_other_cols].applymap(self._cast)
        mols['_graphs'] = mols.apply(self._read_graph, axis = 1) # self.mol_global_cols used here...
        mols.dropna(subset = ['_graphs'], inplace = True)         

        df = df.join(mols, on = self.mol_id_col, how = 'inner', rsuffix = '__mol')

        if self.include_conc_parameter_list is not None:
            # Remove data with given parameters:
            mask_for_include_conc_parameter_list = df.apply(self._get_mask_for_include_conc_parameter_list, axis = 1)
            df = df[mask_for_include_conc_parameter_list]

        # Process auxiliary labels:
        if len(self.auxiliary_label_cols) > 0:
            df[self.auxiliary_label_cols] = df[self.auxiliary_label_cols].applymap(self._cast)
            for col in self.auxiliary_label_cols:
                default_val = self._get_nan_default(df[col])
                df[col] = df[col].apply(lambda x: self._get_nan_mask(x, default_val = default_val))

        if len(self.auxiliary_weight_cols) > 0:
            df[self.auxiliary_weight_cols] = df[self.auxiliary_weight_cols].applymap(self._cast)
            for col in self.auxiliary_weight_cols:
                default_val = self._get_nan_default(df[col])
                df[col] = df[col].apply(lambda x: self._get_nan_mask(x, default_val = default_val))

        df['_seq'] = df.apply(lambda x: self._create_seq_dict(x), axis = 1)
        df['_conc'] = df.apply(lambda x: self._create_conc_dict(x), axis = 1)
        df['_label'] = df.apply(lambda x: self._create_label_dict(x), axis = 1)

        df_ec50 = df[df[self.parameter_col] == 'ec50']
        n_all_original = len(df)
        n_ec50_original = len(df_ec50)

        # Get average EC50:
        df_responsive_numeric_ec50 = df_ec50[df_ec50[self.label_col] == 1]
        df_responsive_numeric_ec50 = df_responsive_numeric_ec50[df_responsive_numeric_ec50[self.conc_value_col] != 'n.d']
        df_responsive_numeric_ec50 = df_responsive_numeric_ec50[~df_responsive_numeric_ec50[self.conc_value_col].str.contains('>')]
        
        self.std_ec50 = df_responsive_numeric_ec50[self.conc_value_col].astype(float).std()
        self.mean_ec50 = df_responsive_numeric_ec50[self.conc_value_col].astype(float).mean()

        if self.n_ec50_copies > 1:
            df_ec50_augmented = pandas.concat([df_ec50]*(self.n_ec50_copies - 1))
        elif self.n_ec50_copies == 1:
            df_ec50_augmented = pandas.DataFrame([], columns = df_ec50.columns)
        else:
            raise ValueError('EC50 not enlarged. This may lead to a lot of unused data.')
        df = pandas.concat([df, df_ec50_augmented])
        n_all_augmented = len(df)
        n_ec50_augmented = n_ec50_original * self.n_ec50_copies
        print('Extending number of EC50 pairs from {:.2f}% to {:.2f}%'.format(100.0*(n_ec50_original/n_all_original), 100.0*(n_ec50_augmented/n_all_augmented)))
        print('Total number of pairs after ec50 augmentation: {}'.format(n_all_augmented))

        df = df[['_seq', '_graphs', '_conc', '_label']]

        # print(df[(df['_label'].apply(lambda x: x['_main_label'] == 0))&(df['_conc'].apply(lambda x: x['value'] != 'n.d' and x['parameter'] == 'ec50'))])
        # raise Exception('test')

        if self.oversampling_function is not None:
            raise NotImplementedError('oversampling is not implemented...')
            df = self.oversampling_function(df, label_col = self.label_col)
        
        return df            

    def _get_class_weight_map(self, class_dist):
        n_samples = sum(class_dist.values())
        n_classes = len(class_dist)
        class_weight_map = {key : n_samples/(n_classes*class_dist[key]) for key in class_dist.keys()}
        return class_weight_map

    def adjusted_class_weight_col(self, ec50_lower_margin, ec50_upper_margin, ec50_lower_extreme, ec50_upper_extreme):
        """
        """
        df = self.data.copy()

        N_ec50_pos = sum(df['_conc'].apply(lambda x: x['parameter'][0] == self.conc_parameter_id_map['ec50']))
        N_ec50_neg = sum(df['_conc'].apply(lambda x: x['parameter'][0] == self.conc_parameter_id_map['ec50_nd']))

        N_ec50_greater_pos = sum(df['_conc'].apply(lambda x: x['parameter'][0] == self.conc_parameter_id_map['ec50_greater_than']))

        df_screening = df[df['_conc'].apply(lambda x: x['parameter'][0] == self.conc_parameter_id_map['screening'])]
        N_screening_pos = sum(df_screening['_label'].apply(lambda x: x['_main_label'] == 1))
        N_screening_neg = sum(df_screening['_label'].apply(lambda x: x['_main_label'] == 0))

        ec50_sample_range_size = (ec50_upper_extreme - ec50_upper_margin) + (ec50_lower_extreme - ec50_lower_margin)
        ec50_active_rate = (ec50_upper_extreme - ec50_upper_margin) / ec50_sample_range_size

        n_pos = N_screening_pos + ec50_active_rate*N_ec50_pos
        n_neg = N_screening_neg + (1.0 - ec50_active_rate)*N_ec50_pos + N_ec50_neg + N_ec50_greater_pos

        class_dist = {0 : n_neg, 1 : n_pos}
        class_weight_map = self._get_class_weight_map(class_dist)

        def _update_label_dict(x):
            label_dict = x.copy()
            label_dict['_main_sample_weight'] = class_weight_map[x['_main_label']]
            return label_dict
        df['_label'] = df['_label'].apply(_update_label_dict)

        self.data = df.copy()
        return df

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        # TODO: previously jax DeviceArray and this raised assertionError 
        # in pandas. See if this behaviour changes in new padnas
        index = numpy.asarray(index)
        sample = self.data.iloc[index]
        
        mol = sample['_graphs']
        seq = sample['_seq']
        conc = sample['_conc']
        label = sample['_label']

        # print(conc_raw)
        if self.sample_concentration_and_label is not None:
            conc, label = self.sample_concentration_and_label(conc, label)
        
        if mol is None or not mol == mol:
            print(sample)
            raise ValueError('Molecule is None or NaN for seq_id: {}, label: {}, mol: {}'.format(seq, label, mol))
        
        if self.element_preprocess is not None:
            seq, mol, label = self.element_preprocess(seq, mol, conc, label)

        return seq, mol, label
    

class AminoConcentrationDatasetMeasurementsSamplingPrecompute(AminoConcentrationDatasetPrecompute):
    def __init__(self, data_csv, mols_csv, seq_id_col, mol_id_col, mol_col, parameter_col, n_ec50_copies, conc_value_col, conc_value_screen_col, label_col, weight_col = None, seqs_csv = None,
                atom_features = ['AtomicNum'], bond_features = ['BondType'], 
                oversampling_function = None,
                line_graph_max_size = None, # 10 * padding_n_node
                line_graph = True,
                class_alpha = None,
                sample_concentration_and_label = None,
                element_preprocess = None,
                **kwargs):
        self.conc_sets_eps = 5*1e-3

        self.screening_lower_margin = kwargs.get('screening_lower_margin')
        self.screening_upper_margin = kwargs.get('screening_upper_margin')
        
        self.ec50_lower_margin = kwargs.get('ec50_lower_margin')
        self.ec50_upper_margin = kwargs.get('ec50_upper_margin')
        
        self.ec50_greater_than_lower_margin = kwargs.get('ec50_greater_than_lower_margin')

        self.sampling_region_lower_bound = kwargs.get('sampling_region_lower_bound')
        self.sampling_region_upper_bound = kwargs.get('sampling_region_upper_bound')

        super(AminoConcentrationDatasetMeasurementsSamplingPrecompute, self).__init__(data_csv = data_csv,
                                                                            mols_csv = mols_csv,
                                                                            seq_id_col = seq_id_col,
                                                                            mol_id_col = mol_id_col,
                                                                            mol_col = mol_col,
                                                                            parameter_col = parameter_col,
                                                                            n_ec50_copies = n_ec50_copies,
                                                                            conc_value_col = conc_value_col,
                                                                            conc_value_screen_col = conc_value_screen_col,
                                                                            label_col = label_col,
                                                                            weight_col = weight_col,
                                                                            seqs_csv = seqs_csv,
                                                                            atom_features = atom_features,
                                                                            bond_features = bond_features,
                                                                            oversampling_function = oversampling_function,
                                                                            line_graph_max_size = line_graph_max_size,
                                                                            line_graph = line_graph,
                                                                            class_alpha = class_alpha,
                                                                            sample_concentration_and_label = sample_concentration_and_label,
                                                                            element_preprocess = element_preprocess,
                                                                            **kwargs)
        
    def _create_conc_dict(self, x):
        conc_dict = Concentration()
        param = x[self.parameter_col]
        if param == 'ec50':
            val = x[self.conc_value_col]
            if val == 'n.d':
                C0 = self.sampling_region_upper_bound
                C1 = self.sampling_region_upper_bound + self.conc_sets_eps
            elif '>' in val:
                C0 = float(val[1:]) - self.ec50_greater_than_lower_margin
                C1 = self.sampling_region_upper_bound + self.conc_sets_eps
            else:
                C0 = float(val) - self.ec50_lower_margin
                C1 = float(val) + self.ec50_upper_margin
        else:
            val = x[self.conc_value_screen_col]
            if x[self.label_col] == 0:
                C0 = float(val) - self.screening_lower_margin
                C1 = self.sampling_region_upper_bound + self.conc_sets_eps
            elif x[self.label_col] == 1:
                C0 = self.sampling_region_lower_bound - self.conc_sets_eps
                C1 = float(val) + self.screening_upper_margin

        conc_dict['C0'] = numpy.expand_dims(C0, axis = 0)
        conc_dict['C1'] = numpy.expand_dims(C1, axis = 0)
        return conc_dict




class AminoConcentrationDatasetPairsSamplingPrecompute(AminoConcentrationDatasetPrecompute):
    def __init__(self, data_csv, mols_csv, seq_id_col, mol_id_col, mol_col, parameter_col, n_ec50_copies, conc_value_col, conc_value_screen_col, label_col, weight_col = None, seqs_csv = None,
                atom_features = ['AtomicNum'], bond_features = ['BondType'], 
                oversampling_function = None,
                line_graph_max_size = None, # 10 * padding_n_node
                line_graph = True,
                class_alpha = None,
                sample_concentration_and_label = None,
                element_preprocess = None,
                **kwargs):
        self.conc_sets_eps = 5*1e-3

        self.screening_lower_margin = kwargs.get('screening_lower_margin')
        self.screening_upper_margin = kwargs.get('screening_upper_margin')
        
        self.ec50_lower_margin = kwargs.get('ec50_lower_margin')
        self.ec50_upper_margin = kwargs.get('ec50_upper_margin')
        
        self.ec50_greater_than_lower_margin = kwargs.get('ec50_greater_than_lower_margin')

        self.sampling_region_lower_bound = kwargs.get('sampling_region_lower_bound')
        self.sampling_region_upper_bound = kwargs.get('sampling_region_upper_bound')

        super(AminoConcentrationDatasetPairsSamplingPrecompute, self).__init__(data_csv = data_csv,
                                                                            mols_csv = mols_csv,
                                                                            seq_id_col = seq_id_col,
                                                                            mol_id_col = mol_id_col,
                                                                            mol_col = mol_col,
                                                                            parameter_col = parameter_col,
                                                                            n_ec50_copies = n_ec50_copies,
                                                                            conc_value_col = conc_value_col,
                                                                            conc_value_screen_col = conc_value_screen_col,
                                                                            label_col = label_col,
                                                                            weight_col = weight_col,
                                                                            seqs_csv = seqs_csv,
                                                                            atom_features = atom_features,
                                                                            bond_features = bond_features,
                                                                            oversampling_function = oversampling_function,
                                                                            line_graph_max_size = line_graph_max_size,
                                                                            line_graph = line_graph,
                                                                            class_alpha = class_alpha,
                                                                            sample_concentration_and_label = sample_concentration_and_label,
                                                                            element_preprocess = element_preprocess,
                                                                            **kwargs)

    def _construct_conc_sets_per_pair(self, df, padding_size = None):
        df_ec50 = df[df[self.parameter_col] == 'ec50']
        df_screening = df[df[self.parameter_col] != 'ec50']

        df_ec50_greater_than = df_ec50[df_ec50[self.conc_value_col].str.contains('>')]
        df_ec50_nd = df_ec50[df_ec50[self.conc_value_col] == 'n.d']
        df_ec50_active  = df_ec50.loc[df_ec50.index.difference(df_ec50_greater_than.index).difference(df_ec50_nd.index)]

        df_screening_non_active = df_screening[df_screening[self.label_col] == 0]
        df_screening_active = df_screening[df_screening[self.label_col] == 1]

        array_ec50_greater_than = numpy.array([])
        array_ec50_nd = numpy.array([])
        array_df_ec50_active_C0 = numpy.array([])
        array_df_ec50_active_C1 = numpy.array([])
        
        array_screening_non_active = numpy.array([])
        array_screening_active = numpy.array([])

        if not df_ec50_greater_than.empty:
            array_ec50_greater_than = df_ec50_greater_than[self.conc_value_col].apply(lambda x: float(x[1:])) - self.ec50_greater_than_lower_margin
        if not df_ec50_nd.empty:
            array_ec50_nd = numpy.ones(len(df_ec50_nd)) * self.sampling_region_upper_bound
        if not df_ec50_active.empty:
            array_df_ec50_active_C0 = df_ec50_active[self.conc_value_col].astype(float) - self.ec50_lower_margin
            array_df_ec50_active_C1 = df_ec50_active[self.conc_value_col].astype(float) + self.ec50_upper_margin
        
        if not df_screening_non_active.empty:
            array_screening_non_active = df_screening_non_active[self.conc_value_screen_col].astype(float) - self.screening_lower_margin
        if not df_screening_active.empty:
            array_screening_active = df_screening_active[self.conc_value_screen_col].astype(float) + self.screening_upper_margin

        C0 = numpy.concatenate([array_ec50_greater_than,
                                array_ec50_nd,
                                array_df_ec50_active_C0,
                                array_screening_non_active])
        
        C1 = numpy.concatenate([array_df_ec50_active_C1,
                                array_screening_active])

        C0_padding = numpy.ones(padding_size - len(C0)) * (self.sampling_region_lower_bound - self.conc_sets_eps)
        C1_padding = numpy.ones(padding_size - len(C1)) * (self.sampling_region_upper_bound + self.conc_sets_eps)

        C0 = numpy.concatenate([C0, C0_padding])
        C1 = numpy.concatenate([C1, C1_padding])

        return pandas.Series({'C0' : C0, 'C1' : C1})
    
    def _create_conc_dict(self, x):
        conc_dict = Concentration()
        conc_dict['C0'] = numpy.expand_dims(x['C0'], axis = 0)
        conc_dict['C1'] = numpy.expand_dims(x['C1'], axis = 0)
        return conc_dict
    
    def _create_label_dict(self, x):
        label_dict = Label()
        label_dict['_main_label'] = -1 # dummy label
        if self.weight_col is not None:
            label_dict['_main_sample_weight'] = -1.0 # dummy weight
        for col in self.auxiliary_label_cols:
            raise NotImplementedError('There is groupby so it is not clear how to treat auxiliary columns.')
            val = x[col]
            label_dict[col] = val['value']
            label_dict[col + '_mask'] = val['mask']
        for col in self.auxiliary_weight_cols:
            raise NotImplementedError('There is groupby so it is not clear how to treat auxiliary columns.')
            val = x[col]
            label_dict[col] = val['value']
        return label_dict

    def read(self):
        if isinstance(self.data_csv, pandas.DataFrame):
            df = self.data_csv # [[self.mol_col, self.seq_id_col, self.label_col] + self._weight_col]
        else:
            df = pandas.read_csv(self.data_csv, sep = self.sep) # , usecols = [self.mol_id_col, self.seq_id_col, self.label_col] + self._weight_col + self.auxiliary_label_cols)

        # -----------------------------------------------------------------------------
        # TODO: Data augmentation would not work at this implementation. We only 
        #     look at whether there is at least one example at a given concentration.
        # -----------------------------------------------------------------------------
        # df_ec50 = df[df[self.parameter_col] == 'ec50']
        # n_all_original = len(df)
        # n_ec50_original = len(df_ec50)

        # if self.n_ec50_copies > 1:
        #     df_ec50_augmented = pandas.concat([df_ec50]*(self.n_ec50_copies - 1))
        # elif self.n_ec50_copies == 1:
        #     df_ec50_augmented = pandas.DataFrame([], columns = df_ec50.columns)
        # else:
        #     raise ValueError('EC50 not enlarged. This may lead to a lot of unused data.')
        # df = pandas.concat([df, df_ec50_augmented])
        # n_all_augmented = len(df)
        # n_ec50_augmented = n_ec50_original * self.n_ec50_copies
        # print('Extending number of EC50 measurements from {:.2f}% to {:.2f}%'.format(100.0*(n_ec50_original/n_all_original), 100.0*(n_ec50_augmented/n_all_augmented)))
        # print('Total number of pairs after ec50 augmentation: {}'.format(n_all_augmented))
        # -----------------------------------------------------------------------------

        padding_size = df.groupby([self.seq_id_col, self.mol_id_col]).count()[self.parameter_col].max()
        df = df.groupby([self.mol_id_col, self.seq_id_col]).apply(lambda x: self._construct_conc_sets_per_pair(x, padding_size = padding_size))
        df = df.reset_index()

        df_seq_ids = None
        df_mol_ids = pandas.Index(df[self.mol_id_col].unique())

        if self.seqs_csv is not None:
            # raise NotImplementedError('Overall structure is here, but check when there is a usecase. Dont forget to change the columns that are picked at the end of self.read.')
            seqs = pandas.read_csv(self.seqs_csv, sep = self.sep, index_col = self.seq_id_col)
            seqs = seqs.applymap(self._cast)
            df = df.join(seqs, on = self.seq_id_col, how = 'inner', rsuffix = '__seq')

        # mols_csv can not be None:
        mols = pandas.read_csv(self.mols_csv, sep = self.sep, index_col = self.mol_id_col)
        mols = mols.drop_duplicates()
        mols = mols.loc[mols.index.intersection(df_mol_ids)] # delete mols not in df:

        _mol_other_cols = mols.columns.difference(set([self.mol_col]))
        mols[_mol_other_cols] = mols[_mol_other_cols].applymap(self._cast)
        mols['_graphs'] = mols.apply(self._read_graph, axis = 1) # self.mol_global_cols used here...
        mols.dropna(subset = ['_graphs'], inplace = True)         

        df = df.join(mols, on = self.mol_id_col, how = 'inner', rsuffix = '__mol')

        # Process auxiliary labels:
        if len(self.auxiliary_label_cols) > 0:
            raise NotImplementedError('There is groupby so it is not clear how to treat auxiliary columns.')
            df[self.auxiliary_label_cols] = df[self.auxiliary_label_cols].applymap(self._cast)
            for col in self.auxiliary_label_cols:
                default_val = self._get_nan_default(df[col])
                df[col] = df[col].apply(lambda x: self._get_nan_mask(x, default_val = default_val))

        if len(self.auxiliary_weight_cols) > 0:
            raise NotImplementedError('There is groupby so it is not clear how to treat auxiliary columns.')
            df[self.auxiliary_weight_cols] = df[self.auxiliary_weight_cols].applymap(self._cast)
            for col in self.auxiliary_weight_cols:
                default_val = self._get_nan_default(df[col])
                df[col] = df[col].apply(lambda x: self._get_nan_mask(x, default_val = default_val))

        df['_seq'] = df.apply(lambda x: self._create_seq_dict(x), axis = 1)
        df['_conc'] = df.apply(lambda x: self._create_conc_dict(x), axis = 1)
        df['_label'] = df.apply(lambda x: self._create_label_dict(x), axis = 1)

        df = df[['_seq', '_graphs', '_conc', '_label']]

        # print(df[(df['_label'].apply(lambda x: x['_main_label'] == 0))&(df['_conc'].apply(lambda x: x['value'] != 'n.d' and x['parameter'] == 'ec50'))])
        # raise Exception('test')

        if self.oversampling_function is not None:
            raise NotImplementedError('oversampling is not implemented...')
            df = self.oversampling_function(df, label_col = self.label_col)
        
        return df  








# ----------------------------------------------------------
# Prediction dataset with Precomputed amino acids embedding:
# ----------------------------------------------------------
class AminoConcentrationDatasetPrecomputePredict(AminoConcentrationDatasetPrecompute):
    """
    consider introducing mol_buffer to save already preprocessed graphs.
    """
    def __init__(self, data_csv, mols_csv, seq_id_col, mol_id_col, mol_col, conc_value_col, seqs_csv = None,
                atom_features = ['AtomicNum'], bond_features = ['BondType'], 
                line_graph_max_size = None, # 10 * padding_n_node
                line_graph = True,
                class_alpha = None,
                sample_concentration_and_label = None,
                element_preprocess = None,
                **kwargs):
        super(AminoConcentrationDatasetPrecomputePredict, self).__init__(data_csv = data_csv, mols_csv = mols_csv, 
                seq_id_col = seq_id_col, mol_id_col = mol_id_col, mol_col = mol_col, 
                parameter_col = None, n_ec50_copies = None, 
                conc_value_col = conc_value_col, 
                conc_value_screen_col = None, 
                label_col = None, weight_col = None, 
                seqs_csv = seqs_csv,
                atom_features = atom_features, bond_features = bond_features, 
                oversampling_function = None,
                line_graph_max_size = line_graph_max_size, # 10 * padding_n_node
                line_graph = line_graph,
                class_alpha = class_alpha,
                sample_concentration_and_label = sample_concentration_and_label,
                element_preprocess = element_preprocess,
                **kwargs)
    
    def _create_conc_dict(self, x):
        conc_dict = Concentration()
        val = x[self.conc_value_col]
        conc_dict['value'] = numpy.array([val], dtype = numpy.float32)
        return conc_dict

    def _create_label_dict(self, x):
        label_dict = Label()
        label_dict['_main_label'] = -1
        return label_dict

    def read(self):
        if isinstance(self.data_csv, pandas.DataFrame):
            df = self.data_csv # [[self.mol_col, self.seq_id_col, self.label_col] + self._weight_col]
        else:
            df = pandas.read_csv(self.data_csv, sep = self.sep) # , usecols = [self.mol_id_col, self.seq_id_col, self.label_col] + self._weight_col + self.auxiliary_label_cols)

        df_seq_ids = None
        df_mol_ids = pandas.Index(df[self.mol_id_col].unique())

        if self.seqs_csv is not None:
            # raise NotImplementedError('Overall structure is here, but check when there is a usecase. Dont forget to change the columns that are picked at the end of self.read.')
            seqs = pandas.read_csv(self.seqs_csv, sep = self.sep, index_col = self.seq_id_col)
            seqs = seqs.applymap(self._cast)
            df = df.join(seqs, on = self.seq_id_col, how = 'inner', rsuffix = '__seq')

        # mols_csv can not be None:
        mols = pandas.read_csv(self.mols_csv, sep = self.sep, index_col = self.mol_id_col)
        mols = mols.drop_duplicates()
        mols = mols.loc[mols.index.intersection(df_mol_ids)] # delete mols not in df:

        _mol_other_cols = mols.columns.difference(set([self.mol_col]))
        mols[_mol_other_cols] = mols[_mol_other_cols].applymap(self._cast)
        mols['_graphs'] = mols.apply(self._read_graph, axis = 1) # self.mol_global_cols used here...
        mols.dropna(subset = ['_graphs'], inplace = True)         

        df = df.join(mols, on = self.mol_id_col, how = 'inner', rsuffix = '__mol')

        if self.conc_value_col is None:
            self.conc_value_col = '_CONC_'
            df[self.conc_value_col] = -12.0

        df['_seq'] = df.apply(lambda x: self._create_seq_dict(x), axis = 1)
        df['_conc'] = df.apply(lambda x: self._create_conc_dict(x), axis = 1)
        df['_label'] = df.apply(lambda x: self._create_label_dict(x), axis = 1)

        df = df[[self.mol_id_col, self.seq_id_col, self.conc_value_col, '_seq', '_graphs', '_conc', '_label']]

        if self.oversampling_function is not None:
            raise NotImplementedError('oversampling is not implemented...')

        return df            
    
    def __getitem__(self, index):
        # TODO: previously jax DeviceArray and this raised assertionError 
        # in pandas. See if this behaviour changes in new padnas
        index = numpy.asarray(index)
        sample = self.data.iloc[index]

        mol = sample['_graphs']
        seq = sample['_seq']
        conc = sample['_conc']
        label = sample['_label']

        # print(conc_raw)
        if self.sample_concentration_and_label is not None:
            raise ValueError('Concentration can not be sampled in prediction.')

        if mol is None or not mol == mol:
            print(sample)
            raise ValueError('Molecule is None or NaN for seq_id: {}, label: {}, mol: {}'.format(seq, label, mol))
        
        if self.element_preprocess is not None:
            seq, mol, label = self.element_preprocess(seq, mol, conc['value'], label)

        return seq, mol, label