# from sentence_transformers import SentenceTransformer
# from sentence_transformers import util as st_util
import torch
from scipy import linalg
from baukit import TraceDict
from transformers import StoppingCriteriaList, StoppingCriteria
from functools import partial
import torch
import sys

class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids[0][-1] == 2: # eos_token
            return True
        if input_ids[0][-1] == 13: # \n
            if (len(input_ids[0]) > 3) & (input_ids[0][-2]!=29901) & (input_ids[0][-3]!=29901):
                return True
        return False

@torch.no_grad()
def vanila_inference(model, tokenizer, raw_query, fschat, max_new_tokens, return_raw_output=False, output_hidden_states=False):
    '''
    vanilla autoregression
    '''
    # combine template with raw query & tokenize
    if len(fschat) > 0:
        query = fschat + '\n' + raw_query
    else:
        query = raw_query
    out = tokenizer(query, return_tensors="pt")
    input_ids = out.input_ids
    attention_mask = out.attention_mask
    
    if 'MambaLMHeadModel' in str(model.__class__):
        outs = model.generate(input_ids=input_ids.cuda(), 
                              attention_mask=attention_mask.cuda(),
                              max_length=len(input_ids[0])+max_new_tokens
                             )
    else:
        outs = model.generate(inputs=input_ids.cuda(), 
                            attention_mask=attention_mask.cuda(),
                            max_new_tokens=max_new_tokens,
                            # stopping_criteria=StoppingCriteriaList([StopOnTokens()]),
                            eos_token_id=tokenizer.eos_token_id,
                            use_cache=True)
    outstr = tokenizer.decode(outs[0], skip_special_tokens=True)
    if not return_raw_output:
        return outstr
    else:
        return outstr, outs

def robo_decoding(model, tokenizer,
                  raw_query,
                  fschat,
                  max_new_tokens,
                  st_model,
                  top_k=25,
                  hh_score_weight=5,
                  helpful_sentence="Assistance's response is helpful, safe, healthy, useful.",
                  harmful_sentence="Assistance's response is dangerous, misleading, problematic."):
    
    # combine template with raw query & tokenize
    query = fschat + '\n' + raw_query
    input_ids = tokenizer(query, return_tensors="pt").input_ids[0]
        
    # get sentence embeddings
    
    helpful_sentence = "Assistance's response is helpful, safe, healthy, useful."
    harmful_sentence = "Assistance's response is dangerous, misleading, problematic."
    
    helpful_emb = st_model.encode(helpful_sentence, show_progress_bar=False)
    harmful_emb = st_model.encode(harmful_sentence, show_progress_bar=False)
    
    fschat_token_len = len(tokenizer(fschat, return_tensors="pt").input_ids[0])
    raw_query_start_idx = fschat_token_len
    query = fschat + '\n' + raw_query
    input_ids = tokenizer(query, return_tensors="pt").input_ids[0]
    
    for _ in range(max_new_tokens):
        # get top k logits in token prediction
        
        outs = model.generate(inputs=input_ids.reshape(1, -1).cuda(), max_new_tokens=1, output_scores=True, return_dict_in_generate=True)
        
        outs_scores = outs[1][0]
        top_k_values, top_k_indices = torch.topk(input=outs_scores, k=top_k)
        top_k_values = torch.nn.functional.softmax(top_k_values, dim=-1) # normalize only with top k
        
        # Compute hh scores
        hh_scores = []
        
        for w_idx in top_k_indices[0]:
            w = tokenizer.decode(w_idx)
            hypothesis = raw_query + ' ' + w
            hypothesis_emb = st_model.encode(hypothesis, show_progress_bar=False)
            helpful_score = st_util.cos_sim(hypothesis_emb, helpful_emb)
            harmful_score = st_util.cos_sim(hypothesis_emb, harmful_emb)
            hh_scores.append((helpful_score - harmful_score).item())
            
        # Recompute token scores by weighted sum
        top_k_values = top_k_values.cpu()
        top_k_indices = top_k_indices.cpu()
        new_scores = top_k_values + hh_score_weight * torch.tensor(hh_scores)
        
        # Make new prediction
        top_idx = top_k_indices[0][torch.argmax(new_scores)]
        top_token = tokenizer.decode(top_idx)
        if top_token=='</s>': # stopping token
            break
        if top_token=='\n' and query[-1]!=':':
            break
        
        # Update input
        input_ids = input_ids.tolist()
        input_ids.append(top_idx.item())
        input_ids = torch.tensor(input_ids)
        
        query = tokenizer.decode(input_ids)
        raw_query = tokenizer.decode(input_ids[raw_query_start_idx:])
    
    final_output = query
    return final_output

