import os
from omegaconf import OmegaConf, DictConfig
env = OmegaConf.load("configs/config.yaml").env
for k, v in env.items():
    os.environ[k] = v
    
import sys
import hydra
import logging
import warnings
import itertools

import numpy as np
import ripser_count
from utils import *
from tqdm import tqdm
from stats_count import *

from transformers import AutoConfig

warnings.filterwarnings('ignore')
tqdm._instances.clear()
tqdm.monitor_interval = 0
sys.stdout.flush()
np.random.seed(42) # For reproducibility.
logging.basicConfig(level=logging.INFO, format="%(message)s")

ripser_feature_names=[
    'h0_s', 
    'h0_e',
    'h0_t_d', 
    'h0_n_d_m_t0.75',
    'h0_n_d_m_t0.5',
    'h0_n_d_l_t0.25',
    'h1_t_b',
    'h1_n_b_m_t0.25',
    'h1_n_b_l_t0.95', 
    'h1_n_b_l_t0.70',  
    'h1_s',
    'h1_e',
    'h1_v',
    'h1_nb'
]

def get_only_barcodes(adj_matricies, ntokens_array, dim, lower_bound):
    """Get barcodes from adj matricies for each layer, head"""
    barcodes = {}
    layers, heads = range(adj_matricies.shape[1]), range(adj_matricies.shape[2])
    for (layer, head) in itertools.product(layers, heads):
        matricies = adj_matricies[:, layer, head, :, :]
        barcodes[(layer, head)] = ripser_count.get_barcodes(matricies, ntokens_array, dim, lower_bound, (layer, head))
    return barcodes

@hydra.main(config_path=None, config_name="config", version_base=None)
def main(cfg: DictConfig):
    model_path = cfg.model.model_path
    config = AutoConfig.from_pretrained(model_path)
    num_layers = getattr(config, "num_hidden_layers", None)
    
    layers_of_interest = [i for i in range(num_layers)]
    max_tokens_amount = 128

    r_file = f'{cfg.output_dir}/attentions/all_heads_{str(len(layers_of_interest))}_layers_MAX_LEN_{str(max_tokens_amount)}_{cfg.data.name}'
    ripser_file = f'{cfg.output_dir}/features/all_heads_{str(len(layers_of_interest))}_layers_MAX_LEN_{str(max_tokens_amount)}_{cfg.data.name}_ripser.npy'

    adj_filenames = [
        f'{cfg.output_dir}/barcodes/{filename}'
        for filename in os.listdir(f'{cfg.output_dir}/barcodes/') if r_file.split('/')[-1] == filename.split('_part')[0]
    ]
    adj_filenames = sorted(adj_filenames, key = lambda x: int(x.split('_')[-1].split('of')[0][4:].strip())) 

    features_array = []

    for filename in tqdm(adj_filenames, desc='Calculating ripser++ features', file=sys.stdout):
        barcodes = json.load(open(filename))
        print(f"Barcodes loaded from: {filename}", flush=True)
        features_part = []
        for layer in barcodes:
            features_layer = []
            for head in barcodes[layer]:
                ref_barcodes = reformat_barcodes(barcodes[layer][head])
                features = ripser_count.count_ripser_features(ref_barcodes, ripser_feature_names)
                features_layer.append(features)
            features_part.append(features_layer)
        features_array.append(np.asarray(features_part))

    features = np.concatenate(features_array, axis=2)
    np.save(ripser_file, features)

if __name__ == "__main__":
    main()