import torch
from tqdm import tqdm
import argparse
import os

parser = argparse.ArgumentParser()
parser.add_argument("--arr_path", type=str, required=True)
parser.add_argument("--x_path", type=str)
parser.add_argument("--w_path", type=str)
parser.add_argument("--err_dir", type=str)
parser.add_argument("--yerr_dir", type=str)

args = parser.parse_args()
layer_count = 32
prec_arr = [3,4,5,6]

module_sa = ["q_proj","k_proj","v_proj","o_proj"]
module_mlp = ["gate_proj","up_proj","down_proj"]
module_arr = module_sa + module_mlp

arr_path = args.arr_path

try:
    th_arr, max_mem_dict = torch.load(arr_path, weights_only=False)
except:
    th_arr, _, _ = torch.load(arr_path, weights_only=False)
    max_mem_dict = {}
    for layer in tqdm(range(layer_count)):
        for name in module_arr:
            max_mem_dict[(layer, name)] = prec_arr[-1]

x_path = args.x_path
w_path = args.w_path
err_dir = args.err_dir
yerr_dir = args.yerr_dir
rsq_th = 0.9


# precomputed err dirs
for p_l in prec_arr[:-1]:
    for p_h in range(p_l+1, prec_arr[-1]+1):
        os.makedirs(f"{yerr_dir}/{p_l}-{p_h}/", exist_ok=True)

mid_dir = arr_path.split("/")[-1]
if len(mid_dir) == 0: mid_dir = arr_path.split("/")[-2]
mid_dir = mid_dir.split(".pt")[0]

corr_arr = torch.load(f"{err_dir}/{mid_dir}/corr_arr_{rsq_th}.pt", weights_only=False)
module_arr = [(l,n) for (l,n,_,_,_,_,_) in corr_arr]

torch.save(corr_arr, f"{err_dir}/corr_arr_{rsq_th}.pt")
os.system(f"rm {yerr_dir}/err")
os.system(f"ln -sf {err_dir}/{mid_dir} {yerr_dir}/err")

for l in tqdm(range(layer_count)):
    for name in module_arr:
        i = l*len(module_arr) + module_arr.index(name)
        th = th_arr[i]
        maxmem = max_mem_dict[(l, name)]
        if maxmem == 6:
            if th < -0.5:
                b_l = prec_arr[0]
                b_h = prec_arr[1]
                th = -th-0.5
            elif th < 0.5:
                b_l = prec_arr[1]
                b_h = prec_arr[2]
                th = -th+0.5
            else:
                b_l = prec_arr[2]
                b_h = prec_arr[3]
                th= -th+1.5
        elif maxmem == 5:
            if th < 0:
                b_l = prec_arr[0]
                b_h = prec_arr[1]
                th = -th
            else:
                b_l = prec_arr[1]
                b_h = prec_arr[2]
                th = 1-th
        elif maxmem == 4:
            b_l = prec_arr[0]
            b_h = prec_arr[1]
            th = 1-th
        elif maxmem == 3:
            b_l = prec_arr[0]
            b_h = prec_arr[0]
        
        if b_l != b_h:
            with torch.no_grad():
                if not os.path.isfile(f"{yerr_dir}/{b_l}-{b_h}/{l}_{name}_p.pt"):
                    # Without pow
                    xarr = torch.load(f"{x_path}/{l}_{name}_x_t_seq512_0-40.pt", weights_only=False)
                    x = torch.concat(xarr)[:,:,:].float()

                    w_l = torch.load(f"{w_path}/{b_l}b/{l}_{name}.pt", map_location=torch.device("cpu"), weights_only=False)
                    w_h = torch.load(f"{w_path}/{b_h}b/{l}_{name}.pt", map_location=torch.device("cpu"), weights_only=False)
                    e = (w_h-w_l).float()

                    err = x @ e.T

                    p = (err).norm(dim=-1).view(-1).float()
                    torch.save(p, f"{yerr_dir}/{b_l}-{b_h}/{l}_{name}_p.pt")
                else:
                    p = torch.load(f"{yerr_dir}/{b_l}-{b_h}/{l}_{name}_p.pt", weights_only=False)

            targ = p.quantile(th)
            if th > 0.99: targ[()] = torch.inf
            if th < 0.01: targ[()] = -torch.inf
        else:
            targ = torch.tensor(th)
        
        torch.save((b_l, b_h, targ), f"{yerr_dir}/{l}_{name}_targ.pt")

torch.save(max_mem_dict, f"{yerr_dir}/max_mem_dict.pt")
exit(0)
