import os
import re
import argparse
import logging
import itertools
import yaml

from multiprocessing import Pool
import multiprocessing as mp

import numpy as np
import pandas as pd
from scipy import sparse as sp
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModel
from transformers import T5Tokenizer, T5EncoderModel

from Bio import SeqIO

import sys
sys.path.insert(0,os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

from src.attention import AttentionCalculator
from bio_tda.bio_tda import get_features_from_attention_matrix_v1, get_features_from_attention_matrix_v21, get_features_from_attention_matrix_v31, get_features_from_attention_matrix_v4


class Config:
    def __init__(self, config_file):
        with open(config_file, 'r') as file:
            self.config = yaml.safe_load(file)

    def get(self, key):
        return self.config.get(key)

class DataPreprocessor:
    def __init__(self, config):
        self.config = config

    def load_data(self, testset_name):

        records = SeqIO.parse(f'datasets/binding/{testset_name}_train.txt', 'fasta')
        seqList = []
        for record in records:
            desp = record.description
            data = record.seq
            data_len = len(data)//2
            seq = data[:data_len]
            label = data[data_len:]
            seqList.append([str(desp)] + [str(seq)] + [str(label)])
            train = pd.DataFrame(seqList)
        train.columns=['id']+['sequence']+['label']

        train['split'] = train.reset_index().apply(lambda row: 'train' if row['index'] < int(train.shape[0]*0.9) else 'val', axis=1)        

        records = SeqIO.parse(f'datasets/binding/{testset_name}_test.txt', 'fasta')
        seqList = []
        for record in records:
            desp = record.description
            data = record.seq
            data_len = len(data)//2
            seq = data[:data_len]
            label = data[data_len:]
            seqList.append([str(desp)] + [str(seq)] + [str(label)]+['test'])
            test = pd.DataFrame(seqList)
        test.columns=['id']+['sequence']+['label']+['split']

        df = pd.concat([train,test])
        return df



class ModelLoader:
    def __init__(self, config):
        self.config = config

    def load_model_tokenizer(self):
        if self.config.get('model_name').split('/')[0] == 'facebook':
            tokenizer = AutoTokenizer.from_pretrained(self.config.get('model_name')) 
        elif self.config.get('model_name').split('/')[0] == 'Rostlab':
            tokenizer = T5Tokenizer.from_pretrained(self.config.get('model_name')) 
        if self.config.get('model_name').split('/')[0] == 'facebook':
            model = AutoModel.from_pretrained(self.config.get('model_name'), output_attentions=True)
        elif self.config.get('model_name').split('/')[0] == 'Rostlab':
            model = T5EncoderModel.from_pretrained(self.config.get('model_name'), output_attentions=True)
        
        for _, param in list(model.named_parameters())[:]:
            param.requires_grad = False
        model.eval()
        model = model.to(self.config.get('device'))
        return model, tokenizer
    
    def load_model(self):
        model = AutoModel.from_pretrained(self.config.get('model_name'), output_attentions=True)
        for _, param in list(model.named_parameters())[:]:
            param.requires_grad = False
        model.eval()
        model = model.to(self.config.get('device'))
        return model

    def load_tokenizer(self):
        tokenizer = AutoTokenizer.from_pretrained(self.config.get('model_name'))
        return tokenizer
    

def setup_logger(config):
    '''Setup logger to write to console and file config.get('log_file').'''
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    
    # Setup console logging.
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    
    # Setup file logging as well if needed.
    if config.get('log_file'):
        fh = logging.FileHandler(config.get('log_file'))
        fh.setLevel(logging.INFO)
        fh.setFormatter(formatter)
        logger.addHandler(fh)
    
    return logger


def main(config_file):
    config = Config(config_file)
    logger = setup_logger(config)
    preprocessor = DataPreprocessor(config)

    df = preprocessor.load_data(config.get('testset_name'))

   
    if config.get('task') == 'calculate_sparse_attns':
        # Load the model
        model_loader = ModelLoader(config)
        model, tokenizer = model_loader.load_model_tokenizer()

        # Calculate the attention weights
        attention_calculator = AttentionCalculator(config, model, tokenizer, logger)  # pass the model here
        device = config.get('device')
        model_name = config.get('model_name').split('/')
        print(model_name)
        subset = config.get('subset')


    
        df=df[df.split==subset]
        sequences = df['sequence'].values
        labels = df['label'].values

        folder = f"data/binding/{config.get('testset_name')}/{config.get('model_name').split('/')[1]}"
        if not os.path.exists(folder):
            os.makedirs(folder) 

        attn_folder = f"{folder}/attns_with_cls"
        if not os.path.exists(attn_folder):
            os.makedirs(attn_folder)
        if not os.path.exists(f"{attn_folder}/{config.get('subset')}/"):
            os.makedirs(f"{attn_folder}/{config.get('subset')}/")

        embs_folder = f"{folder}/embs"
        if not os.path.exists(embs_folder):
            os.makedirs(embs_folder)
        if not os.path.exists(f"{embs_folder}/{config.get('subset')}/"):
            os.makedirs(f"{embs_folder}/{config.get('subset')}/")

        feat_folder = f"{folder}/features"
        if not os.path.exists(feat_folder):
            os.makedirs(feat_folder)
        if not os.path.exists(f"{feat_folder}/{config.get('subset')}/"):
            os.makedirs(f"{feat_folder}/{config.get('subset')}/")
        
        logger.info(f"Calculating attention weights and saving for {len(sequences)}..")
        last_embs = []
        all_embs = []
        np_labels = []
        lens = 0
        for i, seq in tqdm(enumerate(sequences), desc="Processing batches"): 
            if i >= 0:
                max_protein_len = len(seq)
                seq = [seq]
                if config.get('model_name').split('/')[0] == 'Rostlab':
                    seq = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in seq]
                    seq = ["<AA2fold>" + " " + s for s in seq]
                attn, np_embs, np_emb = attention_calculator.grab_attention_weights_and_emb_tokens_sparse(seq, max_protein_len)
                #(6, 20, 73, 73) (7, 71, 320)
                
                matrices = []
                layers, heads = range(attn.shape[0]), range(attn.shape[1])
                for (layer, head) in itertools.product(layers, heads): 
                    matrices.append(attn[layer, head])

                attn_np = np.concatenate(matrices)
                attn_sp = sp.coo_array(attn_np)
                out_file = f"{subset}_{config.get('testset_name')}_attns_{config.get('model_name').split('/')[1]}"
                filename = f"{attn_folder}/{config.get('subset')}/{out_file}_{i+1}_out_of_{len(sequences)}.npz"
                sp.save_npz(filename, attn_sp, compressed=True)
                
                last_embs.append(np_emb)
                all_embs.append(np_embs)
                
                np_label = np.array([int(float(x)) for x in labels[i]]) #
                if config.get('model_name').split('/')[0] == 'Rostlab':
                    np_label = np_label
                np_labels.append(np_label)
                
                if (i + 1) % 1000 == 0 or (i + 1) == len(sequences):
                    last_emb_np = np.concatenate(last_embs, axis = 0)
                    filename = f"{embs_folder}/{config.get('subset')}/{config.get('subset')}_last_embs_{config.get('testset_name')}_{config.get('model_name').split('/')[1]}_{i+1}.npy"
                    np.save(filename, last_emb_np, allow_pickle = True)
                    logger.info("Last layer embeddings saved.")

                    all_emb_np = np.concatenate(all_embs, axis = 1)
                    filename = f"{embs_folder}/{config.get('subset')}/{config.get('subset')}_all_embs_{config.get('testset_name')}_{config.get('model_name').split('/')[1]}_{i+1}.npy"
                    np.save(filename, all_emb_np, allow_pickle = True)
                    logger.info("All embeddings saved.")  

                    last_embs = []
                    all_embs = []
        logger.info("Attns saved.")      
        
        np_labels = np.concatenate(np_labels)
        filename = f"{feat_folder}/y_{config.get('subset')}_{config.get('testset_name')}.npy"
        np.save(filename, np_labels, allow_pickle = True) 
        logger.info("Labels saved.")
        
        last_embs = []
        all_embs = []
        for i in tqdm(range(len(sequences))):
            if i>=0: 
                if (i + 1) % 1000 == 0 or (i + 1) == len(sequences):                             
                    last_emb_filename = f"{embs_folder}/{config.get('subset')}/{config.get('subset')}_last_embs_{config.get('testset_name')}_{config.get('model_name').split('/')[1]}_{i+1}.npy"
                    last_emb_features = np.load(last_emb_filename, allow_pickle=True)
                    last_embs.append(last_emb_features)
                    
        last_emb_result = np.concatenate(last_embs, axis=0)        
        last_emb_file = f"{feat_folder}/{config.get('subset')}_{config.get('testset_name')}_last_embs_{config.get('model_name').split('/')[1]}.npy"
        np.save(last_emb_file, last_emb_result)
        print(last_emb_result.shape)   
        
        

    if config.get('task') == 'calculate_tda_features_from_sparce_matrix':
        # Load the model
        model_loader = ModelLoader(config)
        model, tokenizer = model_loader.load_model_tokenizer()

        # Calculate the attention weights
        attention_calculator = AttentionCalculator(config, model, tokenizer, logger)  # pass the model here
        #topo_feature_calculator = TDAFeatureCalculator(config)
        device = config.get('device')
        subset = config.get('subset')
        testset_name = config.get('testset_name')


        df=df[df.split==subset]
        sequences = df['sequence'].values

        threshold = config.get('threshold')
        num_layers, num_heads = config.get('num_layers'), config.get('num_heads')
        all_heads = num_layers * num_heads
        if config.get('sum'):
            sum_filename = '_sum'
        else:
            sum_filename = ''

        if config.get('with_vert'):
            with_vert = 'with_vert'
        else:
            with_vert = 'without_vert'

        folder = f"data/binding/{testset_name}/{config.get('model_name').split('/')[1]}/features"
        attn_folder = f"data/binding/{testset_name}/{config.get('model_name').split('/')[1]}/attns_with_cls"

        if config.get('method') == 1 or config.get('method') == 2:
            feat_folder = f"{folder}/{subset}/feat_method_{config.get('method')}_thr_{threshold}_{with_vert}_{config.get('attn')}"
        elif config.get('method') == 3 or config.get('method') == 4:
            feat_folder = f"{folder}/{subset}/feat_method_{config.get('method')}_{config.get('attn')}"

        if not os.path.exists(feat_folder):
            os.makedirs(feat_folder)

        seq_feats_sp_batch_list = []
        
        logger.info(f"Calculating topological features for {len(sequences)}..")
        for i, seq in tqdm(enumerate(sequences), desc="Processing batches"): 
            if i>=0: 
                out_file = f"{subset}_{testset_name}_{config.get('attn')}_{config.get('model_name').split('/')[1]}"
                filename = f"{attn_folder}/{subset}/{out_file}_{i+1}_out_of_{len(sequences)}.npz"
                attn = sp.load_npz(filename)
                attn = attn.astype(np.float32).todense() #.tocsr() #.todense()
                matrices = []
                layer_matrices = []
                num_tokens = attn.shape[1]
                for head in range(all_heads): 
                    head_feats = attn[head*num_tokens:(head+1)*num_tokens,:]
                    if config.get('attn') == 'attns':
                        head_feats = head_feats[1:-1,1:-1]
                    if config.get('sum'):
                        if head % num_heads == 0:
                            sum_head_feats = head_feats
                        else:
                            sum_head_feats = sum_head_feats + head_feats
                        if (head + 1) % num_heads == 0:
                            matrices.append(sum_head_feats)
                    else:                      
                        matrices.append(head_feats)               

                if config.get('method') == 1:
                    if with_vert == 'with_vert':
                        args = [(m, threshold) for m in matrices]
                    elif with_vert == 'without_vert':
                        args = [(m, threshold, True) for m in matrices]
                    stats_tuple_lists = Pool(config.get('num_workers')).starmap(get_features_from_attention_matrix_v1, tqdm(args))
                elif config.get('method') == 2:
                    if with_vert == 'with_vert':
                        args = [(m, threshold) for m in matrices]
                    elif with_vert == 'without_vert':
                        args = [(m, threshold, True) for m in matrices]                
                    stats_tuple_lists = Pool(config.get('num_workers')).starmap(get_features_from_attention_matrix_v21, tqdm(args))
                elif config.get('method') == 3:
                    args = [(m, ) for m in matrices]
                    stats_tuple_lists = Pool(config.get('num_workers')).starmap(get_features_from_attention_matrix_v31, tqdm(args))
                elif config.get('method') == 4:
                    args = [(m, ) for m in matrices]
                    stats_tuple_lists = Pool(config.get('num_workers')).starmap(get_features_from_attention_matrix_v4, tqdm(args))
                                          
                feats = []
                if config.get('method') == 1 or config.get('method') == 2:
                    for f_all, dgm_all in stats_tuple_lists:  
                        feats.append(np.array(f_all)[1:-1,:])      
                elif config.get('method') == 3 or config.get('method') == 4:
                    for f_all in stats_tuple_lists:   #, adj, adj_w
                        feats.append(np.array(f_all))    

                seq_feats = np.concatenate(feats, axis=1)
                seq_feats_sp = sp.coo_array(seq_feats) 
                seq_feats_sp_batch_list.append(seq_feats_sp)

                if (i + 1) % 1000 == 0 or (i + 1) == len(sequences):
                    seq_feats_sp_batch = sp.vstack(seq_feats_sp_batch_list)
                    logger.info("Saving batch...")
                    if config.get('method') == 1 or config.get('method') == 2:
                        out_file = f"{subset}_{testset_name}{sum_filename}_method_{config.get('method')}_{config.get('attn')}_thr_{threshold}_{with_vert}_{config.get('model_name').split('/')[1]}"
                        
                    if config.get('method') == 3 or config.get('method') == 4:
                        out_file = f"{subset}_{testset_name}{sum_filename}_method_{config.get('method')}_{config.get('attn')}_{config.get('model_name').split('/')[1]}"
                        
                    stats_filename = f"{feat_folder}/{out_file}_{i+1}.npz"
                    logger.info(f"Saving {stats_filename}: {seq_feats.shape}")
                    sp.save_npz(stats_filename, seq_feats_sp_batch, compressed=True)
                    logger.info("File saved!")     
                    seq_feats_sp_batch_list = []
        
        
        feats = []
        feats_masked = []
        for i in tqdm(range(len(sequences))):
            if i>=0: 
                if (i + 1) % 1000 == 0 or (i + 1) == len(sequences):  
                    if config.get('method') == 1 or config.get('method') == 2:
                        out_file = f"{subset}_{testset_name}{sum_filename}_method_{config.get('method')}_{config.get('attn')}_thr_{threshold}_{with_vert}_{config.get('model_name').split('/')[1]}"
                                            
                    if config.get('method') == 3 or config.get('method') == 4:
                        out_file = f"{subset}_{testset_name}{sum_filename}_method_{config.get('method')}_{config.get('attn')}_{config.get('model_name').split('/')[1]}"
                             
                    filename = f"{feat_folder}/{out_file}_{i+1}.npz"
                    features = np.load(filename, allow_pickle=True)
                    M, N = features['shape']
                    data = features['data']
                    row = features['row']
                    col = features['col'] 
                    #mtr = csr_matrix((data, (row, col))).toarray()#.todense()
                    mtr = np.array(sp.coo_array((data.astype(np.float32), (row, col)), shape=(M, N)).toarray(), dtype=np.float16) #.todense()
                    feats.append(mtr)

        result = np.concatenate(feats, axis=0)
        
        if config.get('method') == 1 or config.get('method') == 2:
            out_file = f"{subset}_{testset_name}{sum_filename}_method_{config.get('method')}_{config.get('attn')}_thr_{threshold}_{with_vert}_{config.get('model_name').split('/')[1]}"
       
        if config.get('method') == 3 or config.get('method') == 4:
            out_file = f"{subset}_{testset_name}{sum_filename}_method_{config.get('method')}_{config.get('attn')}_{config.get('model_name').split('/')[1]}"
 
        stats_file = f"{folder}/{out_file}"
        np.save(stats_file, result)
        print(result.shape)
        



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default='scripts/data_scripts/configs/config_data_binding.yaml')
    args = parser.parse_args()

    mp.set_start_method('spawn')

    main(args.config)

# configs/config_token.yaml
# data_path: datasets/conservtion/
# model_name: facebook/esm2_t33_650M_UR50D
# device: cuda:0
# task: calculate_sparse_attns
# batch: 32
# log_file: log.txt
# num_workers: 4
