import torch 
import re 
import tqdm
from functools import reduce

####################################################
###### Hook classes for saving and injection #######
####################################################

# NOTE: deprecated, broken up into multiple simpler hooks below
class LayerOutputHook:
    def __init__(self, save_output = False, save_input = False, inject_values = {}, inject_op = None, name = 'placeholder'):
        self.save_output         = save_output 
        self.save_input          = save_input     # bool
        self.saved_input         = None           # placeholder
        self.saved_output        = None 
        self.inject_values       = inject_values  # relevant values for transforming output
        self.inject_op           = inject_op      # lambda function 
        self.name                = name 

    def __call__(self, module, input, output):
        assert (self.inject_values == {} and self.inject_op is None) or (not self.inject_values == {} and not self.inject_op is None) ; "invalid combination of inject values and inject op"
        input = input[0]    # Input is always length 1 tuple
        # (A) Save values
        if self.save_output:
            # if 'emb_post_mlp' in self.name:             # outputs = tuple(hidden_state, random other stuff determined by config)
            if re.fullmatch(r"^emb_post_mlp_\d+$", self.name):
                self.saved_output = output[0].clone().detach().cpu()
            else:
                self.saved_output = output.clone().detach().cpu()                 
        if self.save_input:
            self.saved_input = input.clone().detach().cpu() 
        # (B) Modify the output 
        if not self.inject_values == {}:
            # if 'emb_post_mlp' in self.name:
            if re.fullmatch(r"^emb_post_mlp_\d+$", self.name):
                output[0] = self.inject_op(output[0], **self.inject_values)
            else:
                output = self.inject_op(output, **self.inject_values)
        return output  

############################################################################################################
############################################## Saving ######################################################
############################################################################################################

####################################
##### save input to a function #####
####################################
class PreHookSave:
    def __init__(self, save = True, name = 'placeholder'):
        self.save                = save 
        self.saved_values        = None  
        self.name                = name 

    def __call__(self, module, input):
        if self.save:
            self.saved_values = input[0].clone().detach().cpu() 
        return input  
#####################################
##### save output of a function #####
#####################################
class ForwardHookSave:
    def __init__(self, save = True, name = 'placeholder'):
        self.save                = save 
        self.saved_values        = None  
        self.name                = name 

    def __call__(self, module, input, output):
        if self.save:
            if re.fullmatch(r"^emb_post_mlp_\d+$", self.name):
                self.saved_values = output[0].clone().detach().cpu()
            else:
                self.saved_values = output.clone().detach().cpu() 
        return output  

############################################################################################################
############################################# Injecting ####################################################
############################################################################################################
class PreHookInject:
    def __init__(self, inject_op = None, inject_dict = {} , name = 'placeholder'):
        self.inject_op           = inject_op 
        self.inject_dict         = inject_dict  
        self.name                = name 

    def __call__(self, module, input):
        # (a) : old, non in place --> input[0] = self.inject_op(input[0], **self.inject_dict)
        self.inject_op(input[0], **self.inject_dict) #. (b) : new, in place
        return input  

class ForwardHookInject:
    def __init__(self, inject_op = None, inject_dict = {}, name = 'placeholder'):
        self.inject_op           = inject_op 
        self.inject_dict         = inject_dict  
        self.name                = name 

    def __call__(self, module, input, output):
        if re.fullmatch(r"^model.layers.\d+$", self.name): # the output of a layer is the first element of a tuple
            self.inject_op(output[0], **self.inject_dict) # inplace
        else:
            #### Comment out
            # print(self.name) # , self.inject_op)
            # print(output.shape)
            # print("Pre injection norm")
            # print(output.norm())
            self.inject_op(output, **self.inject_dict) #inplace
            # print("Post injection norm")
            # print(output.norm())
            # print()

        return output  

###########################################################################
#### Applies hooks to all relevant quantities at all layers in LLaMA 3 ####
###########################################################################

# TODO: redo with more informative dict such as with the injection loop

