import torch 
import math
from collections import Counter
from hook import * 

###################################################################
####### operations used for injecting information into LLMs #######
###################################################################

def translation_op_(x, t, alpha):
    x.add_(alpha * t)
    return x

###################################################################




def logit_extract_with_hooks(op_to_meandiff, dataloader, model, tokenizer, operations, position = -1, model_use = "LLaMA3_Instruct", inject_op = translation_op_, layers = 32, ops_dict = None):
    # hook
    op_to_layer_to_val = {op : {layer : [] for layer in range(layers)} for op in operations}
    if model_use == "LLaMA3_Instruct":
        operations_to_hook_info = get_op_to_hook_info()
    elif model_use == "Gemma2_Instruct":
        operations_to_hook_info = get_op_to_hook_info_gemma()
    for op in ops_dict:
        op_layers                                      = ops_dict[op]['layers']
        layers_to_alpha                                = ops_dict[op]['layers_to_alpha']
        operations_to_hook_info[op]['inject']          = True
        operations_to_hook_info[op]['layer_to_inject'] = {layer : {'inject_op' : inject_op, 'inject_dict' : {"t" : op_to_meandiff[op][layer].cuda(), "alpha" : layers_to_alpha[layer]} } for layer in op_layers}
    hooks, handles             = hook_llama_inject(model, operations_to_hook_info) 

    # extract
    logits = logit_extract(dataloader, model, tokenizer, position)

    # unhook
    for op in handles.keys(): 
        for layer in handles[op].keys():
            handles[op][layer].remove()

    return logits 


def logit_extract(dataloader, model, tokenizer, position = -1):
    bar = tqdm.tqdm(dataloader, total = len(dataloader))
    logits = [] 
    for q in bar:
        q_tok                  = tokenizer(q, return_tensors = 'pt', padding = True).to("cuda")
        with torch.no_grad():
            logits_batch       = model(q_tok['input_ids'], attention_mask = q_tok['attention_mask'], pad_token_id=tokenizer.pad_token_id).logits.detach().cpu()
        q_tok  = q_tok.to("cpu")
        logits.append(logits_batch[:,position,:])
    return torch.vstack(logits)



# grades degredation
def iterateGPT(QA_pairs, model_use, temp, prompt_format, client, max_tokens = 512, system_prompt = None):
    outs, usages = [], [] 
    bar = tqdm.tqdm(QA_pairs)
    for Q, A in bar:
        P = prompt_format.format(Q = Q, A = A)
        messages = [{ "role": "user","content": P}]
        if not system_prompt is None:
            messages = [{"role" : "system", "content" : system_prompt}] + messages
        completion = client.chat.completions.create(model=model_use, messages= messages, temperature = temp, max_tokens = max_tokens)
        outs.append(completion.choices[0].message.content)
        usages.append(completion.usage)
    return outs, usages 