import torch as t
from functools import partial
from transformers import AutoModelForCausalLM, AutoTokenizer

from baukit import Trace, TraceDict

from src.constants import ACTIVATION_PATCH_LAYER, BATCH_SIZE, DEVICE, HF_ACCESS_TOKEN, DATASET
from src.utils import load_batch, logit_difference


def setup_environment(model_name: str):
    
    # setup model
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        token=HF_ACCESS_TOKEN, 
        device_map=DEVICE,
        torch_dtype=t.float16
    )

    match model_name:
        case "openai-community/gpt2-xl":
            submodule_name = f"transformer.h.{ACTIVATION_PATCH_LAYER}.mlp"
            attr_submodules = [f"transformer.h.{i}.mlp" for i in range(len(model.transformer.h))]
        case "google/gemma-7b":
            submodule_name = f"model.layers.{ACTIVATION_PATCH_LAYER}.mlp"
            attr_submodules = [f"model.layers.{i}.mlp" for i in range(len(model.model.layers))]
        case "meta-llama/Meta-Llama-3.1-8B":
            submodule_name = f"model.layers.{ACTIVATION_PATCH_LAYER}.mlp"
            attr_submodules = [f"model.layers.{i}.mlp" for i in range(len(model.model.layers))]
        case "meta-llama/Meta-Llama-3.1-70B":
            submodule_name = f"model.layers.{ACTIVATION_PATCH_LAYER}.mlp"
            attr_submodules = [f"model.layers.{i}.mlp" for i in range(len(model.model.layers))]
    
    # load dataset
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_ACCESS_TOKEN)
    dataset = load_batch(DATASET, BATCH_SIZE, tokenizer)
    return {
        "model": model, 
        "submodule_name": submodule_name,
        "attr_submodules": attr_submodules,
        "dataset": dataset
    }


def activation_patching(env_dict: dict):
    model, submodule_name, _, dataset = env_dict.values()
    
    clean_prefix, patch_prefix, _, _ = dataset.values()

    # cache clean activations
    with Trace(model, submodule_name) as ret:
        _ = model(clean_prefix.to("cuda"))
        clean_acts = ret.output
     
    # define patching function
    def patching_fn(representation, patched_activation):
        if(isinstance(representation, tuple)):
            representation = list(representation)
            representation[0] = patched_activation[0]
            representation = tuple(representation) 
        else:
            representation = patched_activation
        return representation
    patching_fn = partial(patching_fn, patched_activation=clean_acts)
    
    with Trace(model, submodule_name, edit_output=patching_fn) as ret:
        logits = model(patch_prefix.to("cuda"))
        
    return logits