# v2 of hook_llama; uses pre and forward hooks where appropriate
def hook_llama(model, layers = 32):
    # (1.) For storage, injection, etc 
    hooks = {layer : {  'emb_pre_attn_post_ln': None,                                                                                                                      # pre attn stuff
                        'q'                   : None,                     'k' : None, 'v'                     : None, 'attn_output' : None, 'W0_x_attn_output'    : None,  # internal to attn stuff
                        'entire_emb_post_attn': None, 'emb_post_attn_pre_ln': None, 'emb_post_attn_post_ln' : None, 'emb_post_mlp_residual' : None, 'emb_post_mlp': None }                               # post attn stuff
                        for layer in range(layers) } 
    # (2.) For unregistering 
    hook_handles = {layer : {  'emb_pre_attn_post_ln': None,                                                                                                           # pre attn stuff
                    'q'                              : None,                     'k' : None, 'v'                     : None, 'attn_output' : None, 'W0_x_attn_output'    : None,  # internal to attn stuff
                    'entire_emb_post_attn'          : None,'emb_post_attn_pre_ln': None, 'emb_post_attn_post_ln' : None, 'emb_post_mlp_residual' : None, 'emb_post_mlp': None }                               # post attn stuff
                    for layer in range(layers) } 

    # Hook ops at all layers all layers 
    for layer_id, layer in enumerate(model.model.layers):
        # (A) pre attention hook
        # emb_pre_attn_post_ln_hook                        = PreHookSave(save = True, name = f"emb_pre_attn_post_ln_{layer_id}")
        emb_pre_attn_post_ln_hook                        = ForwardHookSave(save = True, name = f"emb_pre_attn_post_ln_{layer_id}")
        handle_emb_pre_attn_post_ln                      = layer.input_layernorm.register_forward_hook(emb_pre_attn_post_ln_hook)
        hooks[layer_id]['emb_pre_attn_post_ln']          = emb_pre_attn_post_ln_hook
        hook_handles[layer_id]['emb_pre_attn_post_ln']   = handle_emb_pre_attn_post_ln

        # # (B) inside attention hooks
        q_hook                                           = ForwardHookSave(save = True, name = f"q_{layer_id}")
        handle_q                                         = layer.self_attn.q_proj.register_forward_hook(q_hook)
        hooks[layer_id]['q']                             = q_hook
        hook_handles[layer_id]['q']                      = handle_q

        k_hook                                           = ForwardHookSave(save = True, name = f"k_{layer_id}")
        k_handle                                         = layer.self_attn.k_proj.register_forward_hook(k_hook)
        hooks[layer_id]['k']                             = k_hook
        hook_handles[layer_id]['k']                      = k_handle

        v_hook                                           = ForwardHookSave(save = True, name = f"v_{layer_id}")
        v_handle                                         = layer.self_attn.v_proj.register_forward_hook(v_hook)
        hooks[layer_id]['v']                             = v_hook
        hook_handles[layer_id]['v']                      = v_handle

        attn_output_hook                                 = PreHookSave(save = True, name = f"attn_output_{layer_id}")
        handle_attn_output                               = layer.self_attn.o_proj.register_forward_pre_hook(attn_output_hook)
        hooks[layer_id]['attn_output']                   = attn_output_hook
        hook_handles[layer_id]['attn_output']            = handle_attn_output

        W0_x_attn_output_hook                            = ForwardHookSave(save = True, name = f"W0_x_attn_output_{layer_id}")
        handle_W0_x_attn_output                          = layer.self_attn.o_proj.register_forward_hook(W0_x_attn_output_hook)
        hooks[layer_id]['W0_x_attn_output']              = W0_x_attn_output_hook
        hook_handles[layer_id]['W0_x_attn_output']       = handle_W0_x_attn_output

        # (C) after attention hooks
        entire_emb_post_attn_hook                        = ForwardHookSave(save = True, name = f"entire_emb_post_attn_{layer_id}") # NEW Mar 21 Custom
        handle_entire_emb_post_attn                      = layer.add.register_forward_hook(entire_emb_post_attn_hook)
        hooks[layer_id]['entire_emb_post_attn']          = entire_emb_post_attn_hook
        hook_handles[layer_id]['entire_emb_post_attn']   = handle_entire_emb_post_attn

        emb_post_attn_pre_ln_hook                        = PreHookSave(save = True, name = f"emb_post_attn_pre_ln_{layer_id}")
        handle_emb_post_attn_pre_ln                      = layer.post_attention_layernorm.register_forward_pre_hook(emb_post_attn_pre_ln_hook)
        hooks[layer_id]['emb_post_attn_pre_ln']          = emb_post_attn_pre_ln_hook
        hook_handles[layer_id]['emb_post_attn_pre_ln']   = handle_emb_post_attn_pre_ln

        emb_post_attn_post_ln_hook                       = ForwardHookSave(save = True, name = f"emb_post_attn_post_ln_{layer_id}")
        handle_emb_post_attn_post_ln                     = layer.post_attention_layernorm.register_forward_hook(emb_post_attn_post_ln_hook)
        hooks[layer_id]['emb_post_attn_post_ln']         = emb_post_attn_post_ln_hook
        hook_handles[layer_id]['emb_post_attn_post_ln']  = handle_emb_post_attn_post_ln

        emb_post_mlp_residual_hook                       = ForwardHookSave(save = True, name = f"emb_post_mlp_residual_{layer_id}")
        handle_emb_post_mlp_residual                     = layer.mlp.register_forward_hook(emb_post_mlp_residual_hook)
        hooks[layer_id]['emb_post_mlp_residual']         = emb_post_mlp_residual_hook
        hook_handles[layer_id]['emb_post_mlp_residual']  = handle_emb_post_mlp_residual

        emb_post_mlp_hook                                = ForwardHookSave(save = True, name = f"emb_post_mlp_{layer_id}")
        handle_emb_post_mlp                              = layer.register_forward_hook(emb_post_mlp_hook)
        hooks[layer_id]['emb_post_mlp']                  = emb_post_mlp_hook
        hook_handles[layer_id]['emb_post_mlp']           = handle_emb_post_mlp

    return hooks, hook_handles


