import torch
import tqdm 
from hook import * 
from utils import translation_op_
import pickle


##################################################################################################
########### Computes mean cross entropy loss for a list of sentences under a model ###############
##################################################################################################

def compute_nll(sentences, model, tokenizer, max_length=512):
    with torch.no_grad():
        # Tokenize with padding
        inputs         = tokenizer(sentences, return_tensors="pt", padding=True, 
                         truncation=True, max_length=max_length)
        input_ids      = inputs.input_ids
        attention_mask = inputs.attention_mask
        toks           = attention_mask.sum() 

        # Create labels, setting padding tokens to -100 so they're ignored in loss calculation
        labels = input_ids.clone()
        labels[attention_mask == 0] = -100
        
        # Get model output
        outputs = model(input_ids.cuda(), attention_mask=attention_mask.cuda(), labels=labels)
        
        # Get loss excluding padding tokens
        loss = outputs.loss
        loss *= toks # multiply by number of toknes

        return float(loss.detach().cpu().item()), toks 

##################################################################################################
##### Extracts all representations for indicated operations for positive and negastive data ######
##################################################################################################

# layers ingored hard coded saving all layers -- its the injections being specific that matters
def extract_pos_neg(data, model, operations, tokenizer, layers = 32, model_use = "LLaMA3_Instruct", split = "train"):
    dataloader_pos, dataloader_neg = data[split]['dataloader_pos'], data[split]['dataloader_neg']

    if model_use == "LLaMA3_Instruct":
        hooks, handles = hook_llama(model, layers = layers)
    elif model_use == "Gemma2_Instruct":
        hooks, handles = hook_gemma(model, layers = layers)

    # Returns list of length # batchs -- each element is tensor (B, T, d)
    op_to_layer_to_val_pos, end_indxs_pos = extract_with_hooks(model, hooks, operations, layers, dataloader_pos, tokenizer)
    op_to_layer_to_val_neg, end_indxs_neg = extract_with_hooks(model, hooks, operations, layers, dataloader_neg, tokenizer)
    # unhook for safety
    for layer in handles.keys():
        for op in handles[layer].keys():
            handles[layer][op].remove()
    
    return op_to_layer_to_val_pos, end_indxs_pos, op_to_layer_to_val_neg, end_indxs_neg


###########################################################################
##### Extracts relevant representations in batches from a hooked LLM ######
###########################################################################

# 1. Assumes model has been hooked appropraitely
# 2. Returns list of length # batchs -- each element is tensor (B, T, d)
# Memory --> Efficient means save only last token making later agglomeration step not do anything
def extract_with_hooks(model, hooks, operations, layers, dataloader, tokenizer, memory_efficient = True):
    op_to_layer_to_val = {op : {layer : [] for layer in range(layers)} for op in operations}
    end_indxs = [] 

    for batch in tqdm.tqdm(dataloader):
        inputs = tokenizer(batch, return_tensors = 'pt', padding = True)

        # First index after padding
        toks_used_per_ex = inputs['attention_mask'].sum(dim = 1)
        maxlen           = inputs['attention_mask'].shape[1] # the length = longest token sequence = consists of all 1s
        start_per_ex     = torch.tensor(maxlen).unsqueeze(dim = 0) - toks_used_per_ex
        end_indxs.extend(start_per_ex.tolist())

        with torch.no_grad():
            # (A) Inference
            model(inputs['input_ids'].cuda(), attention_mask=inputs["attention_mask"].cuda())

            # (B) Iterate and save
            for layer in range(layers):
                for op in operations:
                    op_vec = hooks[layer][op].saved_values
                    if memory_efficient: # Save only last token
                        op_vec = op_vec[:,-1:None] # keep shape
                    op_to_layer_to_val[op][layer].append(op_vec)

    return op_to_layer_to_val, end_indxs


###########################################################################
############################ Generates completions ########################
###########################################################################
# formerly: get_results_vanilla
def generate_vanilla(model, tokenizer, dataloader_test_qe, max_new_tokens = 1000):
    results = [] 
    
    # (D) batch hooked generation 
    for q in tqdm.tqdm(dataloader_test_qe):
        q_tok                  = tokenizer(q, return_tensors = 'pt', padding = True)
        out                    = model.generate(q_tok['input_ids'].cuda(), max_new_tokens = max_new_tokens, top_p = 0, top_k = 0, do_sample = False, attention_mask = q_tok['attention_mask'].cuda(), pad_token_id=tokenizer.pad_token_id)
        results.append({'out_tok' : out.detach().cpu()})

    return results 

