import torch
import pandas as pd

from lm_understanding.explanations.explanations import LocalExplanationSet
from lm_understanding.question_template import TemplateModelBehavior
from lm_understanding.models import make_full_model_name
from lm_understanding.models.hf_model import create_tokenizer
from salience.understand_salience import get_tokens, get_top_tokens

from tqdm import tqdm
from pathlib import Path
from safetensors import safe_open

from typing import Dict



def get_salience_pattern(path: Path) -> Dict[str, torch.Tensor]:
    """
    Load safetensors object stored at path
    """

    salience_pattern = {}
        
    with safe_open(path, framework="pt", device="cpu") as f:
        for key in f.keys():
            if key in salience_pattern: raise Exception()
            salience_pattern[key] = f.get_tensor(key)
    
    return salience_pattern



def get_model_behavior_train_df(model_behavior: TemplateModelBehavior) -> pd.DataFrame:
    """
    Get the training split of model_behavior in DataFrame form, retaining the original indexing so that it may be matched to the computed salience patterns
    """

    model_behavior_df = model_behavior.df
    
    df = pd.DataFrame(columns=['questions', 'answers', 'explanations', 'template_id'])
    df.questions = model_behavior_df.question_text
    df.answers = model_behavior_df.answer_prob
    df.template_id = model_behavior_df.template_id
    df['split'] = model_behavior_df.split
    df = df[df.split == 'train']
    df = df.drop(columns='split')

    return df



def explanation_preface() -> str:
    return "Pay attention to the following parts of the sentence: [SALIENT_TOKENS]"



def create_salience(model_behavior: TemplateModelBehavior, config) -> LocalExplanationSet:

    assert model_behavior.template_id is not None

    df = get_model_behavior_train_df(model_behavior)

    questions_with_prompt = df.questions.apply(lambda q: config.model.prompt.replace("[question]", q))

    # Get tokenized representation of prompted questions
    tokenizer = create_tokenizer(make_full_model_name(config.model.name))
    inputs = tokenizer(questions_with_prompt.to_list(), return_tensors='pt', max_length=350, padding='max_length')
    tokens = get_tokens(inputs['input_ids'], tokenizer)
    tokens = {str(df.index[i]): tokens[i] for i in range(len(tokens))}

    # Get pre-computed salience pattern
    salience_pattern = get_salience_pattern(path = Path("salience") / "salience_map_results" / config.model.name / config.task.name / model_behavior.template_id / f"{config.explanation.salience_map}.safetensors")

    # Put the top k tokens into the salience explanation dataframe
    for idx in tqdm(df.index):
        
        salience_pattern_idx = salience_pattern[str(idx)]
        tokens_idx = tokens[str(idx)]
        
        if len(salience_pattern_idx) > len(tokens_idx):
            # print(f"Salience vector of length {len(salience_pattern_idx)} is longer than tokens list of length {len(tokens_idx)} at index {idx}")
            assert tokenizer.padding_side in ['right','left']
            if tokenizer.padding_side == 'right':
                assert salience_pattern_idx[len(tokens_idx):].sum() == 0
                salience_pattern_idx = salience_pattern_idx[:len(tokens_idx)]
            if tokenizer.padding_side == 'left':
                assert salience_pattern_idx[:-len(tokens_idx)].sum() == 0
                salience_pattern_idx = salience_pattern_idx[-len(tokens_idx):]
            salience_pattern[str(idx)] = salience_pattern_idx
        
        if len(salience_pattern_idx) < len(tokens_idx):
            # print(f"Salience vector of length {len(salience_pattern_idx)} is shorter than tokens list of length {len(tokens_idx)} at index {idx}.")
            assert tokenizer.padding_side in ['right','left']
            if tokenizer.padding_side == 'right':
                assert all(tokens_idx[len(salience_pattern_idx):] == tokenizer.pad_token)
                tokens_idx = tokens_idx[:len(salience_pattern_idx)]
            if tokenizer.padding_side == 'left':
                assert all(tokens_idx[:-len(salience_pattern_idx)] == tokenizer.pad_token)
                tokens_idx = tokens_idx[-len(salience_pattern_idx):]
            tokens[str(idx)] = tokens_idx
        
        salient_tokens = get_top_tokens(salience_pattern_idx, tokens_idx, n_top_tokens=25, remove_empty=True, remove_special=tokenizer.all_special_tokens, unique_only=True)
        salient_tokens = salient_tokens.index.to_list()
        df.loc[idx, 'explanations'] = explanation_preface().replace("[SALIENT_TOKENS]", ' '.join(salient_tokens))
    
    return LocalExplanationSet(model_behavior.template_id, df.questions.tolist(), df.answers.tolist(), df.explanations.tolist())

