import torch as t
import pyvene as pv
from einops import rearrange
from transformers import AutoModelForCausalLM, AutoTokenizer

from src.constants import (
    ACTIVATION_PATCH_LAYER,
    BATCH_SIZE,
    DATASET,
    DEVICE,
)
from src.utils import load_batch, logit_difference


def setup_environment(model_name: str):
    
    # setup model
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map=DEVICE, torch_dtype=t.float16)
    match model_name:
        case "openai-community/gpt2-xl":
            attr_submodules = [{
                "layer": i,
                "component": "mlp_output"} for i in range(len(model.transformer.h))
            ]
        case "google/gemma-7b":
            attr_submodules = [{
                "layer": i,
                "component": "mlp_output"} for i in range(len(model.model.layers))
            ]
        case "meta-llama/Meta-Llama-3.1-8B":
            attr_submodules = [{
                "layer": i,
                "component": "mlp_output"} for i in range(len(model.model.layers))
            ]
        case "meta-llama/Meta-Llama-3.1-70B":
            attr_submodules = [{
                "layer": i,
                "component": "mlp_output"} for i in range(len(model.model.layers))
            ]
    
    # load dataset
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    dataset = load_batch(DATASET, BATCH_SIZE, tokenizer)
    return {
        "model": model, 
        "attr_submodules": attr_submodules,
        "dataset": dataset
    }

   
def activation_patching(env_dict: dict):
    model, _, dataset = env_dict.values()
    
    clean_prefix, patch_prefix, clean_answer, patch_answer = dataset.values()
    
    # setup intervention on a single mlp
    patching_config = pv.IntervenableConfig({
        "layer": ACTIVATION_PATCH_LAYER,
        "component": "mlp_output",
        "intervention_type": pv.VanillaIntervention}
    )
    intervenable_model = pv.IntervenableModel(
        patching_config, model=model
    )
    intervenable_model.disable_model_gradients()

    # run activation patching
    intervened_outputs = intervenable_model(
        base={"input_ids": clean_prefix.to("cuda")},
        sources={"input_ids": patch_prefix.to("cuda")}
    )
    return intervened_outputs
