import torch
import struct
import torch
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed, BitsAndBytesConfig


def int8_to_binary(tensor):
    if tensor.dim() == 0:
        tensor = tensor.view(1)
    numpy_array = tensor.detach().cpu().numpy()
    binary_array = [format(x if x >= 0 else (1 << 8) + x, '08b') for x in numpy_array]
    return binary_array


def binary_to_int8(binary_list):
    int8_list = [int(b, 2) if b[0] == '0' else int(b, 2) - (1 << 8) for b in binary_list]
    return torch.tensor(int8_list, dtype=torch.int8)

def flip_bit_in_binary(binary_str, bit_position):
    binary_list = list(binary_str)
    binary_list[bit_position] = '1' if binary_list[bit_position] == '0' else '0'
    return ''.join(binary_list)

def int8_to_fp_weight(int8_value, scb):
    return (int8_value * scb) / 127

def fp_weight_to_int8(real_weight, scb):
    return round((real_weight / scb) * 127)

def flip_bit_int8(tensor, offset_bit_flat, target_value, bit=None, scb=None):
    """
    Flips a bit in the int8 representation of the tensor.
    """
    
    with torch.no_grad():
        # cb = tensor.weight.CB #(weight.CB=weight.data)
        # scb = tensor.weight.SCB[offset_eos_token]
        int_repr = tensor.view(torch.int8)
        int_repr_flat = int_repr.data.view(-1)
        original_value = int_repr_flat[offset_bit_flat].item()
        original_weight = int8_to_fp_weight(original_value, scb)
        # original_value_binary = int8_to_binary(int_repr_flat)[offset_bit_flat]
        original_value_binary = int8_to_binary(int_repr_flat[offset_bit_flat])[0]
        
        best_value = original_value
        best_weight = original_weight
        best_binary = original_value_binary
        best_bit = None

        if bit is not None:
            flipped_binary = flip_bit_in_binary(original_value_binary, bit)
            flipped_value = binary_to_int8([flipped_binary])[0].item()
            flipped_weight = int8_to_fp_weight(flipped_value, scb)
            
            best_value = flipped_value
            best_weight = flipped_weight
            best_binary = flipped_binary

            int_repr_flat[offset_bit_flat] = binary_to_int8([best_binary])[0].item()
            print("***flipping***: instruction applied! bit: {}, value: int8: {}, fp: {}".format(bit, best_value, best_weight))
        else:
            print("***flipping***: target_value (fp): {}, original_value (fp): {}, original_value (int8): {}, original_value (binary): {}".format(target_value, original_weight, original_value, original_value_binary))
            for bit in range(8):
                flipped_binary = flip_bit_in_binary(original_value_binary, bit)
                flipped_value = binary_to_int8([flipped_binary])[0].item()
                flipped_weight = int8_to_fp_weight(flipped_value, scb)
                print("***flipping***: trying bit: {}, flipped_value: int8: {}, binary:{}, fp: {}".format(bit, flipped_value, flipped_binary, flipped_weight))
                print("distance (flipped_weight - target_value): {}".format(flipped_weight - target_value))
                if abs(flipped_weight - target_value) < abs(best_weight - target_value):
                    best_value = flipped_value
                    best_weight = flipped_weight
                    best_binary = flipped_binary
                    best_bit = bit
                    print("***flipping***: best result recorded! bit: {}, value: int8: {}, fp: {}".format(best_bit, best_value, best_weight))
                    print("***flipping***: distance (best_weight - target_value): {}".format(best_weight - target_value))

            
            int_repr_flat[offset_bit_flat] = binary_to_int8([best_binary])[0].item()
            print("***flipping***: bit flipped: {}, index: {}, original value: {}, target value: {}, flipped value: {}, flipped value int8: {}, flipped value fp: {}\n".format(best_bit, offset_bit_flat, original_value, target_value, best_value, int_repr_flat[offset_bit_flat].item(), best_weight))
        return int_repr_flat.view(tensor.shape), best_bit


def float_to_bin(fp_value, dtype):
    if dtype == torch.float32:
        packed = struct.pack('f', fp_value)
        bit_length = 32
    elif dtype == torch.float16:
        packed = struct.pack('e', fp_value)
        bit_length = 16
    elif dtype == torch.bfloat16:
        packed = struct.pack('e', fp_value)  # Same encoding as float16 for packing
        bit_length = 16
    else:
        raise ValueError("Unsupported dtype")
    return ''.join(f'{byte:08b}' for byte in packed), bit_length

def bin_to_float(binary_str, dtype):
    byte_array = bytearray(int(binary_str[i:i+8], 2) for i in range(0, len(binary_str), 8))
    if dtype == torch.float32:
        return struct.unpack('f', byte_array)[0]
    elif dtype == torch.float16 or dtype == torch.bfloat16:
        return struct.unpack('e', byte_array)[0]
    else:
        raise ValueError("Unsupported dtype")

