import json
import os
from shutil import copy

import pandas
from ProtLig_GPCRclassA.base_cross_validation import BaseCVPostProcess

class CVPP_LengthDiscard(BaseCVPostProcess):
    """
    Discard sequences that are too short/long
    """
    def __init__(self, data_dir, lower_bound = 296, upper_bound = 'Inf', seqs_csv = None, seq_id_col = 'seq_id', seq_col = 'mutated_sequence', auxiliary_data_path = None):
        """
        sequences that have length equal to lower_bound are kept.
        """
        name = 'discard_by_length'
        super(CVPP_LengthDiscard, self).__init__(name, data_dir)

        self._data_name = 'lower{}_upper{}'.format(lower_bound, upper_bound)
        if lower_bound == 'Inf':
            lower_bound = -1
        if upper_bound == 'Inf':
            upper_bound = 10e6
        
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound

        if seqs_csv is None:
            seqs_csv = os.path.join(self.data_dir, 'seqs.csv')

        _, self.seqs_csv_name = os.path.split(seqs_csv)

        self.seqs_csv = seqs_csv
        self.seq_id_col = seq_id_col
        self.seq_col = seq_col

        self.auxiliary_data_path = auxiliary_data_path # To get how many pairs are in each group.
    
    def serialize_hparams(self):
        """
        returns dictionary with all hyperparameters that will be saved. self.working_dir will be added
        to the dict in self.save_hparams.
        """
        return {'lower_bound' : self.lower_bound,
                'upper_bound' : self.upper_bound,
                'seqs_csv' : self.seqs_csv,
                'seq_id_col' : self.seq_id_col,
                'seq_col' : self.seq_col}
    
    def load_seqs(self):
        seqs = pandas.read_csv(self.seqs_csv, sep = self.sep, index_col = self.seq_id_col, header = 0)
        return seqs
    
    def load_auxiliary(self):
        auxiliary = {}
        auxiliary['stats_data'] = pandas.read_csv(self.auxiliary_data_path['stats_data'], sep = ';')
        return auxiliary

    def _postprocess(self, seqs):
        """
        """
        seqs_datas = {}        
        seqs['_len'] = seqs[self.seq_col].apply(len)
        exclude_idx = seqs[(seqs['_len'] < self.lower_bound)|(seqs['_len'] > self.upper_bound)].index

        seqs = seqs.drop(columns = ['_len'])

        seqs_datas['exclude_' + self._data_name] = seqs.loc[exclude_idx].copy()
        seqs_datas[self._data_name] = seqs.loc[seqs.index.difference(exclude_idx)].copy()

        for _name in seqs_datas.keys():
            print('Num of unique sequences in {}: {}'.format(_name, len(seqs_datas[_name].index)))
        return seqs_datas
    
    def _print_num_of_affected_pairs(self, seqs_datas):
        auxiliary = self.load_auxiliary()
        stats_data = auxiliary['stats_data']

        for name in seqs_datas.keys():
            _data = stats_data[stats_data[self.seq_id_col].isin(seqs_datas[name].index)]
            print('Num of pairs in {}: {}'.format(name, len(_data)))

        return None

    def postprocess(self):
        seqs = self.load_seqs()

        seqs_datas = self._postprocess(seqs)
        if self.auxiliary_data_path is not None:
            self._print_num_of_affected_pairs(seqs_datas)

        for name in seqs_datas.keys():
            seqs_datas[name].to_csv(os.path.join(self.working_dir, 'seqs_' + name + '.csv'), sep=';', index = True, header = True)    

        self.save_hparams()
        return None