import os
import pdb
import random

from typing import Dict, List, Tuple, Any, Optional, Union, Callable
import argparse

from tqdm import tqdm
import numpy as np
import torch

import sys
sys.path.append("../pp_experiment")
from utils import compute_topk_components


def get_circuit_top_heads(circuit_root_path: str, n_groupA=20, n_groupB=20, n_groupC=30, n_groupD=20):
    circuit = {}
    circuit["A"] = compute_topk_components(torch.tensor(np.load(f"{circuit_root_path}/pp_groupA.npy")), k=n_groupA, largest=False)
    circuit["B"] = compute_topk_components(torch.tensor(np.load(f"{circuit_root_path}/pp_groupB.npy")), k=n_groupB, largest=False)
    circuit["C"] = compute_topk_components(torch.tensor(np.load(f"{circuit_root_path}/pp_groupC.npy")), k=n_groupC, largest=False)
    circuit["D"] = compute_topk_components(torch.tensor(np.load(f"{circuit_root_path}/pp_groupD.npy")), k=n_groupD, largest=False)
    return circuit


def compare(path1: str, path2: str, n_groupA=20, n_groupB=10, n_groupC=30, n_groupD=30, print_heads: bool = True):
    c1 = get_circuit_top_heads(path1, n_groupA, n_groupB, n_groupC, n_groupD)
    c2 = get_circuit_top_heads(path2, n_groupA, n_groupB, n_groupC, n_groupD)

    for group in c1.keys():
        common_heads = [h for h in c1[group] if h in c2[group]]
        c1_only_heads = [h for h in c1[group] if h not in c2[group]]
        c2_only_heads = [h for h in c2[group] if h not in c1[group]]
        print(f"Group={group}, common_heads={len(common_heads)}/{len(c1[group])} ({(len(common_heads))/len(c1[group])*100:.1f}%)")
        if print_heads:
            print(f"C1 only heads: {c1_only_heads}")
            print(f"C2 only heads: {c2_only_heads}")
        
if __name__ == '__main__':
    p_1put_obj1 = "../outputs/nnsight_patch_1put/codellama-13b/logp_notLastObj/n200"
    p_1put_obj2 = "../outputs/nnsight_patch_1put/codellama-13b/logp_lastObjOnly/n200"
    #p_noop = "../outputs/nnsight_patch_noop/codellama-13b/logp/n200"
    compare(
        p_1put_obj1, p_1put_obj2,
        #n_groupA=20, n_groupB=10, n_groupC=30, n_groupD=30,
        n_groupA=120, n_groupB=60, n_groupC=80, n_groupD=80,
        print_heads=False
    )