def robo_decoding_multi(model, tokenizer,
                  raw_query,
                  fschat,
                  max_new_tokens,
                  helpful_harmful_pairs,
                  top_k=25,
                  hh_score_weight=5,
                  st_model_name='all-MiniLM-L6-v2'):

    # combine template with raw query & tokenize
    query = fschat + '\n' + raw_query
    input_ids = tokenizer(query, return_tensors="pt").input_ids[0]

    # get sentence embeddings
    st_model = SentenceTransformer(st_model_name)

    helpful_sentences = []
    harmful_sentences = []

    for insight_idx in range(len(helpful_harmful_pairs)):
        helpful_sentence = f"Assistance's response is {helpful_harmful_pairs[insight_idx]['helpful'].lower()}."
        harmful_sentence = f"Assistance's response is {helpful_harmful_pairs[insight_idx]['harmful'].lower()}."
        helpful_sentences.append(helpful_sentence)
        harmful_sentences.append(harmful_sentence)

    helpful_emb = st_model.encode(helpful_sentences)
    harmful_emb = st_model.encode(harmful_sentences)

    fschat_token_len = len(tokenizer(fschat, return_tensors="pt").input_ids[0])
    raw_query_start_idx = fschat_token_len
    query = fschat + '\n' + raw_query
    input_ids = tokenizer(query, return_tensors="pt").input_ids[0]

    for _ in range(max_new_tokens):
        # get top k logits in token prediction
        outs = model.generate(inputs=input_ids.reshape(1, -1).cuda(), max_new_tokens=1, output_scores=True, return_dict_in_generate=True)
        outs_scores = outs[1][0]
        top_k_values, top_k_indices = torch.topk(input=outs_scores, k=top_k)
        top_k_values = torch.nn.functional.softmax(top_k_values, dim=-1) # normalize only with top k

        # Compute hh scores
        hh_scores = []

        for w_idx in top_k_indices[0]:
            w = tokenizer.decode(w_idx)
            hypothesis = raw_query + ' ' + w
            hypothesis_emb = st_model.encode(hypothesis)
            helpful_score = st_util.cos_sim(hypothesis_emb, helpful_emb)
            harmful_score = st_util.cos_sim(hypothesis_emb, harmful_emb)
            hh_scores.append((helpful_score.mean() - harmful_score.mean()).item())

        # Recompute token scores by weighted sum
        top_k_values = top_k_values.cpu()
        top_k_indices = top_k_indices.cpu()
        new_scores = top_k_values + hh_score_weight * torch.tensor(hh_scores)

        # Make new prediction
        top_idx = top_k_indices[0][torch.argmax(new_scores)]
        top_token = tokenizer.decode(top_idx)
        if top_token=='</s>': # stopping token
            break
        if top_token=='\n' and query[-1]!=':':
            break

        # Update input
        input_ids = input_ids.tolist()
        input_ids.append(top_idx.item())
        input_ids = torch.tensor(input_ids)

        query = tokenizer.decode(input_ids)
        raw_query = tokenizer.decode(input_ids[raw_query_start_idx:])
        
    final_output = query
    return final_output

def get_interventions_dict(q_concept):
    interventions = {}
    layer_outputs = [f"lm_head"]
    interventions[f"lm_head"] = q_concept
    return interventions

def get_answer_with_intervention(model, tokenizer, prompt, max_new_tokens=1024, interventions={}, intervention_fn=None):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
    # --- intervention code --- #
    def id(head_output, layer_name): 
        return head_output
    if interventions == {}: 
        intervene = id
        layers_to_intervene = []
    else: 
        intervene = partial(intervention_fn, interventions=interventions)
        layers_to_intervene = list(interventions.keys())
    # --- intervention code --- #
    input_token_len = input_ids.shape[1]
    with torch.inference_mode():
        with TraceDict(model, layers_to_intervene, edit_output=intervene) as ret: 
            model_output = model.generate(inputs = input_ids, 
                                          max_new_tokens=max_new_tokens,
                                          stopping_criteria=StoppingCriteriaList([StopOnTokens()]),
                                          use_cache=True
                                            )
        outstr = tokenizer.decode(model_output[0], skip_special_tokens=True)
    torch.cuda.empty_cache()
    return outstr

def lt_modulated_roboshot(layer_output, layer_name, interventions):
    if layer_output.shape[1] > 1:
        return layer_output
    layer_output = layer_output.squeeze()
    q_concept = interventions[layer_name]
    for orthonormal_vector in q_concept:
        orthonormal_vector = torch.Tensor(orthonormal_vector.reshape(1, -1)).cuda()
        cos = torch.nn.functional.cosine_similarity(layer_output.squeeze(), orthonormal_vector)
        rejection_features = cos * torch.repeat_interleave(orthonormal_vector, cos.shape[0], 0)
        # /torch.norm(orthonormal_vector)
        layer_output = layer_output - rejection_features
    layer_output = layer_output.to(torch.float16)
    layer_output = layer_output.unsqueeze(0)
    return layer_output

def roboalign_lm_head(model, tokenizer,
                    fschat,
                    raw_query,
                    max_new_tokens,
                    harmful_embeddings,
                    helpful_embeddings=[],
                    qr=False
                    ):
    if qr:
        if len(helpful_embeddings) == 0:
            q_spurious_all, _ = linalg.qr(harmful_embeddings.T, mode='economic')
        else:
            q_spurious_all, _ = linalg.qr((harmful_embeddings-helpful_embeddings).T, mode='economic')
        q_spurious_all= q_spurious_all.T
        interventions = get_interventions_dict(q_spurious_all)
    else:
        if len(helpful_embeddings) == 0:
            interventions = get_interventions_dict(harmful_embeddings)
        else:
            interventions = get_interventions_dict(harmful_embeddings-helpful_embeddings)
    query = fschat + '\n' + raw_query
    return get_answer_with_intervention(model, tokenizer, query, max_new_tokens, interventions, intervention_fn=lt_modulated_roboshot)
