import torch
import os
import random
import json
from transformers import AutoModelForCausalLM, AutoConfig
from awq.quantize.pre_quant import get_blocks, get_named_linears
from LFSR import get_lfsr_states


@torch.no_grad()
def normalize_lfsr_vector(v, K=16):
    v = v.to(torch.float32)
    center = 2 ** (K - 1)
    scale = 2 ** (K - 1) - 1
    return (v - center) / scale



@torch.no_grad()
def pseudo_quantize_model_weight_parallel_dynamic(
    model,
    group_size=16,
    seed_size=16,
    error_ratio=0.05,
    block_size_quantization=2,
    min_alpha=2,
    max_alpha=None
):
    alpha_counts = []
    max_alpha = group_size
    layers = get_blocks(model)
    
    named_linears = get_named_linears(layers)
    for name, linear in named_linears.items():

        W = linear.weight.data
        orig_shape = W.shape
        assert orig_shape[-1] % group_size == 0
        reshaped = W.reshape(-1, group_size)
        device = W.device
        num_blocks = reshaped.size(0)
        new_blocks = []
        for block_start in range(0, num_blocks, block_size_quantization):
            coeff_list = []
            coeff_shapes = []
            X_list = []
            block_indices = []
            O_list = []
            for block_offset in range(block_size_quantization):
                idx = block_start + block_offset
                if idx >= num_blocks:
                    break
                O = reshaped[idx]
                for num_basis in range(min_alpha, max_alpha + 1):
                    min_err = float('inf')  
                    best_seed = None        
                    best_coeffs = None      
                    best_X = None          

                    for seed in range(1, 0xFFFF + 1):
                        total_needed = num_basis * group_size
                        raw_vals = get_lfsr_states(seed, seed_size, total_needed)
                        raw_tensor = torch.tensor(raw_vals, dtype=torch.float32, device=device)
                        normalized = normalize_lfsr_vector(raw_tensor, K=16)
                        X = normalized.view(num_basis, group_size).T

                        coeffs = torch.linalg.lstsq(X, O.float().unsqueeze(1)).solution
                        O_approx = X @ coeffs
                        err = torch.norm(O.float() - O_approx.squeeze()).item()

                        if err < min_err:
                            min_err = err
                            best_seed = seed
                            best_coeffs = coeffs
                            best_X = X

                    if min_err <= error_ratio:
                        alpha_counts.append(num_basis)
                        coeff_list.append(best_coeffs.view(-1))
                        coeff_shapes.append(best_coeffs.shape[0])
                        X_list.append(best_X)
                        O_list.append(O)
                        block_indices.append(idx)
                        break
                    
            all_coeffs = torch.cat(coeff_list, dim=0)
            max_abs = torch.max(torch.abs(all_coeffs))
            scale = torch.tensor(max_abs / 127.0, dtype=torch.float16, device=device)
            quantized = torch.clamp(torch.round(all_coeffs / scale), -127, 127).to(torch.int8)
            restored = quantized.float() * scale
            offset = 0
            
            for i in range(len(block_indices)):
                coeff_len = coeff_shapes[i]
                coeff = restored[offset:offset + coeff_len].unsqueeze(1)
                X = X_list[i]
                O = O_list[i]
                O_approx = X @ coeff
                new_blocks.append(O_approx.squeeze())
                offset += coeff_len

        W.copy_(torch.stack(new_blocks, dim=0).reshape(orig_shape))

    return alpha_counts

def load_model_only(model_path: str, dtype: str = "float16"):
    torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype]
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    config.use_cache = False
    model = AutoModelForCausalLM.from_pretrained(
        model_path, config=config,
        torch_dtype=torch_dtype,
        trust_remote_code=True,
        device_map={"": "cpu"},
        low_cpu_mem_usage=True
    )
    return model

if __name__ == "__main__":
    group_size = 16
    seed_size = 16
    block_size_quantization = 2
    error_ratio = 0.05
    model_path = "./llama2-hf/llama-2-7b"
    output_dir = "./output"

    os.makedirs(output_dir, exist_ok=True)
    model = load_model_only(model_path, dtype="float16")
    model = model.to("cuda:0")
    alpha_counts = pseudo_quantize_model_weight_parallel_dynamic(
        model=model,
        group_size=group_size,
        seed_size=seed_size,
        error_ratio=error_ratio,
        block_size_quantization=block_size_quantization
    )
    model.to("cpu")
    torch.cuda.empty_cache()

    total_bits = 0
    txt_path = os.path.join(output_dir, f"alpha_counts_{group_size}_0.05.txt")
    with open(txt_path, "w") as f:
        for count in alpha_counts:
            f.write(f"{count}\n")
            total_bits += 16 + 8 * count + 8

    avg_bit_per_value = total_bits / (len(alpha_counts) * group_size)
    results = {"avg_bit_per_value": round(avg_bit_per_value, 4)}
    save_path = os.path.join(output_dir, f"quant_eval_block_{group_size}_0.05.json")
    with open(save_path, "w") as f:
        json.dump(results, f, indent=2)

    print(f"average bit: {results['avg_bit_per_value']}")
