import torch

from metrics.attention_head import *
from utils.output import *

"""
Manage prompts/helper to get an output for a particular prompt
"""

# device = torch.device("cuda:0")


def get_prompts(dataset, data_name, start=None, end=None):
    if data_name == "hellaswag":
        prompts = dataset["test"]["ctx"]
    elif data_name == "gsm8k":
        prompts = dataset["test"]["question"]
    elif data_name == "mmlu":
        prompts = dataset["test"]["question"]
    elif data_name == "truthful_qa":
        prompts = dataset["validation"]["question"]
    elif data_name == "arc":
        prompts = dataset["test"]["question"]
    elif data_name == "winogrande":
        prompts = dataset["test"]["sentence"]
    else:
        print("Run config for dataset not implemented!")


    if start is None:
        start = 0
    if end is None or end == -1:
        prompts = prompts[start:]
        end = start + len(prompts)
    else:
        prompts = prompts[start:end]

    return prompts, start, end



def get_prompt_outputs(model, input_ids):

    # input_ids = tokenizer(prompt, return_tensors='pt').input_ids

    with torch.no_grad():
        outputs = model(input_ids, output_hidden_states=True)


    return outputs


def get_prompt_outputs_attention(model, input_ids):
    # input_ids = input_ids.to(device)

    with torch.no_grad():
        outputs = model(input_ids, output_hidden_states=True, output_attentions=True)

    return outputs



# def get_prompt_head_entropy(model, tokenizer, prompt):
#     """

#     :param model: HuggingFace transformer
#     :param tokenizer:
#     :param prompt: prompt to compute average head entropy for
#     :return: average head entropy, n_layer x n_token
#     """

#     input_ids = tokenizer(prompt, return_tensors='pt').input_ids

#     with torch.no_grad():
#         outputs = model(input_ids, output_hidden_states=True, output_attentions=True)

#     # states = get_out_tensor(outputs)
#     attn_list = get_out_attention(outputs)

#     avg = compute_avg_attn_entropies(attn_list, stacked=True)

#     return avg

def get_prompt_head_entropy(outputs):
    """

    :param model: HuggingFace transformer
    :param tokenizer:
    :param prompt: prompt to compute average head entropy for
    :return: average head entropy, n_layer x n_token
    """

    attn_list = get_out_attention(outputs)

    avg = compute_avg_attn_entropies(attn_list, stacked=True)


    return avg




