from transformer_lens import HookedTransformer
import torch
import os

def init_model(model_name, device=None, n_devices=1):

    os.environ['HF_TOKEN'] = None

    if(model_name in ["gpt2-sentiment"]):
        model = torch.load("./task_models/sentiment_model.pt")
    else:
        model = HookedTransformer.from_pretrained(model_name, device=device, n_devices=n_devices);
    
    model.set_use_hook_mlp_in(True)
    model.set_use_split_qkv_input(True)
    model.set_use_attn_result(True)
    model.set_use_attn_in(True)
    return model


def get_logit_ids(model, example_batch, correct, incorrect):
    
    correct_ids = []
    incorrect_ids = []

    for example in example_batch:
        values = example["values"]
        attributes = example["attributes"]

        correct_idx = attributes.index(correct)
        correct_value = values[correct_idx]
        correct_token_id = model.tokenizer.encode(correct_value)[-1]

        incorrect_idx = attributes.index(incorrect)
        incorrect_value = values[incorrect_idx]
        incorrect_token_id = model.tokenizer.encode(incorrect_value)[-1]

        correct_ids.append(correct_token_id)
        incorrect_ids.append(incorrect_token_id)

    return correct_ids, incorrect_ids


def init_metric(model, example_batch, correct=None, incorrect=None):
    
    input_tokens = torch.cat([model.tokenizer.encode(example["example"], return_tensors="pt") for example in example_batch], dim=0)
    
    if(correct is not None and incorrect is not None):
        correct_ids, incorrect_ids = get_logit_ids(model, example_batch, correct, incorrect)
        
        def metric(logits):
            batch_logits = [logits[i,-1,correct_ids[i]] - logits[i,-1,incorrect_ids[i]] for i in range(len(correct_ids))]
            return torch.sum(torch.stack(batch_logits), dim=0).mean(0)
        
    else:
        with torch.no_grad():
            correct_logits = model(input_tokens)[:,-1]
            correct_ids = correct_logits.argmax(1)

            def metric(logits):
                batch_logits = [logits[i,-1,correct_ids[i]] for i in range(len(correct_ids))]
                return torch.sum(torch.stack(batch_logits), dim=0).mean(0)

    return metric