import torch
import matplotlib.pyplot as plt
import numpy as np
import os
import scipy
from transformers import AutoTokenizer, LlamaForCausalLM
from tqdm import tqdm
from datasets import load_dataset
from torch.utils.data import DataLoader
import argparse

parser = argparse.ArgumentParser()
# parser.add_argument("--targ_bits", type=float, default=3.5)
parser.add_argument("--targ_bits", type=float, default=None, nargs="+")
parser.add_argument("--maxmem", type=float, default=None)
parser.add_argument("--hessian_path", type=str)
parser.add_argument("--w_path", type=str)
parser.add_argument("--save_dir", type=str)

args = parser.parse_args()


layer_count = 32
size_arr = [4096,1024,1024,4096,14336,14336,14336] * layer_count #q,k,v,o,g,u,d x 32
size_np = np.array(size_arr)

module_sa = ["q_proj","k_proj","v_proj","o_proj"]
module_mlp = ["gate_proj","up_proj","down_proj"]
module_arr = module_sa + module_mlp

targ_bits = args.targ_bits

mN = len(size_arr)

print("loading hessians..")
grad_arr = torch.load(args.hessian_path)

max_layer_bits = [6] * len(size_arr)

s6_arr = []
s5_arr = []
s4_arr = []
s3_arr = []
for l in tqdm(range(32)):
    for i, n in enumerate(module_arr):
        w_path = args.w_path
        w16 = torch.load(f"{w_path}/16b/{l}_{n}.pt").float().cpu()
        w6 = torch.load(f"{w_path}/6b/{l}_{n}.pt").float().cpu()
        w5 = torch.load(f"{w_path}/5b/{l}_{n}.pt").float().cpu()
        w4 = torch.load(f"{w_path}/4b/{l}_{n}.pt").float().cpu()
        w3 = torch.load(f"{w_path}/3b/{l}_{n}.pt").float().cpu()
        if n in module_sa:
            g_n = "self_attn."+n
        else:
            g_n = "mlp."+n
        g = grad_arr[l][g_n]
        s6 = (g * ((w16-w6)**2)).sum()
        s5 = (g * ((w16-w5)**2)).sum()
        s4 = (g * ((w16-w4)**2)).sum()
        s3 = (g * ((w16-w3)**2)).sum()
        s6_arr.append(s6.item())
        s5_arr.append(s5.item())
        s4_arr.append(s4.item())
        s3_arr.append(s3.item())

# Auto find
import cvxpy
import time

maxmem = args.maxmem
if maxmem is not None:
    save_dir = args.save_dir
    maxmem_path = f"{save_dir}/maxmem_{maxmem}.pt"
    if maxmem == 6.0:
        torch.save(max_layer_bits, maxmem_path)
    max_layer_bits = torch.load(maxmem_path)
    print(f"Using maxmem setting from {maxmem_path}")

for targ_bits in args.targ_bits:
    print(f"target: {targ_bits}")
    B = targ_bits * sum(size_arr)

    s6_np = np.array(s6_arr)
    s5_np = np.array(s5_arr)
    s4_np = np.array(s4_arr)
    s3_np = np.array(s3_arr)

    close_enough = False
    iter_n = 0
    lower_bound = 0.0
    while not close_enough:
        z = cvxpy.Variable((mN * 4), boolean=True) #[c for 6b, c for 5b, c for 4b, c for 3b]

        def count_bits():
            layer_bits = []
            for i in range(mN):
                bits = 6
                if z[i].value == 1.0:
                    bits = 6
                elif z[i+mN].value == 1.0:
                    bits = 5
                elif z[i+mN*2].value == 1.0:
                    bits = 4
                elif z[i+mN*3].value == 1.0:
                    bits = 3
                else: raise RuntimeError("Multiple precision assigned")
                layer_bits.append(bits)
            bsum = 0
            for i, b in enumerate(layer_bits):
                bsum += b * size_arr[i]
            return bsum/sum(size_arr)
        
        L = lower_bound * sum(size_arr)
        obj = (cvxpy.sum(cvxpy.multiply(s6_np,z[:mN])) + 
            cvxpy.sum(cvxpy.multiply(s5_np,z[mN:mN*2])) + 
            cvxpy.sum(cvxpy.multiply(s4_np,z[mN*2:mN*3])) + 
            cvxpy.sum(cvxpy.multiply(s4_np,z[mN*3:mN*4])))
        
        qsum = (cvxpy.sum(cvxpy.multiply(size_arr,z[:mN]))*6
                + cvxpy.sum(cvxpy.multiply(size_arr,z[mN:mN*2]))*5 
                + cvxpy.sum(cvxpy.multiply(size_arr,z[mN*2:mN*3]))*4
                + cvxpy.sum(cvxpy.multiply(size_arr,z[mN*3:mN*4]))*3) <= B
        
        qlowsum = (cvxpy.sum(cvxpy.multiply(size_arr,z[:mN]))*6
                + cvxpy.sum(cvxpy.multiply(size_arr,z[mN:mN*2]))*5 
                + cvxpy.sum(cvxpy.multiply(size_arr,z[mN*2:mN*3]))*4
                + cvxpy.sum(cvxpy.multiply(size_arr,z[mN*3:mN*4]))*3) >= L
        
        constraints = []
        constraints.append(qsum)
        constraints.append(qlowsum)
        for i in range(mN):
            constraints.append((z[i]+z[i+mN]+z[i+mN*2]+z[i+mN*3]==1))
            if max_layer_bits[i] < 6:
                constraints.append((z[i] == 0))
            if max_layer_bits[i] < 5:
                constraints.append((z[i+mN] == 0))
            if max_layer_bits[i] < 4:
                constraints.append((z[i+mN*2] == 0))
        problem = cvxpy.Problem(cvxpy.Minimize(obj), constraints=constraints)
        problem.solve(solver=cvxpy.GLPK_MI)

        print(f"iter {iter_n}: lower:{lower_bound:.2f} result: {count_bits():.4f}                    ", end='\r')
        time.sleep(0)
        if abs(count_bits()-targ_bits) <= 0.01:
            close_enough = True
        else:
            lower_bound += 0.01
            iter_n += 1

    layer_bits = []
    for i in range(mN):
        bits = 6
        if z[i].value == 1.0:
            bits = 6
        elif z[i+mN].value == 1.0:
            bits = 5
        elif z[i+mN*2].value == 1.0:
            bits = 4
        elif z[i+mN*3].value == 1.0:
            bits = 3
        else: raise RuntimeError("No precision assigned")
        layer_bits.append(bits)
    print(layer_bits)
    bsum = 0
    for i, b in enumerate(layer_bits):
        bsum += b * size_arr[i]
    print(f"{bsum/sum(size_arr)} bits")

    save_dir = args.save_dir
    os.makedirs(save_dir, exist_ok=True)
    torch.save(layer_bits, f"{save_dir}/max{'mem' if maxmem is None else maxmem}_{targ_bits}.pt")