import argparse
import torch
torch.set_grad_enabled(False)
from utils import *

LOG10P_THRESHOLD = -23

parser = argparse.ArgumentParser()
parser.add_argument("--model_A_dir", type=str, required=True)
parser.add_argument("--model_B_dir", type=str, required=True)
parser.add_argument("--output_dir", type=str, default="./out")
parser.add_argument("--solve_mlp", type=int, default=1)
parser.add_argument("--solve_attn", type=int, default=1)
parser.add_argument("--solve_mla", type=int, default=0)
parser.add_argument("--num_layers", type=int, default=32)
parser.add_argument("--plot_all", type=int, default=1)
parser.add_argument("--plot_full", type=int, default=0)
parser.add_argument("--head_size", type=int, default=128)
parser.add_argument("--mlp_heuristic", type=int, default=1)
plot_list = list(range(100))

# try: python main_mdir.py --model_A_dir <your_model_A> --model_B_dir <your_model_B> --num_layers 10
args = parser.parse_args()

A_name = args.model_A_dir.split('/')[-1]
B_name = args.model_B_dir.split('/')[-1]
UU, tr, row_ind, col_ind, logp, log10p = vocab(args.model_A_dir, args.model_B_dir)
log10p = round(min(log10p, 0))
comment = f" (Trace {tr:.2f}, P value around 10^{log10p})"
print(comment, ", saving ...")
plot_matrix(UU, os.path.join(args.output_dir, "vocab.png"), 
            comments=comment, row=B_name, column=A_name, plot_full=args.plot_full)

if log10p < LOG10P_THRESHOLD:
    print("Identification successful!")
    U = torch.zeros_like(UU)
    U[row_ind, col_ind] = 1
else:
    print("MDIR identification unsuccessful. Model A and Model B are probably not homologous, please try other methods.")
    U = UU

if args.solve_mlp:
    for i in range(args.num_layers):
        ortho_sum = None
        W = {}
        for proj_type in ['up']:
            tensor_name = [f"model.layers.{i}.mlp.{proj_type}_proj.weight", f"model.layers.{i}.mlp.shared_expert.{proj_type}_proj.weight"]
        # for proj_type in ['key']:
        #     tensor_name = f"blocks.{i}.ffn.{proj_type}.weight"
            print(tensor_name)
            model_A_proj = read_alias(args.model_A_dir, tensor_name)
            model_B_proj = read_alias(args.model_B_dir, tensor_name)
            if proj_type in ['down']:
                model_A_proj = model_A_proj.T
                model_B_proj = model_B_proj.T
            W[proj_type] = polarize(model_B_proj @ U @ model_A_proj.T).T
            ortho_sum = W[proj_type].clone() if ortho_sum is None else ortho_sum + W[proj_type]
        # intermediate dim is very large! use heuristic
        if args.mlp_heuristic:
            P, row_selections, col_selections = linear_assignment_max_heuristic(ortho_sum)
        else:
            P, row_selections, col_selections = linear_assignment_max(ortho_sum)
        tr = float(W['up'][row_selections, col_selections].sum())
        logp = - tr**2 / 2 + lognfactorial(max(P.shape))
        log10p = logp / math.log(10)
        log10p = round(min(log10p, 0))
        comment = f" (Trace {tr:.2f}, P value around 10^{log10p})"
        print(comment, ", saving ...")
        if args.plot_all and i in plot_list:
            plot_matrix(W['up'].T, os.path.join(args.output_dir, f"model.layers.{i}.mlp.png"), 
                        comments=comment, row=B_name, column=A_name, plot_full=args.plot_full)

