import numpy
import pandas
import ast
import time
import copy
import jax
from jax import numpy as jnp

from ProtLig_GPCRclassA.mol2graph.jraph.convert import smiles_to_jraph
from ProtLig_GPCRclassA.mol2graph.exceptions import NoBondsError

from ProtLig_GPCRclassA.utils import create_line_graph, Sequence, Label, Concentration
from ProtLig_GPCRclassA.base_loader import BaseDataset


# ----------------------------------
# Precomputed amino acids embedding:
# ----------------------------------
class AminoEC50RegressionDatasetPrecompute(BaseDataset):
    """
    consider introducing mol_buffer to save already preprocessed graphs.
    """
    def __init__(self, data_csv, mols_csv, seq_id_col, mol_id_col, mol_col, parameter_col, conc_value_col, conc_value_screen_col, label_col, weight_col = None, seqs_csv = None,
                atom_features = ['AtomicNum'], bond_features = ['BondType'], 
                class_alpha = None,
                sample_concentration_and_label = None,
                element_preprocess = None,
                non_active_ec50_value = None,
                **kwargs):
        """
        Parameters:
        -----------
        data_csv : str
            path to csv

        seq_id_col : str
            name of the column with protein IDs

        mol_col : str
            name of the column with smiles
        
        label_col : str
            name of the column with labels (in case of multilabel problem, 
            labels should be in one column)

        atom_features : list
            list of atom features.

        bond_features : list
            list of bond features.

        **kwargs
            IncludeHs
            sep
            seq_sep

        Notes:
        ------
        sequence representations are retrived in Collate.
        """
        self.data_csv = data_csv
        self.mols_csv = mols_csv
        self.seqs_csv = seqs_csv
        self.sep = kwargs.get('sep', ';')

        self.conc_parameter_id_map = {'ec50_nd' : 0,
                                  'ec50_greater_than' : 1,
                                  'ec50' : 2,
                                  'screening' : 3}
        self.conc_parameter_id_inverse_map = {val : key for key,val in self.conc_parameter_id_map.items()}

        self.seq_id_col = seq_id_col
        self.mol_id_col = mol_id_col
        self.mol_col = mol_col
        self.parameter_col = parameter_col
        self.conc_value_col = conc_value_col
        self.conc_value_screen_col = conc_value_screen_col
        self.label_col = label_col
        self.weight_col = weight_col

        self.class_alpha = class_alpha

        self.IncludeHs = kwargs.get('IncludeHs', False)
        self.self_loops = kwargs.get('self_loops', False)

        # Additional information:
        self.auxiliary_label_cols = kwargs.get('auxiliary_label_cols', [])
        self.auxiliary_weight_cols = kwargs.get('auxiliary_weight_cols', [])
        self.mol_global_cols = kwargs.get('mol_global_cols', [])
        self.seq_global_cols = kwargs.get('seq_global_cols', [])

        # Adjusted_class_weight:
        self._calculate_class_weight = False
        if weight_col == '_calculate_class_weight':
            self.weight_col = None
            self._calculate_class_weight = True        

        self.atom_features = atom_features
        self.bond_features = bond_features

        self.sample_concentration_and_label = sample_concentration_and_label
        self.element_preprocess = element_preprocess

        self.mean_ec50 = None
        self.std_ec50 = None

        if non_active_ec50_value is None:
            raise ValueError('non_active_ec50_value must be set.')
        else:
            self.non_active_ec50_value = non_active_ec50_value

        self.data = self.read()

    def _read_graph(self, x):
        """
        """
        try:
            if len(self.mol_global_cols) > 0:
                u = {col : x[col] for col in self.mol_global_cols}
            else:
                u = None
            G = smiles_to_jraph(x[self.mol_col], u = u, validate = False, IncludeHs = self.IncludeHs,
                            atom_features = self.atom_features, bond_features = self.bond_features,
                            self_loops = self.self_loops)
        except NoBondsError:
            return float('nan')
        except AssertionError:
            print(x[self.mol_col])
            print('WARNING: having this except is short term solution. Discard this data in data preprocessing!!\n--------')
            return float('nan')
        # return G
        return (G, )

    @staticmethod
    def _cast(x):
        if isinstance(x, str):
            _x = x.strip().lower()
            if _x[0] == '[' and _x[-1] == ']': # expecting list..
                _x_content = _x[1:-1]
                if len(_x_content) > 0:
                    if ', ' in _x_content:
                        return numpy.array([float(ele) for ele in _x_content.split(', ')])
                    elif ',' in _x_content:
                        return numpy.array([float(ele) for ele in _x_content.split(',')])
                else:
                    return numpy.array([])
            else:
                return x
        elif numpy.isnan(x).any():
            return x
        elif isinstance(x, int):
            return numpy.array([x], dtype = numpy.int32)
        elif isinstance(x, float):
            return numpy.array([x], dtype = numpy.float32)
        else:
            return x

    @staticmethod
    def _get_nan_default(series):
        for x in series.items():
            idx, val = x
            if not numpy.isnan(val).any():
                return val
        raise ValueError('All values in column {} are NaN.'.format(series.name))

    @staticmethod
    def _get_nan_mask(x, default_val):
        if numpy.isnan(x).any():
            return {'value' : default_val, 'mask' : False}
        else:
            return {'value' : x, 'mask' : True}

    def _create_seq_dict(self, x):
        _seq_id = x[self.seq_id_col]
        _seq_id = str(_seq_id)
        seq = Sequence()
        # seq[_seq_id] = _seq_id
        seq['_seq_id'] = _seq_id
        for col in self.seq_global_cols:
            seq[col] = x[col]
        return seq

    def _create_label_dict(self, x):
        label_dict = Label()
        val = x[self.conc_value_col]
        if val == 'n.d':
            label_dict['_main_label'] = numpy.array([self.non_active_ec50_value], dtype = numpy.float32)
        else:    
            label_dict['_main_label'] = numpy.array([val], dtype = numpy.float32)
        # weights:
        if self.weight_col is not None:
            label_dict['_main_sample_weight'] = x[self.weight_col]
        for col in self.auxiliary_label_cols:
            val = x[col]
            label_dict[col] = val['value']
            label_dict[col + '_mask'] = val['mask']
        for col in self.auxiliary_weight_cols:
            val = x[col]
            label_dict[col] = val['value']
        return label_dict


    def read(self):
        if isinstance(self.data_csv, pandas.DataFrame):
            df = self.data_csv # [[self.mol_col, self.seq_id_col, self.label_col] + self._weight_col]
        else:
            df = pandas.read_csv(self.data_csv, sep = self.sep) # , usecols = [self.mol_id_col, self.seq_id_col, self.label_col] + self._weight_col + self.auxiliary_label_cols)

        df_seq_ids = None
        df_mol_ids = pandas.Index(df[self.mol_id_col].unique())

        if self.seqs_csv is not None:
            # raise NotImplementedError('Overall structure is here, but check when there is a usecase. Dont forget to change the columns that are picked at the end of self.read.')
            seqs = pandas.read_csv(self.seqs_csv, sep = self.sep, index_col = self.seq_id_col)
            seqs = seqs.applymap(self._cast)
            df = df.join(seqs, on = self.seq_id_col, how = 'inner', rsuffix = '__seq')

        # mols_csv can not be None:
        mols = pandas.read_csv(self.mols_csv, sep = self.sep, index_col = self.mol_id_col)
        mols = mols.drop_duplicates()
        mols = mols.loc[mols.index.intersection(df_mol_ids)] # delete mols not in df:

        _mol_other_cols = mols.columns.difference(set([self.mol_col]))
        mols[_mol_other_cols] = mols[_mol_other_cols].applymap(self._cast)
        mols['_graphs'] = mols.apply(self._read_graph, axis = 1) # self.mol_global_cols used here...
        mols.dropna(subset = ['_graphs'], inplace = True)         

        df = df.join(mols, on = self.mol_id_col, how = 'inner', rsuffix = '__mol')

        # Process auxiliary labels:
        if len(self.auxiliary_label_cols) > 0:
            df[self.auxiliary_label_cols] = df[self.auxiliary_label_cols].applymap(self._cast)
            for col in self.auxiliary_label_cols:
                default_val = self._get_nan_default(df[col])
                df[col] = df[col].apply(lambda x: self._get_nan_mask(x, default_val = default_val))

        if len(self.auxiliary_weight_cols) > 0:
            df[self.auxiliary_weight_cols] = df[self.auxiliary_weight_cols].applymap(self._cast)
            for col in self.auxiliary_weight_cols:
                default_val = self._get_nan_default(df[col])
                df[col] = df[col].apply(lambda x: self._get_nan_mask(x, default_val = default_val))

        df = df[df[self.parameter_col] == 'ec50']

        print('\n\n WARNING: Removing ec50_greater_than...')
        df = df[~df[self.conc_value_col].str.contains('>')]

        df['_seq'] = df.apply(lambda x: self._create_seq_dict(x), axis = 1)
        df['_label'] = df.apply(lambda x: self._create_label_dict(x), axis = 1)

        if self._calculate_class_weight:
            class_dist = df.groupby(self.label_col).count()['_label'].to_dict()
            class_weight_map = self._get_class_weight_map(class_dist)

            def _update_label_dict(x):
                label_dict = x['_label'].copy()
                label_dict['_main_sample_weight'] = class_weight_map[x[self.label_col]]
                return label_dict
            df['_label'] = df.apply(_update_label_dict, axis = 1)
        
        df = df[['_seq', '_graphs', '_label']]

        self.data = df.copy()
        return df

    def _get_class_weight_map(self, class_dist):
        n_samples = sum(class_dist.values())
        n_classes = len(class_dist)
        class_weight_map = {key : n_samples/(n_classes*class_dist[key]) for key in class_dist.keys()}
        return class_weight_map

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        # TODO: previously jax DeviceArray and this raised assertionError 
        # in pandas. See if this behaviour changes in new padnas
        index = numpy.asarray(index)
        sample = self.data.iloc[index]
        
        mol = sample['_graphs']
        seq = sample['_seq']
        label = sample['_label']
        
        if mol is None or not mol == mol:
            print(sample)
            raise ValueError('Molecule is None or NaN for seq_id: {}, label: {}, mol: {}'.format(seq, label, mol))
        
        if self.element_preprocess is not None:
            seq, mol, label = self.element_preprocess(seq, mol, label)

        return seq, mol, label








