import pdb
import json
import math
import random
import argparse
from collections import defaultdict


import numpy as np
import torch

import sys

sys.path.append("../pp_experiment")
from utils import get_model_and_tokenizer, load_dataloader, fix_random_seed, get_mean_activations, get_circuit, get_head_significance_score, compute_pair_drop_values
from run_patching import build_parser, post_arg_parse_fix, get_model_and_dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

idx_to_group = {
    0: "struct_reader",
    1: "pos_transmitter",
    2: "pos_detector",
    3: "value_fetcher",
}
idx_to_pos = {0: -1, 1: 0, 2: 2, 3: 0}
minimal_circuit = defaultdict(list)


def minimality_main(args):
    """
    Computes the minimality scores for the heads in the model

    Args:
        circuit_root_path (str): path to the circuit components.
        percentage (float): percentage of heads to consider.
        minimality_threshold (float): threshold for minimality.
    """
    dataloader, dataset, model = get_model_and_dataset(args)
    
    # mean activation data also needs to be loaded filtered by operation orders
    mean_activations, modules = get_mean_activations(model=model, args=args)

    (
        circuit_components,
        value_fetcher_heads,
        pos_trans_heads,
        pos_detect_heads,
        struct_read_heads,
    ) = get_circuit(
        model=model,
        circuit_root_path=args.circuit_root_path,
        n_value_fetcher=args.n_groupA,
        n_pos_trans=args.n_groupB,
        n_pos_detect=args.n_groupC,
        n_struct_read=args.n_groupD,
    )

    print(f"Value Fetcher Heads: {len(value_fetcher_heads)}")
    print(f"Position Transmitter Heads: {len(pos_trans_heads)}")
    print(f"Position Detector Heads: {len(pos_detect_heads)}")
    print(f"Structure Reader Heads: {len(struct_read_heads)}")

    print("Started Computing Minimality Scores...")
    for idx, head_group in enumerate(
        [struct_read_heads, pos_trans_heads, pos_detect_heads, value_fetcher_heads]
    ):
        print(f"{idx_to_group[idx]} Heads Started...")
        data = compute_pair_drop_values(
            model=model,
            heads=head_group,
            circuit_components=circuit_components,
            dataloader=dataloader,
            modules=modules,
            mean_activations=mean_activations,
            rel_pos=idx_to_pos[idx],
        )
        with open(f"{args.ouptput_dir}/{idx_to_group[idx]}.json", "w+", encoding="utf-8") as f:
            json.dump(data, f)

        ranked = defaultdict(list)
        for k_1 in data:
            for k_2 in data[k_1]:
                ranked[k_1].append((k_2, data[k_2][k_2] - data[k_1][k_2]))
        for k_1 in ranked:
            ranked[k_1].sort(key=(lambda x: x[1]), reverse=True)

        res = get_head_significance_score(
            model=model,
            heads=head_group,
            ranked=ranked,
            percentage=percentage,
            circuit_components=circuit_components,
            dataloader=dataloader,
            modules=modules,
            mean_activations=mean_activations,
            rel_pos=idx_to_pos[idx],
        )
        new_res = {}
        for k, v in res.items():
            new_res[str(k)] = v

        k = math.ceil(percentage * len(head_group))
        with open(f"{results_path}/{idx_to_group[idx]}_{k}_significance.json", "w+", encoding="utf-8") as f:
            json.dump(new_res, f)

        print(f"{idx_to_group[idx]} Heads Completed...")

        # Selecting heads with minimality score greater than threshold
        for k in new_res:
            if new_res[k][0] / new_res[k][1] - 1 >= minimality_threshold:
                if model.config.architectures[0].isin(["LlamaForCausalLM", "Gemma2ForCausalLM"]):
                    head = [int(k.split(".")[2]), int(k.split(",")[1][1:-1])]
                else:
                    head = [int(k.split(".")[4]), int(k.split(",")[1][1:-1])]
                minimal_circuit[idx_to_group[idx]].append(head)

    with open(f"{results_path}/{model_name}_circuit.json", "w", encoding="utf-8") as f:
        json.dump(minimal_circuit, f)

    print("Minimal Circuit Computed...")


def add_args(parser: argparse.ArgumentParser):
    """
    circuit_root_path: str = "../outputs/nnsight_patch_no_op/gemma-2-2b/n200",
    percentage: float = 0.3,
    minimality_threshold: float = 0.01,
    """
    parser.add_argument('--circuit_root_path', help='where circuit info dir lives', type=str, default="../outputs/nnsight_patch_no_op/gemma-2-2b/n200")
    parser.add_argument('--percentage', help='top percentage of heads to form each head group subset', type=float, default=0.3)
    parser.add_argument('--minimality_threshold', help='', type=float, default=0.01)
    return parser

if __name__ == "__main__":
    parser = build_parser()
    parser = add_args(parser)
    args = parser.parse_args()
    print(f"ARGS: {args}")
    post_arg_parse_fix(args)
    fix_random_seed(args.seed)
    minimality_main(args)