import numpy
import pandas
import ast

from ProtLig_GPCRclassA.mol2graph.jraph.convert import smiles_to_jraph
from ProtLig_GPCRclassA.mol2graph.exceptions import NoBondsError

from ProtLig_GPCRclassA.utils import create_line_graph, Sequence, Label
from ProtLig_GPCRclassA.base_loader import BaseDataset

# ----------------------------------
# Precomputed amino acids embedding:
# ----------------------------------
class AminoDatasetPrecompute(BaseDataset):
    """
    consider introducing mol_buffer to save already preprocessed graphs.
    """
    def __init__(self, data_csv, mols_csv, seq_id_col, mol_id_col, mol_col, label_col, weight_col = None, seqs_csv = None,
                atom_features = ['AtomicNum'], bond_features = ['BondType'], 
                oversampling_function = None,
                line_graph_max_size = None, # 10 * padding_n_node
                line_graph = True,
                class_alpha = None,
                element_preprocess = None,
                **kwargs):
        """
        Parameters:
        -----------
        data_csv : str
            path to csv

        seq_id_col : str
            name of the column with protein IDs

        mol_col : str
            name of the column with smiles
        
        label_col : str
            name of the column with labels (in case of multilabel problem, 
            labels should be in one column)

        atom_features : list
            list of atom features.

        bond_features : list
            list of bond features.

        **kwargs
            IncludeHs
            sep
            seq_sep

        Notes:
        ------
        sequence representations are retrived in Collate.
        """
        self.data_csv = data_csv
        self.mols_csv = mols_csv
        self.seqs_csv = seqs_csv
        self.sep = kwargs.get('sep', ';')

        # self.seq_json = seq_json
        # self.poc_csv = poc_csv

        self.seq_id_col = seq_id_col # 'Gene'
        self.mol_id_col = mol_id_col
        self.mol_col = mol_col # 'SMILES'
        self.label_col = label_col # 'Responsive'
        self.weight_col = weight_col
        # if weight_col is None:
        #     self._weight_col = []
        # else:
        #     self._weight_col = [weight_col]

        self.class_alpha = class_alpha
        
        self.oversampling_function = oversampling_function

        self.IncludeHs = kwargs.get('IncludeHs', False)
        self.self_loops = kwargs.get('self_loops', False)

        # Additional information:
        self.auxiliary_label_cols = kwargs.get('auxiliary_label_cols', [])
        self.auxiliary_weight_cols = kwargs.get('auxiliary_weight_cols', [])
        self.mol_global_cols = kwargs.get('mol_global_cols', [])
        self.seq_global_cols = kwargs.get('seq_global_cols', [])

        if len(self.mol_global_cols) > 0:
            print(' -----'*8, '\nWARNING: ast.literal_eval is used in _read_graph!\n', '----- '*8)        

        self.atom_features = atom_features
        self.bond_features = bond_features

        self.line_graph = line_graph
        self.line_graph_max_size = line_graph_max_size # 10 * mid_padding_n_node

        legacy_version = kwargs.get('legacy_version', False)
        if legacy_version:
            raise NotImplementedError('Use older git chekpoint for legacy version...')

        self.element_preprocess = element_preprocess

        self.data = self.read()


    def _read_graph(self, x):
        """
        """
        try:
            if len(self.mol_global_cols) > 0:
                u = {}
                for col in self.mol_global_cols:
                    val = x[col]
                    if len(val.shape) == 0:
                        u[col] = numpy.expand_dims(val, axis = 0)
                    else:
                        u[col] = val
            else:
                u = None
            G = smiles_to_jraph(x[self.mol_col], u = u, validate = False, IncludeHs = self.IncludeHs,
                            atom_features = self.atom_features, bond_features = self.bond_features,
                            self_loops = self.self_loops)
        except NoBondsError:
            return float('nan')
        except AssertionError:
            print(x[self.mol_col])
            print('WARNING: having this except is short term solution. Discard this data in data preprocessing!!\n--------')
            return float('nan')
        # return G
        if self.line_graph:
            raise NotImplementedError('Logic for line graph will change entirely.')
            return (G, create_line_graph(G, max_size = self.line_graph_max_size))
        else:
            return (G, )

    @staticmethod
    def _cast(x):
        if isinstance(x, str):
            _x = x.strip().lower()
            if _x[0] == '[' and _x[-1] == ']': # expecting list..
                _x_content = _x[1:-1]
                if len(_x_content) > 0:
                    if ', ' in _x_content:
                        return numpy.array([float(ele) for ele in _x_content.split(', ')])
                    elif ',' in _x_content:
                        return numpy.array([float(ele) for ele in _x_content.split(',')])
                else:
                    return numpy.array([])
                # return numpy.array(ast.literal_eval(_x))
            else:
                return x
        elif numpy.isnan(x).any():
            return x
        elif isinstance(x, int):
            return numpy.array([x], dtype = numpy.int32)
        elif isinstance(x, float):
            return numpy.array([x], dtype = numpy.float32)
        else:
            return x

    @staticmethod
    def _get_nan_default(series):
        for x in series.items():
            idx, val = x
            if not numpy.isnan(val).any():
                return val
        raise ValueError('All values in column {} are NaN.'.format(series.name))

    @staticmethod
    def _get_nan_mask(x, default_val):
        if numpy.isnan(x).any():
            return {'value' : default_val, 'mask' : False}
        else:
            return {'value' : x, 'mask' : True}

    def _create_seq_dict(self, x):
        _seq_id = x[self.seq_id_col]
        _seq_id = str(_seq_id)
        seq = Sequence()
        seq['_seq_id'] = _seq_id
        for col in self.seq_global_cols:
            seq[col] = x[col]
        return seq

    def _create_label_dict(self, x):
        label_dict = Label()
        label_dict['_main_label'] = x[self.label_col]
        if self.weight_col is not None:
            label_dict['_main_sample_weight'] = x[self.weight_col]
        for col in self.auxiliary_label_cols:
            val = x[col]
            label_dict[col] = val['value']
            label_dict[col + '_mask'] = val['mask']
        for col in self.auxiliary_weight_cols:
            val = x[col]
            label_dict[col] = val['value']
        return label_dict

    def read(self):
        if isinstance(self.data_csv, pandas.DataFrame):
            df = self.data_csv # [[self.mol_col, self.seq_id_col, self.label_col] + self._weight_col]
        else:
            df = pandas.read_csv(self.data_csv, sep = self.sep) # , usecols = [self.mol_id_col, self.seq_id_col, self.label_col] + self._weight_col + self.auxiliary_label_cols)
        
        df_seq_ids = None
        df_mol_ids = pandas.Index(df[self.mol_id_col].unique())

        if self.seqs_csv is not None:
            seqs = pandas.read_csv(self.seqs_csv, sep = self.sep, index_col = self.seq_id_col)
            seqs = seqs.applymap(self._cast)
            df = df.join(seqs, on = self.seq_id_col, how = 'inner', rsuffix = '__seq')

        # mols_csv can not be None:
        mols = pandas.read_csv(self.mols_csv, sep = self.sep, index_col = self.mol_id_col)
        mols = mols.drop_duplicates()
        mols = mols.loc[mols.index.intersection(df_mol_ids)] # delete mols not in df:

        _mol_other_cols = mols.columns.difference(set([self.mol_col]))
        mols[_mol_other_cols] = mols[_mol_other_cols].applymap(self._cast)
        mols['_graphs'] = mols.apply(self._read_graph, axis = 1) # self.mol_global_cols used here...
        mols.dropna(subset = ['_graphs'], inplace = True)         

        df = df.join(mols, on = self.mol_id_col, how = 'inner', rsuffix = '__mol')

        # Process auxiliary labels:
        if len(self.auxiliary_label_cols) > 0:
            df[self.auxiliary_label_cols] = df[self.auxiliary_label_cols].applymap(self._cast)
            for col in self.auxiliary_label_cols:
                default_val = self._get_nan_default(df[col])
                df[col] = df[col].apply(lambda x: self._get_nan_mask(x, default_val = default_val))

        if len(self.auxiliary_weight_cols) > 0:
            df[self.auxiliary_weight_cols] = df[self.auxiliary_weight_cols].applymap(self._cast)
            for col in self.auxiliary_weight_cols:
                default_val = self._get_nan_default(df[col])
                df[col] = df[col].apply(lambda x: self._get_nan_mask(x, default_val = default_val))

        df['_seq'] = df.apply(lambda x: self._create_seq_dict(x), axis = 1)
        df['_label'] = df.apply(lambda x: self._create_label_dict(x), axis = 1)
        # df = df[[self.mol_id_col, '_seq', '_graphs', '_label']]
        df = df[[self.mol_id_col, self.seq_id_col, '_seq', '_graphs', '_label']]

        if self.oversampling_function is not None:
            raise NotImplementedError('oversampling is not implemented...')
            df = self.oversampling_function(df, label_col = self.label_col)
        
        return df

    def infer_class_dist(self):
        labels = self.data.apply(lambda x: x['_label']['_main_label'], axis = 1)
        return labels.groupby(labels).count().to_dict()

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        # TODO: previously jax DeviceArray and this raised assertionError 
        # in pandas. See if this behaviour changes in new padnas
        index = numpy.asarray(index)
        sample = self.data.iloc[index]
        
        mol = sample['_graphs']
        seq = sample['_seq']
        label = sample['_label']

        if mol is None or not mol == mol:
            print(sample)
            raise ValueError('Molecule is None or NaN for seq_id: {}, label: {}, mol: {}'.format(seq, label, mol))

        if self.element_preprocess is not None:
            seq, mol, label = self.element_preprocess(seq, mol, label)
        
        return seq, mol, label