import os
import functools
import re
import jax
from jax import numpy as jnp
import numpy
# import numpy
import pandas
import json
import tables

from ProtLig_GPCRclassA.mol2graph.read import read_fasta

from ProtLig_GPCRclassA.utils import smiles_to_jraph_and_serialize, serialize_ESM2_hidden_states
from ProtLig_GPCRclassA.base_loader import BaseDataset, BaseDataLoader


class PrecomputeESM2Dataset(BaseDataset):
    """
    """
    def __init__(self, data, seq_col, id_col,
                 orient='columns'):
        self.seq_col = seq_col # 'seq'
        self.id_col = id_col
        self.data = data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        index = numpy.asarray(index)
        seq = self.data.iloc[index][self.seq_col]
        seq = ' '.join(list(seq))
        seq = re.sub(r"[UZOB]", "X", seq)

        ids = self.data.iloc[index][self.id_col]
        return ids, seq


def collate_fn_seq_with_id(batch, tokenizer, n_partitions, max_length, add_position_ids = False):
    """
    """
    ids, batch = zip(*batch) # transposed
    
    seqs = dict(tokenizer(batch, return_tensors='pt', padding = 'max_length', max_length = max_length, truncation = True)) # 2048
    if add_position_ids and 'position_ids' not in seqs.keys():
        raise NotImplementedError('...')
        seqs['position_ids'] = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(seqs['input_ids']).shape[-1]), seqs['input_ids'].shape)
    if n_partitions > 0:
        partition_size = len(batch) // n_partitions
        _seqs = []
        for i in range(n_partitions): # n_partitions
            _seq = {}
            for key in seqs.keys():
                _seq[key] = seqs[key][i*partition_size:(i+1)*partition_size]
            _seqs.append(_seq)
        return ids, _seqs
    else:
        return ids, seqs


class PrecomputeESM2Loader(BaseDataLoader):
    """
    """
    def __init__(self, dataset, tokenizer,
                    batch_size=1,
                    n_partitions = 0,
                    shuffle=False, 
                    rng=None, 
                    drop_last=False,
                    max_length=512,
                    add_position_ids = False):

        self.n_partitions = n_partitions
        if n_partitions > 0:
            assert batch_size % self.n_partitions == 0

        super(self.__class__, self).__init__(dataset,
        batch_size = batch_size,
        shuffle = shuffle,
        rng = rng,
        drop_last = drop_last,
        collate_fn = functools.partial(collate_fn_seq_with_id, tokenizer = tokenizer,
                                                            n_partitions = n_partitions,
                                                            max_length = max_length,
                                                            add_position_ids = add_position_ids),
        )