############################################################################################
################### Generates completions for list of inject hooks #########################
############################################################################################
# formerly: get_results
def generate_with_hooks(model, tokenizer, ops_dict, dataloader, op_to_meandiff, inject_op = translation_op_, max_new_tokens = 1000, model_use = "LlaMA3_Instruct"):
    assert model_use in ["LLaMA3_Instruct", "Gemma2_Instruct"]; "Unsupported model needs to be llama3instruct or gemma2instruct"
    
    results = [] 
    # (A) empty; later populate with layers, inject dict, etc

    if model_use == "LLaMA3_Instruct":
        # TODO: replace w/ function call like in gemma 
        operations_to_hook_info = { 'emb_pre_attn_post_ln'  : {"hook type" : "forward",     'inject' : False, "module"     : "model.layers.{layer}.input_layernorm"},
                                    'q'                     : {"hook type" : "forward",     'inject' : False, "module"     : "model.layers.{layer}.self_attn.q_proj"},
                                    'k'                     : {"hook type" : "forward",     'inject' : False, "module"     : "model.layers.{layer}.self_attn.k_proj"},
                                    'v'                     : {"hook type" : "forward",     'inject' : False, "module"     : "model.layers.{layer}.self_attn.v_proj"},
                                    'attn_output'           : {"hook type" : "forward_pre", 'inject' : False, "module"     : "model.layers.{layer}.self_attn.o_proj"}, 
                                    'W0_x_attn_output'      : {"hook type" : "forward",     'inject' : False, "module"     : "model.layers.{layer}.self_attn.o_proj"},            # Equivalent to attn output, no need for experimentation
                                    'emb_post_attn_pre_ln'  : {"hook type" : "forward_pre", 'inject' : False, "module"     : "model.layers.{layer}.post_attention_layernorm"}, 
                                    'emb_post_attn_post_ln' : {"hook type" : "forward",     'inject' : False, "module"     : "model.layers.{layer}.post_attention_layernorm"}, 
                                    'emb_post_mlp_residual' : {"hook type" : "forward",     'inject' : False, "module"     : "model.layers.{layer}.mlp"}, 
                                    'emb_post_mlp'          : {"hook type" : "forward",     'inject' : False, "module"     : "model.layers.{layer}"},
                                    'entire_emb_post_attn'          : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}.add"}
                                    }
    elif model_use == "Gemma2_Instruct":
        operations_to_hook_info = get_op_to_hook_info_gemma()

    # (B) Populate hook info 
    for op in ops_dict:
        op_layers                                      = ops_dict[op]['layers']
        layers_to_alpha                                = ops_dict[op]['layers_to_alpha']
        operations_to_hook_info[op]['inject']          = True
        operations_to_hook_info[op]['layer_to_inject'] = {layer : {'inject_op' : inject_op, 'inject_dict' : {"t" : op_to_meandiff[op][layer].cuda(), "alpha" : layers_to_alpha[layer]} } for layer in op_layers}


    # print(operations_to_hook_info)
    
    # (C) hook
    # if model_use == "LlaMA3_Instruct":

    # called hook llama inject but is all purpose  TODO: change name 
    hooks, handles             = hook_llama_inject(model, operations_to_hook_info) 
   
    # (D) batch hooked generation 
    for q in tqdm.tqdm(dataloader):
        q_tok                  = tokenizer(q, return_tensors = 'pt', padding = True)
        out                    = model.generate(q_tok['input_ids'].cuda(), max_new_tokens = max_new_tokens, top_p = 0, top_k = 0, do_sample = False, attention_mask = q_tok['attention_mask'].cuda(), pad_token_id=tokenizer.pad_token_id)
        results.append({'out_tok' : out.detach().cpu()})

    # (E) unhook
    for op in handles.keys(): 
        for layer in handles[op].keys():
            handles[op][layer].remove()

    return results 