def flip_bit_fp(tensor, offset_bit_flat, target_value, bit=None, dtype=torch.float32):
    """
    Flips a bit in the floating-point representation of the tensor.
    """
    with torch.no_grad():
        tensor_flat = tensor.view(-1)
        original_value = tensor_flat[offset_bit_flat].item()
        original_value_binary, bit_length = float_to_bin(original_value, dtype)
        
        best_value = original_value
        best_binary = original_value_binary
        best_bit = None

        if bit is not None:
            flipped_binary = flip_bit_in_binary(original_value_binary, bit)
            flipped_value = bin_to_float(flipped_binary, dtype)
            best_value = flipped_value
            best_binary = flipped_binary
            tensor_flat[offset_bit_flat] = best_value
            print(f"***flipping***: instruction applied! bit: {bit}, value ({dtype}): {best_value}")
        else:
            print(f"***flipping***: target_value ({dtype}): {target_value}, original_value: {original_value}, binary: {original_value_binary}")
            for bit in range(bit_length):
                flipped_binary = flip_bit_in_binary(original_value_binary, bit)
                flipped_value = bin_to_float(flipped_binary, dtype)
                print(f"***flipping***: trying bit: {bit}, flipped_value: {flipped_value}, binary: {flipped_binary}")
                print(f"distance (flipped_value - target_value): {abs(flipped_value - target_value)}")
                if abs(flipped_value - target_value) < abs(best_value - target_value):
                    best_value = flipped_value
                    best_binary = flipped_binary
                    best_bit = bit
                    print(f"***flipping***: best result recorded! bit: {best_bit}, value: {best_value}")
                    print(f"***flipping***: distance (best_value - target_value): {abs(best_value - target_value)}")

            tensor_flat[offset_bit_flat] = best_value
            print(f"***flipping***: bit flipped: {best_bit}, index: {offset_bit_flat}, original: {original_value}, target: {target_value}, flipped: {best_value}")
        
        return tensor_flat.view(tensor.shape), best_bit

    
def hamming_weight_difference(tensor1, tensor2, dtype='int8'):
    """
    Calculate the Hamming weight difference between two tensors.
    Args:
        tensor1 (torch.Tensor): The first tensor.
        tensor2 (torch.Tensor): The second tensor.
        dtype (str): The data type of the tensors ('int8', 'fp16', 'fp32').
    """
    assert tensor1.shape == tensor2.shape, "Tensors must have the same shape"
    
    if dtype == 'int8':
        binary_tensor1 = int8_to_binary(tensor1.view(-1))
        binary_tensor2 = int8_to_binary(tensor2.view(-1))
    elif dtype == 'fp16':
        binary_tensor1 = [float_to_bin(val.item(), torch.float16)[0] for val in tensor1.view(-1)]
        binary_tensor2 = [float_to_bin(val.item(), torch.float16)[0] for val in tensor2.view(-1)]
    elif dtype == 'fp32':
        binary_tensor1 = [float_to_bin(val.item(), torch.float32)[0] for val in tensor1.view(-1)]
        binary_tensor2 = [float_to_bin(val.item(), torch.float32)[0] for val in tensor2.view(-1)]
    else:
        raise NotImplementedError("Unsupported dtype")
    
    hamming_diff = 0
    for b1, b2 in zip(binary_tensor1, binary_tensor2):
        hamming_diff += sum(c1 != c2 for c1, c2 in zip(b1, b2))
    
    return hamming_diff

def reload_and_flip(model_name, flip_instructions, dtype):
    with torch.no_grad():
        model_path = './models/'
        if dtype == 'fp32':
            model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", cache_dir=model_path, torch_dtype=torch.float32)
            flip_func = flip_bit_fp
            args = torch.float32
        elif dtype == 'fp16':
            model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", cache_dir=model_path, torch_dtype=torch.float16)
            flip_func = flip_bit_fp
            args = torch.float16
        elif dtype == 'int8':
            quant_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_threshold=999.0)
            model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", cache_dir=model_path, quantization_config=quant_config)
            flip_func = flip_bit_int8 
            eos_token_id = model.config.eos_token_id
            args = model.get_output_embeddings().weight.SCB[eos_token_id]
        else:
            raise ValueError("Invalid dtype: {}".format(dtype))
        
        eos_token_id = model.config.eos_token_id
        model.requires_grad_(False)
        # model.get_output_embeddings().requires_grad_(True)

        for flip_instruction in flip_instructions:
            if flip_instruction['flipped_bit'] != None:
                idx = int(flip_instruction['flip_bit_index'])
                target_value = float(flip_instruction['target_value'])
                bit_to_flip = int(flip_instruction['flipped_bit'])
            
                model.get_output_embeddings().weight[eos_token_id], _ = flip_func(model.get_output_embeddings().weight[eos_token_id], idx, target_value, bit_to_flip, args)
    return model
