import os
from omegaconf import OmegaConf, DictConfig
env = OmegaConf.load("configs/config.yaml").env
for k, v in env.items():
    os.environ[k] = v

import sys
import torch
import hydra
import logging
import warnings
from utils import *
import numpy as np
import pandas as pd
from tqdm import tqdm
from math import ceil
from stats_count import *
from grab_weights import grab_attention_weights
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel, AutoConfig

warnings.filterwarnings('ignore')
tqdm._instances.clear()
tqdm.monitor_interval = 0
sys.stdout.flush()
np.random.seed(42) # For reproducibility.
logging.basicConfig(level=logging.INFO, format="%(message)s")

@hydra.main(config_path=None, config_name="config", version_base=None)
def main(cfg: DictConfig):
    model_path = cfg.model.model_path
    device=torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    tokenizer = AutoTokenizer.from_pretrained(model_path, do_lower_case=True)
    model = AutoModel.from_pretrained(model_path, output_attentions=True, attn_implementation="eager")
    config = AutoConfig.from_pretrained(model_path)
    num_layers = getattr(config, "num_hidden_layers", None)
    
    model = model.to(device)
    os.makedirs(cfg.output_dir + f'/attentions', exist_ok=True)
    os.makedirs(cfg.output_dir + f'/features', exist_ok=True)

    max_tokens_amount = 128 # The number of tokens to which the tokenized text is truncated / padded.

    layers_of_interest = [i for i in range(num_layers)]  # Layers for which attention matrices and features on them are 
                                                # calculated. For calculating features on all layers, leave it be
                                                # [i for i in range(12)].
    
    r_file     = f'{cfg.output_dir}/attentions/all_heads_{str(len(layers_of_interest))}_layers_MAX_LEN_{str(max_tokens_amount)}_{cfg.data.name}'
    
    data = pd.read_csv(f'{cfg.data.input_dir}/{cfg.data.name}.csv').reset_index(drop=True)

    # Extraction
    batch_size = 4 # batch size
    DUMP_SIZE = 100 # number of batches to be dumped
    number_of_batches = ceil(len(data['sentence']) / batch_size)
    number_of_files = ceil(number_of_batches / DUMP_SIZE)
    batched_sentences = np.array_split(data['sentence'].values, number_of_batches)

    adj_matricies = []
    adj_filenames = []
    assert number_of_batches == len(batched_sentences) # sanity check
    
    for i in tqdm(range(number_of_batches), total=number_of_batches, desc="Weights calc", file=sys.stdout): 
        save_filename = f'{r_file}_part{str(ceil(i/DUMP_SIZE))}of{str(number_of_files)}.npy'
        attention_w = grab_attention_weights(model, tokenizer, batched_sentences[i], max_tokens_amount, device)
        # sample X layer X head X n_token X n_token
        adj_matricies.append(attention_w)
        if (i+1) % DUMP_SIZE == 0 and not os.path.exists(save_filename): # dumping
            print(f'Saving: shape {adj_matricies[0].shape}')
            adj_matricies = np.concatenate(adj_matricies, axis=1)
            print("Concatenated")
            adj_matricies = np.swapaxes(adj_matricies, axis1=0, axis2=1) # sample X layer X head X n_token X n_token
            print(f"Saving weights to : {save_filename}")
            adj_filenames.append(save_filename)
            np.save(save_filename, adj_matricies)
            adj_matricies = []
        
    if len(adj_matricies):
        save_filename = f'{r_file}_part{str(ceil(i/DUMP_SIZE))}of{str(number_of_files)}.npy'
        print(f'Saving: shape {adj_matricies[0].shape}')
        adj_matricies = np.concatenate(adj_matricies, axis=1)
        print("Concatenated")
        adj_matricies = np.swapaxes(adj_matricies, axis1=0, axis2=1) # sample X layer X head X n_token X n_token
        print(f"Saving weights to : {save_filename}")
        np.save(save_filename, adj_matricies)

    print("Results saved.")
    
    
if __name__ == "__main__":
    main()