from typing import Optional, Dict, Any
from pathlib import Path
import json
from collections import defaultdict, deque

import torch
import numpy as np

from patching_gemma import logger

def prune_edges_inside_attn(prune_using_imp_scores, prune_k,
                            edges_to_prune_for_sure,
                            num_layers, num_heads, num_token_types,
                            log_dir, is_split_by_token_type) -> Optional[Dict[int, Dict[str, Any]]]:
    assert prune_using_imp_scores is not None
    assert prune_k is not None, "if you want to prune, provide prune_k also"
    assert len(prune_using_imp_scores) == len(prune_k), "provide as many k as the imp scores to prune according to"
    if len(prune_k) == 0:
        return None
    assert len(prune_k) < 2 or all([prune_k[i - 1] > prune_k[i] for i in range(1, len(prune_k))]), f"prune_k must be descending array, not it is {prune_k}"
    if not is_split_by_token_type:
        raise NotImplementedError("No implementation for pruning without split by token type")

    important_edge_inds = torch.ones(
                                        num_layers,
                                        num_heads,
                                        num_token_types,
                                        num_token_types
                        )
    important_edge_inds_flattened = np.ones_like(important_edge_inds.reshape(-1))
    for imp_score_iter, (imp_scores_path, k) in enumerate(zip(prune_using_imp_scores, prune_k)):
        imp_scores_for_pruning = np.load(str(Path(imp_scores_path).joinpath(f"inside_attn.npy")))
        for edge in edges_to_prune_for_sure:
            imp_scores_for_pruning[edge[0], edge[1], edge[2], edge[3]] = 0.
        imp_scores_for_pruning_flattened = imp_scores_for_pruning.reshape(-1)
        logger.info(f"On edge pruning iter {imp_score_iter} we have {(imp_scores_for_pruning_flattened == 0).sum()} edges with zero importance scores ({(imp_scores_for_pruning_flattened != 0).sum()} with non-zero)")
        assert (imp_scores_for_pruning_flattened >= 0).all()
        imp_scores_for_pruning_flattened_sorted = np.argsort(-imp_scores_for_pruning_flattened)
        assert imp_scores_for_pruning_flattened.shape == important_edge_inds_flattened.shape
        pruning_ind = np.argmax((np.cumsum(important_edge_inds_flattened[imp_scores_for_pruning_flattened_sorted], axis=0) >= k))
        if pruning_ind != k - 1:
            logger.debug(f"pruning_ind={pruning_ind}, k={k} for iter={imp_score_iter}")
        important_edge_inds_flattened[imp_scores_for_pruning_flattened_sorted[pruning_ind + 1:]] = 0
        logger.info(f"After edge pruning iter {imp_score_iter} we have {(imp_scores_for_pruning_flattened[important_edge_inds_flattened.astype(bool)] == 0).sum()} edges with 0 importance (out of {imp_scores_for_pruning_flattened.shape[0]} total, {important_edge_inds_flattened.sum()} important)")
        logger.info(f"After edge pruning iter {imp_score_iter} we have {(imp_scores_for_pruning_flattened[important_edge_inds_flattened.astype(bool)] != 0).sum()} edges with non-zero importance (out of {imp_scores_for_pruning_flattened.shape[0]} total, {important_edge_inds_flattened.sum()} important)")
    assert min(prune_k) == important_edge_inds_flattened.sum(), f"something went wrong during pruning, should have left {min(prune_k)} actually left {important_edge_inds_flattened.sum()}"
    important_edge_inds = important_edge_inds_flattened.reshape(important_edge_inds.shape)
    important_edge_inds = torch.tensor(important_edge_inds)

    Path(log_dir).joinpath("important_edges").mkdir(parents=True, exist_ok=True)
    with open(Path(log_dir).joinpath("important_edges").joinpath("inside_attn.json"), "w") as file:
        json.dump((important_edge_inds == 1).nonzero(as_tuple=False).tolist(), file, indent=4)
    return {"inside_attn": important_edge_inds}