import os
from omegaconf import OmegaConf, DictConfig
env = OmegaConf.load("configs/config.yaml").env
for k, v in env.items():
    os.environ[k] = v
    
import sys
import hydra
import logging
import warnings
import numpy as np
import pandas as pd
from utils import *
from tqdm import tqdm
from stats_count import *
from multiprocessing import Pool
from grab_weights import text_preprocessing
from transformers import AutoTokenizer, AutoConfig

warnings.filterwarnings('ignore')
tqdm._instances.clear() 
tqdm.monitor_interval = 0
sys.stdout.flush()
np.random.seed(42) # For reproducibility.
logging.basicConfig(level=logging.INFO, format="%(message)s")

def get_list_of_ids(sentences, MAX_LEN, tokenizer):
    inputs = tokenizer.batch_encode_plus([text_preprocessing(s) for s in sentences],
                                       add_special_tokens=True,
                                       max_length=MAX_LEN,             # Max length to truncate/pad
                                       padding="max_length",        # Pad sentence to max length)
                                       truncation=True
                                      )
    return np.array(inputs['input_ids'])

@hydra.main(config_path=None, config_name="config", version_base=None)
def main(cfg: DictConfig):
    model_path = cfg.model.model_path
    MAX_LEN = 128
    tokenizer = AutoTokenizer.from_pretrained(model_path, do_lower_case=True)
    config = AutoConfig.from_pretrained(model_path)
    num_layers = getattr(config, "num_hidden_layers", None)
    layers_of_interest = [i for i in range(num_layers)]
    attention_dir = f'{cfg.output_dir}/attentions/'
    attention_name = f'all_heads_{str(len(layers_of_interest))}_layers_MAX_LEN_128_{cfg.data.name}'

    adj_filenames = [
        attention_dir + filename 
        for filename in os.listdir(attention_dir) 
        if attention_name == filename.split("_part")[0]
    ]
    adj_filenames = sorted(adj_filenames, key = lambda x: int(x.split('_')[-1].split('of')[0][4:].strip())) 
    pool = Pool(cfg.model.num_of_workers)
    feature_list = ['self', 'beginning', 'prev', 'next', 'comma', 'dot']

    data = pd.read_csv(f'{cfg.data.input_dir}/{cfg.data.name}.csv').reset_index(drop=True)

    features_array = []

    for i, filename in tqdm(list(enumerate(adj_filenames)), desc='Features calc', file=sys.stdout):
        print(f"Attentions loaded from: {filename}", flush=True)
        adj_matricies = np.load(filename, allow_pickle=True)
        batch_size = adj_matricies.shape[0]
        sentences = data['sentence'].values[i*batch_size:(i+1)*batch_size]
        splitted_indexes = np.array_split(np.arange(batch_size), cfg.model.num_of_workers)
        splitted_list_of_ids = [
            get_list_of_ids(sentences[indx], MAX_LEN, tokenizer ) 
            for indx in tqdm(splitted_indexes, desc=f"Calculating token ids on iter {i} from {len(adj_filenames)}", file=sys.stdout)
        ]
        splitted_adj_matricies = [adj_matricies[indx] for indx in splitted_indexes]
        
        args = [(m, feature_list, list_of_ids) for m, list_of_ids in zip(splitted_adj_matricies, splitted_list_of_ids)]
        
        features_array_part = pool.starmap(
            calculate_features_t, args
        )
        features_array.append(np.concatenate([_ for _ in features_array_part], axis=3))
    features_array = np.concatenate(features_array, axis=3)
    np.save(f'{cfg.output_dir}/features/{attention_name}_template.npy', features_array)
    
if __name__ == "__main__":
    main()