# ----------------------------------------------------------
# Prediction dataset with Precomputed amino acids embedding:
# ----------------------------------------------------------
class AminoEC50RegressionDatasetPrecomputePredict(AminoEC50RegressionDatasetPrecompute):
    """
    consider introducing mol_buffer to save already preprocessed graphs.
    """
    def __init__(self, data_csv, mols_csv, seq_id_col, mol_id_col, mol_col, conc_value_col, seqs_csv = None,
                atom_features = ['AtomicNum'], bond_features = ['BondType'], 
                line_graph_max_size = None, # 10 * padding_n_node
                line_graph = True,
                class_alpha = None,
                sample_concentration_and_label = None,
                element_preprocess = None,
                **kwargs):
        raise NotImplementedError('This is not implemented yet.')
        super(AminoConcentrationDatasetPrecomputePredict, self).__init__(data_csv = data_csv, mols_csv = mols_csv, 
                seq_id_col = seq_id_col, mol_id_col = mol_id_col, mol_col = mol_col, 
                parameter_col = None, n_ec50_copies = None, 
                conc_value_col = conc_value_col, 
                conc_value_screen_col = None, 
                label_col = None, weight_col = None, 
                seqs_csv = seqs_csv,
                atom_features = atom_features, bond_features = bond_features, 
                oversampling_function = None,
                line_graph_max_size = line_graph_max_size, # 10 * padding_n_node
                line_graph = line_graph,
                class_alpha = class_alpha,
                sample_concentration_and_label = sample_concentration_and_label,
                element_preprocess = element_preprocess,
                **kwargs)
    
    def _create_conc_dict(self, x):
        conc_dict = Concentration()
        val = x[self.conc_value_col]
        conc_dict['value'] = numpy.array([val], dtype = numpy.float32)
        return conc_dict

    def _create_label_dict(self, x):
        label_dict = Label()
        label_dict['_main_label'] = -1
        return label_dict

    def read(self):
        if isinstance(self.data_csv, pandas.DataFrame):
            df = self.data_csv # [[self.mol_col, self.seq_id_col, self.label_col] + self._weight_col]
        else:
            df = pandas.read_csv(self.data_csv, sep = self.sep) # , usecols = [self.mol_id_col, self.seq_id_col, self.label_col] + self._weight_col + self.auxiliary_label_cols)

        df_seq_ids = None
        df_mol_ids = pandas.Index(df[self.mol_id_col].unique())

        if self.seqs_csv is not None:
            # raise NotImplementedError('Overall structure is here, but check when there is a usecase. Dont forget to change the columns that are picked at the end of self.read.')
            seqs = pandas.read_csv(self.seqs_csv, sep = self.sep, index_col = self.seq_id_col)
            seqs = seqs.applymap(self._cast)
            df = df.join(seqs, on = self.seq_id_col, how = 'inner', rsuffix = '__seq')

        # mols_csv can not be None:
        mols = pandas.read_csv(self.mols_csv, sep = self.sep, index_col = self.mol_id_col)
        mols = mols.drop_duplicates()
        mols = mols.loc[mols.index.intersection(df_mol_ids)] # delete mols not in df:

        _mol_other_cols = mols.columns.difference(set([self.mol_col]))
        mols[_mol_other_cols] = mols[_mol_other_cols].applymap(self._cast)
        mols['_graphs'] = mols.apply(self._read_graph, axis = 1) # self.mol_global_cols used here...
        mols.dropna(subset = ['_graphs'], inplace = True)         

        df = df.join(mols, on = self.mol_id_col, how = 'inner', rsuffix = '__mol')

        df['_seq'] = df.apply(lambda x: self._create_seq_dict(x), axis = 1)
        df['_conc'] = df.apply(lambda x: self._create_conc_dict(x), axis = 1)
        df['_label'] = df.apply(lambda x: self._create_label_dict(x), axis = 1)

        df = df[[self.mol_id_col, self.seq_id_col, self.conc_value_col, '_seq', '_graphs', '_conc', '_label']]

        if self.oversampling_function is not None:
            raise NotImplementedError('oversampling is not implemented...')

        return df            
    
    def __getitem__(self, index):
        # TODO: previously jax DeviceArray and this raised assertionError 
        # in pandas. See if this behaviour changes in new padnas
        index = numpy.asarray(index)
        sample = self.data.iloc[index]

        mol = sample['_graphs']
        seq = sample['_seq']
        conc = sample['_conc']
        label = sample['_label']

        # print(conc_raw)
        if self.sample_concentration_and_label is not None:
            raise ValueError('Concentration can not be sampled in prediction.')

        if mol is None or not mol == mol:
            print(sample)
            raise ValueError('Molecule is None or NaN for seq_id: {}, label: {}, mol: {}'.format(seq, label, mol))
        
        if self.element_preprocess is not None:
            seq, mol, label = self.element_preprocess(seq, mol, conc['value'], label)

        return seq, mol, label