if args.solve_attn:
    for i in range(args.num_layers):
        for proj_type in ['v', 'k', 'q', 'o']:
            tensor_name = f"model.layers.{i}.self_attn.{proj_type}_proj.weight"
        # for proj_type in ['value', 'key', 'receptance', 'output']:
        #     tensor_name = f"blocks.{i}.att.{proj_type}.weight"
            model_A_proj = read_tensor(args.model_A_dir, tensor_name)
            model_B_proj = read_tensor(args.model_B_dir, tensor_name)
            if proj_type in ['o', 'output']:
                model_A_proj = model_A_proj.T
                model_B_proj = model_B_proj.T
            W = polarize(model_A_proj @ U.T @ model_B_proj.T)
            if proj_type in ['v', 'value']:
                perm_nkvhead, perm_headsize = reconstruct_permutation(W, bs=args.head_size)
                print({
                    "layer": i,
                    "tensor_name": tensor_name,
                    "perm_1": perm_nkvhead,
                    "perm_2": perm_headsize,
                })
                if args.plot_all:
                    plot_matrix(W.T, os.path.join(args.output_dir, f"{tensor_name}.png"),
                                 row=B_name, column=A_name, plot_full=args.plot_full)
            else:
                perm_1, perm_2 = reconstruct_permutation(W, bs=args.head_size)
                if proj_type == 'q':
                    expand_ratio = W.shape[0] // (len(perm_nkvhead) * len(perm_headsize))
                    if not (perm_1 == tensorprod_permlist(perm_nkvhead, list(range(expand_ratio)))
                            and perm_2 == list(range(args.head_size))):
                        print(f"warning: layer {i} matrix {proj_type}")
                if proj_type == 'k':
                    if not (perm_1 == perm_nkvhead
                            and perm_2 == list(range(args.head_size))):
                        print(f"warning: layer {i} matrix {proj_type}")
                if proj_type == 'o':
                    expand_ratio = W.shape[0] // (len(perm_nkvhead) * len(perm_headsize))
                    if not (perm_1 == tensorprod_permlist(perm_nkvhead, list(range(expand_ratio)))
                            and perm_2 == perm_headsize):
                        print(f"warning: layer {i} matrix {proj_type}")

if args.solve_mla: 
    for i in range(args.num_layers):
        for proj_type in ['o']:
            tensor_name = f"model.layers.{i}.self_attn.{proj_type}_proj.weight"
            model_A_proj = read_tensor(args.model_A_dir, tensor_name)
            model_B_proj = read_tensor(args.model_B_dir, tensor_name)
            if proj_type in ['o']:
                model_A_proj = model_A_proj.T
                model_B_proj = model_B_proj.T
            W = polarize(model_A_proj @ U.T @ model_B_proj.T)
            P, row_selections, col_selections = linear_assignment_max(W)
            tr = float(W[row_selections, col_selections].sum())
            logp = - tr**2 / 2 + lognfactorial(max(P.shape))
            log10p = logp / math.log(10)
            log10p = round(min(log10p, 0))
            comment = f" (Trace {tr:.2f}, P value around 10^{log10p})"
            print(comment, ", saving ...")
            if args.plot_all:
                plot_matrix(W.T, os.path.join(args.output_dir, f"{tensor_name}.png"), 
                            comments=comment, row=B_name, column=A_name, plot_full=args.plot_full)
            # else:
            #     perm_1, perm_2 = reconstruct_permutation(W, bs=args.head_size)
            #     if proj_type == 'q':
            #         expand_ratio = W.shape[0] // (len(perm_nkvhead) * len(perm_headsize))
            #         if not (perm_1 == tensorprod_permlist(perm_nkvhead, list(range(expand_ratio)))
            #                 and perm_2 == list(range(args.head_size))):
            #             print(f"warning: layer {i} matrix {proj_type}")
            #     if proj_type == 'k':
            #         if not (perm_1 == perm_nkvhead
            #                 and perm_2 == list(range(args.head_size))):
            #             print(f"warning: layer {i} matrix {proj_type}")
            #     if proj_type == 'o':
            #         expand_ratio = W.shape[0] // (len(perm_nkvhead) * len(perm_headsize))
            #         if not (perm_1 == tensorprod_permlist(perm_nkvhead, list(range(expand_ratio)))
            #                 and perm_2 == perm_headsize):
            #             print(f"warning: layer {i} matrix {proj_type}")
