import torch
import matplotlib.pyplot as plt
import numpy as np
import os
import scipy
from tqdm import tqdm
import argparse
import math
from torch.optim import SGD

torch.random.manual_seed(0)

parser = argparse.ArgumentParser()
parser.add_argument("--arr_path", type=str, required=True)
parser.add_argument("--k", type=int, default=64)
parser.add_argument("--iterations", type=int, default=10000)
parser.add_argument("--k_async", type=int, default=None)
parser.add_argument("--x_path", type=str)
parser.add_argument("--w_path", type=str)
parser.add_argument("--err_dir", type=str)
parser.add_argument("--yerr_dir", type=str)

args = parser.parse_args()

layer_count = 32
prec_arr = [3,4,5,6]

module_sa = ["q_proj","k_proj","v_proj","o_proj"]
module_mlp = ["gate_proj","up_proj","down_proj"]
module_arr = module_sa + module_mlp

arr_path = args.arr_path
try:
    th_arr, max_mem_dict = torch.load(arr_path)
except:
    th_arr, _, _ = torch.load(arr_path)
    max_mem_dict = {}
    for layer in tqdm(range(layer_count)):
        for name in module_arr:
            max_mem_dict[(layer, name)] = prec_arr[-1]
iterations = args.iterations


x_path = args.x_path
w_path = args.w_path
err_dir = args.err_dir
yerr_dir = args.yerr_dir

mid_dir = arr_path.split("/")[-1]
if len(mid_dir) == 0: mid_dir = arr_path.split("/")[-2]
mid_dir = mid_dir.split(".pt")[0]

os.makedirs(f"{err_dir}/{mid_dir}", exist_ok=True)


corr_arr = []
rsq_th = 0.9
k=args.k
k_async = args.k_async

if not os.path.isfile(f"{err_dir}/{mid_dir}/corr_arr_{rsq_th}.pt"):
    maxcor = -1
    maxname = ""
    for layer in tqdm(range(layer_count)):
        for name in module_arr:
            i = layer*len(module_arr) + module_arr.index(name)
            p = th_arr[i]
            maxmem = max_mem_dict[(layer, name)]
            if maxmem == 6:
                if p < -0.5:
                    b_l = prec_arr[0]
                    b_h = prec_arr[1]
                elif p < 0.5:
                    b_l = prec_arr[1]
                    b_h = prec_arr[2]
                else:
                    b_l = prec_arr[2]
                    b_h = prec_arr[3]
            elif maxmem == 5:
                if p < 0:
                    b_l = prec_arr[0]
                    b_h = prec_arr[1]
                else:
                    b_l = prec_arr[1]
                    b_h = prec_arr[2]
            elif maxmem == 4:
                b_l = prec_arr[0]
                b_h = prec_arr[1]
            elif maxmem == 3:
                b_l = prec_arr[0]
                b_h = prec_arr[0]

            xarr = torch.load(f"{x_path}/{layer}_{name}_x_t_seq512_0-40.pt")
            x = torch.cat(xarr).float().cpu()

            if b_l != b_h:
                os.makedirs(f"{yerr_dir}/{b_l}-{b_h}", exist_ok=True)
                if os.path.isfile(f"{yerr_dir}/{b_l}-{b_h}/{layer}_{name}_p.pt"):
                    yerr_n = torch.load(f"{yerr_dir}/{b_l}-{b_h}/{layer}_{name}_p.pt")
                else:
                    w_l = torch.load(f"{w_path}/{b_l}b/{layer}_{name}.pt", map_location=torch.device("cpu"))
                    w_h = torch.load(f"{w_path}/{b_h}b/{layer}_{name}.pt", map_location=torch.device("cpu"))
                    err = (w_h-w_l).float().T
                    yerr = x @ err
                    yerr_n = yerr.norm(dim=-1).view(-1)
                    torch.save(yerr_n, f"{yerr_dir}/{b_l}-{b_h}/{layer}_{name}_p.pt")

                xlist = x.norm(dim=-1).view(-1).tolist()
                ylist = yerr_n.tolist()
                slope, intercept, r_value,_,_ = scipy.stats.linregress(xlist, ylist)
                if r_value ** 2 >= rsq_th:
                    corr_arr.append((layer, name, slope, intercept, r_value, b_l, b_h))
                    if r_value **2 > maxcor:
                        maxcor = r_value **2
                        maxname = f"{layer}_{name}"

    print(f"{len(corr_arr)}/{layer_count*7}")
    print(f"max: {maxname}, {maxcor}")
    torch.save(corr_arr, f"{err_dir}/{mid_dir}/corr_arr_{rsq_th}.pt")
