import os
from omegaconf import OmegaConf, DictConfig
env = OmegaConf.load("configs/config.yaml").env
for k, v in env.items():
    os.environ[k] = v
    
import sys
import hydra
import logging
import warnings
import itertools
import ripser_count
import numpy as np
import pandas as pd
from tqdm import tqdm
from utils import *
from stats_count import *
from multiprocessing import Process, Queue
from transformers import AutoTokenizer, AutoConfig

warnings.filterwarnings('ignore')
tqdm._instances.clear() 
tqdm.monitor_interval = 0
sys.stdout.flush()
np.random.seed(42) # For reproducibility.
logging.basicConfig(level=logging.INFO, format="%(message)s")

def get_only_barcodes(adj_matricies, ntokens_array, dim, lower_bound):
    """Get barcodes from adj matricies for each layer, head"""
    barcodes = {}
    layers, heads = range(adj_matricies.shape[1]), range(adj_matricies.shape[2])
    for (layer, head) in itertools.product(layers, heads):
        matricies = adj_matricies[:, layer, head, :, :]
        barcodes[(layer, head)] = ripser_count.get_barcodes(matricies, ntokens_array, dim, lower_bound, (layer, head))
    return barcodes

@hydra.main(config_path=None, config_name="config", version_base=None)
def main(cfg: DictConfig):
    model_path = cfg.model.model_path
    config = AutoConfig.from_pretrained(model_path)
    num_layers = getattr(config, "num_hidden_layers", None)
    
    layers_of_interest = [i for i in range(num_layers)]
    MAX_LEN = max_tokens_amount = 128
    os.makedirs(cfg.output_dir + f'/barcodes', exist_ok=True)

    r_file = f'{cfg.output_dir}/attentions/all_heads_{str(len(layers_of_interest))}_layers_MAX_LEN_{str(max_tokens_amount)}_{cfg.data.name}'    
    barcodes_file = f'{cfg.output_dir}/barcodes/all_heads_{str(len(layers_of_interest))}_layers_MAX_LEN_{str(max_tokens_amount)}_{cfg.data.name}'

    adj_filenames = [
        f'{cfg.output_dir}/attentions/{filename}'
        for filename in os.listdir(f'{cfg.output_dir}/attentions/') if r_file in (f'{cfg.output_dir}/attentions/{filename}')
    ]
    adj_filenames = sorted(adj_filenames, key = lambda x: int(x.split('_')[-1].split('of')[0][4:].strip())) 

    tokenizer = AutoTokenizer.from_pretrained(model_path, do_lower_case=True)
    data = pd.read_csv(f'{cfg.data.input_dir}/{cfg.data.name}.csv').reset_index(drop=True)
    data['tokenizer_length'] = get_token_length(data['sentence'].values, tokenizer, MAX_LEN)
    ntokens_array = data['tokenizer_length'].values
    
    dim = 1
    lower_bound = 1e-3
    batch_size = 4 # batch size
    DUMP_SIZE = 100 # number of batches to be dumped

    queue = Queue()
    number_of_splits = 2
    for i, filename in enumerate(tqdm(adj_filenames,total=len(adj_filenames), desc='Calculating barcodes', file=sys.stdout)):
        barcodes = defaultdict(list)
        adj_matricies = np.load(filename, allow_pickle=True) # samples X 
        print(f"Matricies loaded from: {filename}")
        ntokens = ntokens_array[i*batch_size*DUMP_SIZE : (i+1)*batch_size*DUMP_SIZE]
        splitted = split_matricies_and_lengths(adj_matricies, ntokens, number_of_splits)
        for matricies, ntokens in tqdm(splitted, leave=False, file=sys.stdout):
            p = Process(
                target=subprocess_wrap,
                args=(
                    queue,
                    get_only_barcodes,
                    (matricies, ntokens, dim, lower_bound)
                )
            )
            p.start()
            barcodes_part = queue.get() # block until putted and get barcodes from the queue
    #         print("Features got.")
            p.join() # release resources
    #         print("The process is joined.")
            p.close() # releasing resources of ripser
    #         print("The proccess is closed.")
            
            barcodes = unite_barcodes(barcodes, barcodes_part)
        part = filename.split('_')[-1].split('.')[0]
        save_barcodes(barcodes, f'{barcodes_file}_{part}.json')
        
if __name__ == "__main__":
    main()