import numpy as np
from tqdm import tqdm
from multiprocessing import Pool
import os
from .stats_count import count_top_stats

class TDAFeatureCalculator:
    def __init__(self, config):
        self.config = config
        self.num_of_workers = config.get('num_of_workers')
        self.batch_size = config.get('batch_size')

    def process_file(self, args):
        filename, ntokens_array = args
        idx = int(filename.split("_")[-1].split(".")[0]) # filename: 'att_deeploc_val_1.npy'
        adj_matricies = np.load(os.path.join(self.config.get('att_mat_dir'), filename), allow_pickle=True)
        # add batch dimension to adj_matricies
        adj_matricies = np.expand_dims(adj_matricies, axis=0)
        ntokens = np.expand_dims(np.array(ntokens_array[idx]), axis=0)
        args = (adj_matricies, 
                self.config.get('thresholds_array'), 
                ntokens, 
                self.config.get('stats_name'), 
                self.config.get('stats_cap')) 
        stats_tuple_lists_array_part = count_top_stats(*args)
        return stats_tuple_lists_array_part

    def split_into_batches(self, list):
        return [list[i:i + self.batch_size] for i in range(0, len(list), self.batch_size)]

    def split_matricies_and_lengths(adj_matricies, ntokens_array, num_of_workers):
        splitted_adj_matricies = np.array_split(adj_matricies, num_of_workers)
        splitted_ntokens = np.array_split(ntokens_array, num_of_workers)
        assert all([len(m)==len(n) for m, n in zip(splitted_adj_matricies, splitted_ntokens)]), "Split is not valid!"
        return zip(splitted_adj_matricies, splitted_ntokens)
    
    def save_matrices(self, adj_matrix, file_index):
        filename = f"{self.config.get('attention_file_prefix')}_{str(int(file_index) + 1)}.npy"
        output_dir = self.config.get('outdir')
        filepath = os.path.join(output_dir, filename)
        if os.path.isfile(filepath):
            return filename
        np.save(filepath, adj_matrix) # (6, 20, 1024, 1024) (layer, head, n_token, n_token)
        return filename

    def calculate_and_save(self, adj_filenames, ntokens_array, stats_file_prefix):
        batch_results_list = []
        with Pool(self.num_of_workers) as pool:
            batches = self.split_into_batches(adj_filenames)
            for batch in tqdm(batches, desc='Processing batches'):
                args_list = [(filename, ntokens_array) for filename in batch]
                batch_results = pool.map(self.process_file, args_list)
                batch_results_list.extend(batch_results)
        stats_tuple_lists_array = np.concatenate(batch_results_list, axis=3)
        print(f"stats_tuple_lists_array.shape: {stats_tuple_lists_array.shape}")
        
        np.save(f"{self.config.get('out_file_prefix')}.npy", stats_tuple_lists_array)

    def calculate_and_save_per_batch(self, adj_filenames, ntokens_array, stats_file_prefix):
        batch_results_list = []
        with Pool(self.num_of_workers) as pool:
            batches = self.split_into_batches(adj_filenames)
            for batch in tqdm(batches, desc='Processing batches'):
                args_list = [(filename, ntokens_array) for filename in batch]
                batch_results = pool.map(self.process_file, args_list)
                batch_results_list.extend(batch_results)
        stats_tuple_lists_array = np.concatenate(batch_results_list, axis=3)
        print(f"stats_tuple_lists_array.shape: {stats_tuple_lists_array.shape}")
        
        np.save(f"{self.config.get('out_file_prefix')}.npy", stats_tuple_lists_array)




# Layers x Heads x Features x Samples x Thresholds : 6 x 20 x 6 x 14000 x 6