from typing import Optional, Dict, Any
from pathlib import Path
import json

import torch
import numpy as np

from patching_gemma import logger

def prune_heads(prune_using_imp_scores, prune_k, 
          num_layers, num_heads, num_token_types,
          log_dir, is_split_by_token_type) -> Optional[Dict[int, Dict[str, Any]]]:
    if prune_using_imp_scores is not None:
        assert prune_k is not None, "if you want to prune, provide prune_k also"
        assert len(prune_using_imp_scores) == len(prune_k), "provide as many k as the imp scores to prune according to"
        assert len(prune_k) < 2 or all([prune_k[i - 1] > prune_k[i] for i in range(1, len(prune_k))]), f"prune_k must be descending array, not it is {prune_k}"
        if is_split_by_token_type:
            component_names = [("attn", layer, head, token_type) for layer in range(num_layers) for head in range(num_heads) for token_type in range(num_token_types)]
        else:
            component_names = [("attn", layer, head) for layer in range(num_layers) for head in range(num_heads)]
        important_component_inds = torch.ones(len(component_names))
        for imp_score_iter, (imp_scores_path, k) in enumerate(zip(prune_using_imp_scores, prune_k)):
            attn_imp_scores_for_pruning = np.load(str(Path(imp_scores_path).joinpath("attn_scores.npy")))
            all_scores = attn_imp_scores_for_pruning.reshape(-1)
            all_scores_inds_sorted = np.argsort(-all_scores)
            pruning_ind = torch.argmax((torch.cumsum(important_component_inds[all_scores_inds_sorted], dim=0) == k).to(torch.int32)).item()
            if pruning_ind != k - 1:
                logger.debug(f"pruning_ind={pruning_ind}, k={k} for iter={imp_score_iter}")
            important_component_inds[all_scores_inds_sorted[pruning_ind + 1:]] = 0
        if len(prune_k) > 0:
            assert min(prune_k) == important_component_inds.sum().item(), f"something went wrong during pruning, should have left {min(prune_k)} actually left {important_component_inds.sum()}"
        else:
            assert important_component_inds.sum().item() == len(component_names), f"something went wrong during pruning, should have left {len(component_names)} actually left {important_component_inds.sum()}"
        important_components = [component_names[i] for i in range(len(important_component_inds)) if important_component_inds[i]]
        if is_split_by_token_type:
            important_components_by_layer = {
                layer: {
                    "attn": set([(component[2], component[3]) for component in important_components if component[1] == layer and component[0] == "attn"]),
                }
                for layer in range(num_layers)
            }
        else:
            important_components_by_layer = {
                layer: {
                    "attn": set([component[2] for component in important_components if component[1] == layer and component[0] == "attn"]),
                }
                for layer in range(num_layers)
            }
        Path(log_dir).mkdir(parents=True, exist_ok=True)
        with open(Path(log_dir).joinpath("important_components_by_layer.json"), "w") as file:
            json.dump({layer: { "attn": list(important_components_by_layer[layer]["attn"])} for layer in range(num_layers)},
                        file, indent=4)
        return important_components_by_layer
    return None