import torch as t
from nnsight import CONFIG, LanguageModel
from transformers import AutoTokenizer
import time
from src.constants import (
    ACTIVATION_PATCH_LAYER, 
    BATCH_SIZE, 
    DATASET, 
    DEVICE, 
    HF_ACCESS_TOKEN
)
from src.utils import load_batch, logit_difference


tracer_kwargs = {'scan' : False, 'validate' : False}

# optional remote setup
REMOTE = False 
CONFIG.set_default_api_key("")


def setup_environment(model_name: str):
    
    # setup model
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_ACCESS_TOKEN)
    tokenizer.pad_token = tokenizer.eos_token
    model = LanguageModel(
        model_name,
        dispatch=True,  # Note: This should **not** be active during remote execution.
        tokenizer=tokenizer,
        device_map=DEVICE, 
        token=HF_ACCESS_TOKEN, 
        torch_dtype=t.float16
    )
    
    match model_name:
        case _ if model_name.startswith("facebook/opt-"):
            submodule = model.model.decoder.layers[ACTIVATION_PATCH_LAYER].fc2
            attr_submodules = [layer.fc2 for layer in model.model.decoder.layers]
            head_module = model.lm_head
        case _ if model_name.startswith("EleutherAI/pythia-"):
            submodule = model.gpt_neox.layers[ACTIVATION_PATCH_LAYER].mlp
            attr_submodules = [layer.mlp for layer in model.gpt_neox.layers]
            head_module = model.embed_out
        case _ if model_name.startswith("meta-llama/Meta-Llama-3.1"):
            submodule = model.model.layers[ACTIVATION_PATCH_LAYER].mlp
            attr_submodules = [layer.mlp for layer in model.model.layers]
            head_module = model.lm_head
        case "openai-community/gpt2-xl":
            submodule = model.transformer.h[ACTIVATION_PATCH_LAYER].mlp
            attr_submodules = [layer.mlp for layer in model.transformer.h]
            head_module = model.lm_head
        case "google/gemma-7b":
            submodule = model.model.layers[ACTIVATION_PATCH_LAYER].mlp
            attr_submodules = [layer.mlp for layer in model.model.layers]
            head_module = model.lm_head
        
    # load dataset
    dataset = load_batch(DATASET, BATCH_SIZE, model.tokenizer)
    return {
        "model": model, 
        "submodule": submodule, 
        "attr_submodules" : attr_submodules,
        "head_module": head_module,
        "dataset": dataset
    }


def activation_patching(env_dict: dict):
    model, submodule, _, head_module, dataset = env_dict.values()
    
    clean_prefix, patch_prefix, _, _ = dataset.values()
    
    # patching over the mlp in a single layer
    with model.trace(**tracer_kwargs, remote=REMOTE) as tracer:
        with tracer.invoke(clean_prefix, scan=False):
            clean_hs = submodule.output[0]
        
        with tracer.invoke(patch_prefix, scan=False):
            submodule.output[0][:] = clean_hs
            logits = head_module.output
            metric = logits[:, -1, 0] - logits[:, -1, 1]
            metric.save()

    return metric
