import torch as t
from transformer_lens import HookedTransformer

from src.constants import (
    ACTIVATION_PATCH_LAYER, 
    BATCH_SIZE, 
    DATASET, 
    HF_ACCESS_TOKEN,
    DEVICE
)
from src.utils import load_batch, logit_difference


def setup_environment(model_name: str, device=DEVICE):

    # load model
    match model_name:
        case "openai-community/gpt2-xl":
            model_name = model_name.split("/")[1]
            model = HookedTransformer.from_pretrained_no_processing(
                model_name, 
                torch_dtype=t.float16,
                device_map=DEVICE,
                n_devices=3,
            )
        case "google/gemma-7b":
            model = HookedTransformer.from_pretrained_no_processing(
                model_name, 
                torch_dtype=t.float16,
                device_map=DEVICE,
                n_devices=3,
            )
        case "meta-llama/Meta-Llama-3.1-8B":
            model = HookedTransformer.from_pretrained_no_processing(
                "meta-llama/Meta-Llama-3-8B", 
                torch_dtype=t.float16,
                device_map=DEVICE,
                n_devices=3,
            )
        case "meta-llama/Meta-Llama-3.1-70B":
            model = HookedTransformer.from_pretrained_no_processing(
                "meta-llama/Meta-Llama-3-70B", 
                torch_dtype=t.float16,
                device_map=DEVICE,
                n_devices=3
            )
    submodule_hook_name = f'blocks.{ACTIVATION_PATCH_LAYER}.hook_mlp_out'
    submodule_hook_point = model.blocks[ACTIVATION_PATCH_LAYER].hook_mlp_out
    attr_submodules = lambda hook_name: 'hook_mlp_out' in hook_name
    
    # load dataset
    dataset = load_batch(DATASET, BATCH_SIZE, model.tokenizer)
    return {
        "model": model,
        "submodule_hook_name": submodule_hook_name,
        "submodule_hook_point": submodule_hook_point,
        "attr_submodules": attr_submodules,
        "dataset": dataset,
    }


def activation_patching(env_dict):
    model, submodule_hook_name, submodule_hook_point, _, dataset = env_dict.values()
    
    clean_prefix, patch_prefix, _, _ = dataset.values()
    
    # cache clean activations
    def cache_hook_fn(activations, hook):
        hook.ctx = activations.detach()
    model.run_with_hooks(
        clean_prefix.to("cuda"),
        return_type=None, 
        fwd_hooks=[(submodule_hook_name, cache_hook_fn)]
        )
    clean_act = submodule_hook_point.ctx
    
    # define patching function
    def patch_hook_fn(activations, hook):
        return clean_act.to("cuda")
    model.reset_hooks()
    logits = model.run_with_hooks(
        patch_prefix.to("cuda"),
        fwd_hooks=[(submodule_hook_name, patch_hook_fn)]
    )
    
    return logits

