
import torch
import numpy as np
from torch import uint8, int32, Tensor
import datetime
from tqdm import tqdm
import collections

def pack_2bit_u8(W_q: Tensor) -> Tensor:  # uint8 > uint8/4
    W_q = W_q.to(uint8)
    _step = int(len(W_q) / 4)

    return (
        W_q[:_step] << 6
        | W_q[_step : 2 * _step] << 4
        | W_q[2 * _step : 3 * _step] << 2
        | W_q[3 * _step :]
    )

def ternary(weight, epsilon=1e-6):
    
    def roundclip(x, a, b):
        x_round = x.round()
        weight_ = torch.where(x_round > b, b, x_round)
        return torch.where(weight_ < a, a, weight_)
    
    weight_ternary = weight.clone()
    # gamma = torch.mean(torch.abs(weight_ternary), dim=-1, keepdim=True)
    gamma = torch.mean(torch.abs(weight_ternary))
    weight_ternary = roundclip(weight_ternary / (gamma + epsilon), -1, 1)

    return weight_ternary

def quantize(weight, bits=2):
    infeatures = weight.shape[1]
    intweight = []
    for idx in range(infeatures):
        intweight.append(torch.round((weight[:,idx] + 1)).to(torch.uint8)[:,None])
    intweight = torch.cat(intweight,dim=1)
    return pack_2bit_u8(intweight)

ckpt = torch.load("/path/to/consolidated_ema.00-of-01.pth")

from pprint import pprint
pprint(ckpt.keys())

ckpt_quantized = {}
ternary_list = ["attention.wq.weight", "attention.wk.weight", "attention.wv.weight", "attention.wo.weight", "feed_forward.w1.weight", "feed_forward.w2.weight", "feed_forward.w3.weight", "adaLN_modulation.1.weight"]
for key in tqdm(ckpt.keys()):
    if "final_layer" in key:
        ckpt_quantized[key] = ckpt[key]
        continue
    flag = False
    for key_choose in ternary_list:
        if key_choose in key:
            weight_old = ckpt[key]
            weight_ternary = ternary(weight_old)
            weight_quantized = quantize(weight_ternary)
            ckpt_quantized[key] = weight_quantized
            flag = True
    if flag == False:
        ckpt_quantized[key] = ckpt[key]
        
weight_quantized.shape

ckpt_quantized = collections.OrderedDict(ckpt_quantized)

print(ckpt_quantized.keys())

torch.save(ckpt_quantized, "/path/to/consolidated_ema_ternary.00-of-01.pth")