class PrecomputeESM2:
    def __init__(self, data_file, save_dir, save_folder_name = None, mode = 'a', id_col = 'UniProt ID', seq_col = 'seq', dbname = '', batch_size = 8, 
                    model = None, tokenizer = None, max_length = 512,
                    hidden_states_shape = (31, 512, 1024)):
        """
        """
        if save_folder_name is None:
            save_folder_name = __class__.__name__

        self.data_file = data_file
        # self.save_dir = save_dir
        self.save_dir = os.path.join(save_dir, save_folder_name)
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        self.id_col = id_col 
        self.seq_col = seq_col
        self.batch_size = batch_size
        self.max_length = max_length
        self.hidden_states_shape = hidden_states_shape
        if max_length not in hidden_states_shape:
            raise ValueError('max_lenght is not a part of hidden_states_shape.')

        self.model = model
        self.tokenizer = tokenizer
        self.apply_model = self._make_apply_model(model)

        self.db_id_len = 64
        self.dbname = dbname
        self.mode = mode

    def serialize_hparams(self):
        """
        returns dictionary with all hyperparameters that will be saved. self.save_dir will be added
        to the dict in self.save_hparams.
        """
        return {'batch_size' : str(self.batch_size),
                'model' : self.model.__class__.__name__,
                'tokenizer' : self.tokenizer.__class__.__name__}

    def save_hparams(self):
        hparams = self.serialize_hparams()
        hparams.update({'data_file' : self.data_file,
                        'save_dir' : self.save_dir})
        with open(os.path.join(self.save_dir, 'hparams.json'), 'w') as outfile:
            json.dump(hparams, outfile)
        
    def create_h5file(self, expectedrows):
        # Database handling:
        class PrecomputeESM2table(tables.IsDescription):
            id    = tables.StringCol(self.db_id_len)
            hidden_states = tables.Float32Col(shape = self.hidden_states_shape) # (31, self.max_length, 1024))
            attention_mask = tables.Int32Col(shape = (self.max_length,))
            # test = tables.Float64Col(shape = (1,))

        h5file = tables.open_file(os.path.join(self.save_dir, self.dbname), mode = self.mode, title="ESM2")
        group = h5file.create_group("/", name = 'amino', title = 'AMINOgroup')
        self.filters = tables.Filters(complevel = 1, complib = 'blosc')
        self.table = h5file.create_table(group, name = 'table', description = PrecomputeESM2table, title = "table",
                                        filters = self.filters, expectedrows = expectedrows)
        self.h5file = h5file
        print(h5file)
        return None
    
    def load_h5file(self):
        h5file = tables.open_file(os.path.join(self.save_dir, self.dbname), mode = self.mode, title="ESM2")
        self.table = h5file.root.amino.table
        self.h5file = h5file
        print(h5file)
        return None

    def _make_apply_model(self, model):
        def apply_model(seq):
            output = model(**seq, 
                                output_attentions = False,
                                output_hidden_states = True, 
                                return_dict = True)
            return output
        return apply_model

    def _precompute_and_save(self, data):
        dataset = PrecomputeESM2Dataset(data, seq_col = self.seq_col, id_col = self.id_col)
        loader = PrecomputeESM2Loader(dataset, tokenizer = self.tokenizer, batch_size = self.batch_size,
                    n_partitions = 0, shuffle=False, rng=None, drop_last=False, max_length = self.max_length, add_position_ids = False)

        row = self.table.row

        for i, batch in enumerate(loader):
            ids, batch = batch
            attn_mask = batch['attention_mask']
            
            _batch = self.apply_model(batch)
            hidden_states = serialize_ESM2_hidden_states(_batch.hidden_states)

            n_examples = len(loader) - 1
            for j in range(len(ids)):
                if len(ids[j]) > self.db_id_len:
                    raise ValueError('ID "{}" is too long for db_id_len: {}'.format(ids[j], self.db_id_len))
                row['id'] = ids[j]                
                row['hidden_states'] = hidden_states[j][-1,:,:].astype(jnp.float32) # last layer
                row['attention_mask'] = attn_mask[j].detach().numpy().astype(jnp.int32)
                # row['test'] = numpy.random.normal(size = (1, ))
                row.append()

            if i >= 10 or i >= n_examples:
                self.table.flush()

        print('creating index...')
        if self.mode == 'w':
            self.table.cols.id.create_index(optlevel=9, kind='full', filters = self.filters) # Create index for finished table to speed up search
        else:
            self.table.cols.id.reindex()
        self.table.flush()
        return None

    def load_data(self):
        _, ext = os.path.splitext(self.data_file)
        if ext == ".fasta" or ext == '.fa':
            df = read_fasta(self.data_file)
            df.name = self.seq_col
            df = df.to_frame()
            df.index.name = self.id_col
            df.reset_index(inplace = True)
        elif ext == '.csv':
            df = pandas.read_csv(self.data_file, sep = ';', index_col = None, header = 0, usecols = [self.id_col, self.seq_col])
        return df

    def precompute_and_save(self):
        data = self.load_data()

        data = data[[self.id_col, self.seq_col]]
        data = data[~data[self.id_col].duplicated()]

        print('Number of records to process:  {}'.format(len(data)))

        if self.mode == 'w':
            self.create_h5file(expectedrows = len(data))
        else:
            self.load_h5file()
        self._precompute_and_save(data)
        self.h5file.close()

        self.save_hparams()
        return None


def precompute(hparams):
    from transformers import EsmTokenizer, EsmConfig, EsmModel

    tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
    config = EsmConfig.from_pretrained("facebook/esm2_t33_650M_UR50D")
    model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D", config = config, add_pooling_layer = False)

    hparams = {'DATA_FILE' : os.path.join('amino_GNN','Data','chemosimdb','mixDiscard_20220608-144844','seqs.csv'),
                'SAVE_DIR' : os.path.join('amino_GNN','Data','chemosimdb','mixDiscard_20220608-144844'),
                'MODE' : 'w',
                'DBNAME' : 'ESM2.h5',
                'ID_COL' : 'seq_id',
                'SEQ_COL' : 'mutated_Sequence',
                'BATCH_SIZE' : 8,
                'HIDDEN_STATES_SHAPE' : (1024, 1280),
                'MAX_LENGTH' : 1024}
 
    precomuteESM2 = PrecomputeESM2(data_file = hparams['DATA_FILE'],
                                    save_dir = hparams['SAVE_DIR'],
                                    mode = hparams['MODE'],
                                    dbname = hparams['DBNAME'],
                                    id_col = hparams['ID_COL'],
                                    seq_col = hparams['SEQ_COL'],
                                    batch_size = hparams['BATCH_SIZE'],
                                    model = model,
                                    tokenizer = tokenizer,
                                    hidden_states_shape = hparams['HIDDEN_STATES_SHAPE'],
                                    max_length = hparams['MAX_LENGTH'],
                                    )
    precomuteESM2.precompute_and_save()