else:
    print(f"Skipping corr as {err_dir}/{mid_dir}/corr_arr_{rsq_th}.pt exists")


corr_arr = [(l,n) for (l,n,_,_,_,b_l,b_h) in corr_arr]

for l in tqdm(range(layer_count)):
    for name in module_arr:
        i = l*len(module_arr) + module_arr.index(name)
        if (l,name) in corr_arr:
            continue

        p = th_arr[i]
        maxmem = max_mem_dict[(l, name)]
        if maxmem == 6:
            if p < -0.5:
                b_l = prec_arr[0]
                b_h = prec_arr[1]
            elif p < 0.5:
                b_l = prec_arr[1]
                b_h = prec_arr[2]
            else:
                b_l = prec_arr[2]
                b_h = prec_arr[3]
        elif maxmem == 5:
            if p < 0:
                b_l = prec_arr[0]
                b_h = prec_arr[1]
            else:
                b_l = prec_arr[1]
                b_h = prec_arr[2]
        elif maxmem == 4:
            b_l = prec_arr[0]
            b_h = prec_arr[1]
        elif maxmem == 3:
            b_l = prec_arr[0]
            b_h = prec_arr[0]
        now_k = k
        if k_async is not None and name not in ["o_proj", "down_proj"]:
            now_k = k_async

        if b_l != b_h:
            if os.path.isfile(f"{yerr_dir}/{b_l}-{b_h}/{l}_{name}_jl_finetuned_{now_k}.pt"):
                PiE = torch.load(f"{yerr_dir}/{b_l}-{b_h}/{l}_{name}_jl_finetuned_{now_k}.pt")
            else:
                w_l = torch.load(f"{w_path}/{b_l}b/{l}_{name}.pt", map_location=torch.device("cpu"))
                w_h = torch.load(f"{w_path}/{b_h}b/{l}_{name}.pt", map_location=torch.device("cpu"))
                xarr = torch.load(f"{x_path}/{l}_{name}_x_t_seq512_0-40.pt")
                x = torch.cat(xarr).half().cuda().detach()
                err = (w_h-w_l).cuda()


                real_err = (x @ err.T).norm(dim=-1).detach()
                G = torch.normal(0.0, 1.0, (now_k, err.size(0))).to(err.device)
                Pi = G / math.sqrt(now_k)
                PiE = Pi @ err.float()

                PiE = torch.nn.Parameter(PiE.half().cuda().detach())

                opt = SGD([PiE], lr=1e-3)

                for train_i in range(iterations):
                    opt.zero_grad()
                    jl_err = (x @ PiE.T).norm(dim=-1)
                    loss = ((jl_err-real_err)**2).mean()
                    loss.backward()
                    opt.step()

                torch.save(PiE.clone().cpu(), f"{yerr_dir}/{b_l}-{b_h}/{l}_{name}_jl_finetuned_{now_k}.pt")

            torch.save(PiE.half().cpu(), f"{err_dir}/{mid_dir}/{l}_{name}_jl.pt")
        else:
            w_l = torch.load(f"{w_path}/{b_l}b/{l}_{name}.pt", map_location=torch.device("cpu"))
            torch.save(torch.zeros(now_k, w_l.shape[1]), f"{err_dir}/{mid_dir}/{l}_{name}_jl.pt")