import os
from omegaconf import OmegaConf, DictConfig
env = OmegaConf.load("configs/config.yaml").env
for k, v in env.items():
    os.environ[k] = v

import sys
import hydra
import logging
import warnings
import numpy as np
import pandas as pd
from utils import *
from tqdm import tqdm
from stats_count import *
from multiprocessing import Pool
from transformers import AutoTokenizer, AutoConfig

warnings.filterwarnings('ignore')
tqdm._instances.clear()
tqdm.monitor_interval = 0
sys.stdout.flush()
np.random.seed(42) # For reproducibility.
logging.basicConfig(level=logging.INFO, format="%(message)s")

@hydra.main(config_path=None, config_name="config", version_base=None)
def main(cfg: DictConfig):
    
    thresholds_array = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75] # The set of thresholds
    thrs = len(thresholds_array)                           # ("t" in the paper)
    batch_size = 4 # batch size
    DUMP_SIZE = 100 # number of batches to be dumped
    MAX_LEN = max_tokens_amount = 128
    stats_cap = 500

    model_path = cfg.model.model_path
    tokenizer = AutoTokenizer.from_pretrained(model_path, do_lower_case=True)
    config = AutoConfig.from_pretrained(model_path)
    num_layers = getattr(config, "num_hidden_layers", None)
    layers_of_interest = [i for i in range(num_layers)]
    stats_file = f'{cfg.output_dir}/features/all_heads_{str(len(layers_of_interest))}_layers_{cfg.model.stats_name}_lists_array_{str(thrs)}_thrs_MAX_LEN_{str(max_tokens_amount)}_{cfg.data.name}.npy'

    data = pd.read_csv(f'{cfg.data.input_dir}/{cfg.data.name}.csv').reset_index(drop=True)
    data['tokenizer_length'] = get_token_length(data['sentence'].values, tokenizer, MAX_LEN)
    ntokens_array = data['tokenizer_length'].values
    r_file = f'{cfg.output_dir}/attentions/all_heads_{str(len(layers_of_interest))}_layers_MAX_LEN_{str(max_tokens_amount)}_{cfg.data.name}'    
    adj_filenames = [
        f'{cfg.output_dir}/attentions/{filename}'
        for filename in os.listdir(f'{cfg.output_dir}/attentions/') if r_file in (f'{cfg.output_dir}/attentions/{filename}')
    ]

    adj_filenames = sorted(adj_filenames, key = lambda x: int(x.split('_')[-1].split('of')[0][4:].strip())) 
    pool = Pool(cfg.model.num_of_workers)

    stats_tuple_lists_array = [] 
    for i, filename in tqdm(enumerate(adj_filenames),total=len(adj_filenames), desc="Feature calculation", file=sys.stdout):
        adj_matricies = np.load(filename, allow_pickle=True)
        ntokens = ntokens_array[i*batch_size*DUMP_SIZE : (i+1)*batch_size*DUMP_SIZE]
        splitted = split_matricies_and_lengths(adj_matricies, ntokens, cfg.model.num_of_workers)
        args = [(m, thresholds_array, ntokens, cfg.model.stats_name.split("_"), stats_cap) for m, ntokens in splitted]
        stats_tuple_lists_array_part = pool.starmap(
            count_top_stats, args
        )
        stats_tuple_lists_array.append(np.concatenate([_ for _ in stats_tuple_lists_array_part], axis=3))
    stats_tuple_lists_array = np.concatenate(stats_tuple_lists_array, axis=3)
    np.save(stats_file, stats_tuple_lists_array)
    print(stats_tuple_lists_array.shape)

if __name__ == "__main__":
    main()