# cleaner and nicer than hook llama, maybe later redo hook llama in this way 
def hook_gemma(model, layers = 42): #layers = 42, unsused
    hooks, handles = {}, {} 
    op_dict = get_op_to_hook_info_gemma()

    for layer_id, _ in enumerate(model.model.layers):
        hooks[layer_id], handles[layer_id] = {}, {}

        for op in op_dict.keys():
            assert op_dict[op]['hook type'] in ['forward', 'forward_pre']; "error hook type must be forward or forward_pre"

            op_exact_name = op_dict[op]['module'].format(layer = str(layer_id))
            module        = reduce(getattr, op_exact_name.split("."), model)
            if op_dict[op]['hook type'] == 'forward':
                hook                  = ForwardHookSave(save = True, name = f"{op}_{layer_id}")
                handle                = module.register_forward_hook(hook)
            elif op_dict[op]['hook type'] == 'forward_pre':
                hook                  = PreHookSave(save = True, name = f"{op}_{layer_id}")
                handle                = module.register_forward_pre_hook(hook)
            hooks[layer_id][op]   = hook
            handles[layer_id][op] = handle

    return hooks, handles 


#####################################################################################################
##### Dicts for constructing injection hooks, start off vanilla w/ no hooks (inject = False)  #######
#####################################################################################################

# TODO: repurpose to also work with extraction hooks
#### named as default but really llama specific version
def get_op_to_hook_info():
    # operations_to_hook_info = 
    return {'emb_pre_attn_post_ln'          : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}.input_layernorm"},          # Y
            'q'                             : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}.self_attn.q_proj"},         # Y
            'k'                             : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}.self_attn.k_proj"},         # Y
            'v'                             : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}.self_attn.v_proj"},         # Y
            'attn_output'                   : {"hook type" : "forward_pre", 'inject' : False, "module" : "model.layers.{layer}.self_attn.o_proj"},         # Y
            'W0_x_attn_output'              : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}.self_attn.o_proj"},         # N
            'emb_post_attn_pre_ln'          : {"hook type" : "forward_pre", 'inject' : False, "module" : "model.layers.{layer}.post_attention_layernorm"}, # Y Add
            'emb_post_attn_post_ln'         : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}.post_attention_layernorm"}, # Y add
            'emb_post_mlp_residual'         : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}.mlp"},                      # Y add
            'emb_post_mlp'                  : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}"},                          # Y 
            'entire_emb_post_attn'          : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}.add"},
            }                           

# gemma specific version
def get_op_to_hook_info_gemma():
    # operations_to_hook_info = 
    return {'emb_pre_attn_post_ln'          : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}.input_layernorm"},                                             # Correct
            'q'                             : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}.self_attn.q_proj"},                                            # Correct
            'k'                             : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}.self_attn.k_proj"},                                            # Correct
            'v'                             : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}.self_attn.v_proj"},                                            # Correct
            'attn_output'                   : {"hook type" : "forward_pre", 'inject' : False, "module" : "model.layers.{layer}.self_attn.o_proj"},                                            # Correct
            # 'W0_x_attn_output'      : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}.self_attn.o_proj"},                                          # Correct, Don't Need
            # 'emb_post_attn_pre_ln'  : {"hook type" : "forward_pre", 'inject' : False, "module" : "model.layers.{layer}.post_attention_layernorm"},                                  # Correct, Don't Need for Gemma, Do need for LlaMA
            'emb_post_attn_post_ln'         : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}.post_attention_layernorm"},  
            'emb_mlp_input'                 : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}.pre_feedforward_layernorm"},    # NEW                                 
            'emb_post_mlp_residual'         : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}.mlp"},                                                         # Correct 
            'emb_post_mlp_residual_post_ln' : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}.post_feedforward_layernorm"},                          # New, Not in LlaMA
            'emb_post_mlp'                  : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}"},
            'entire_emb_post_attn'          : {"hook type" : "forward",     'inject' : False, "module" : "model.layers.{layer}.add"}}                                                             # Correct 

#####################################################################################################
##### Given dict with information on operations, layers, alpha and inject_ops we hook llama  #######
#####################################################################################################

def hook_llama_inject(model, operations_to_hook_info):
    hooks   = {}
    handles = {}

    for op in operations_to_hook_info:
        # print(op)
        temp = operations_to_hook_info[op]
        if temp['inject'] == True:
            hooks[op], handles[op] = {},{}
            # print("Injecting")
            # print(temp['layer_to_inject'])


            for layer in temp['layer_to_inject'].keys(): # for every layer we are modifying
                assert temp['hook type'] in ['forward_pre', 'forward']; "error in hook type provided, must be forward_pre or forward"
                temp_layer    = temp['layer_to_inject'][layer]

                # Grab module to hook, make sure it is the correct layer
                op_exact_name = temp['module'].format(layer = str(layer))
                module        = reduce(getattr, op_exact_name.split("."), model)

                # print("NORM: ")
                # print(temp_layer['inject_dict']['t'].norm())
                
                
                
                if temp['hook type'] == "forward_pre":
                    hook          = PreHookInject(inject_op = temp_layer['inject_op'], inject_dict = temp_layer['inject_dict'],    name = op_exact_name)
                    handle        = module.register_forward_pre_hook(hook)
                else:
                    hook          = ForwardHookInject(inject_op = temp_layer['inject_op'], inject_dict = temp_layer['inject_dict'], name = op_exact_name)
                    handle        = module.register_forward_hook(hook)
                
                hooks[op][layer]    = hook
                handles[op][layer]  = handle

    return hooks, handles