import os

from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
from baukit import TraceDict
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm

import re
import sys

import pandas as pd
from sentence_transformers import SentenceTransformer

from sklearn.metrics.pairwise import cosine_similarity
from scipy import linalg 


from utils.data_utils import set_seed, build_hhrlhf_dataset, load_hhrlhf_template
from torch.utils.data import DataLoader
from transformers import StoppingCriteriaList, StoppingCriteria
import json


from functools import partial
from utils.inference import vanila_inference, StopOnTokens
import argparse


def get_single_activation(model, query):
    MLPS_OUT = [f"model.layers.{i}.mlp" for i in range(model.config.num_hidden_layers)]
    input_ids = tokenizer(query, return_tensors="pt").input_ids.cuda()
    with torch.no_grad():
        with TraceDict(model, MLPS_OUT) as ret:
            output = model(input_ids, output_hidden_states = True)
        mlp_out = [ret[mlp_].output.squeeze().detach().cpu() for mlp_ in MLPS_OUT]
        mlp_out = torch.stack(mlp_out, dim = 0).squeeze().numpy()
    return mlp_out[:, -1, :]

def get_insights_emb(model, pos_insights, neg_insights):
    p_embed = []
    n_embed = []
    for f_pos in tqdm(pos_insights):
        try:
            p_embed_ = get_single_activation(model, f_pos)
            p_embed.append(p_embed_)
        except Exception as e:
            raise e
    for f_neg in tqdm(neg_insights):
        try:
            n_embed_ = get_single_activation(model, f_neg)
            n_embed.append(n_embed_)
        except Exception as e:
            raise e
    return p_embed, n_embed

def get_insights(data, df_insights):
    pos_rows = df_insights[df_insights['label']==1]
    neg_rows = df_insights[df_insights['label']==0]

    pos_samples = pos_rows['insight'].tolist()
    neg_samples = neg_rows['insight'].tolist()

    print("LEN POS NEG SAMPLES", len(pos_samples), len(neg_samples))

    pos_samples = [s for s in pos_samples]
    if "mistral" in model_name.lower() or  "gemma" in model_name.lower():
        str_split = "You are a "
    else:
        str_split = "Your answer "
    for i in range(len(pos_samples)):
        s = pos_samples[i]
        try:
            pos_samples[i] = s.split(str_split)[0].strip().rstrip()+"\n"+"Assistant:"+s.split("Assistant:")[-1].strip().rstrip()
        except:
            continue

    neg_samples = [s for s in neg_samples]
    for i in range(len(neg_samples)):
        s = neg_samples[i]
        try:
            neg_samples[i] = s.split(str_split)[0].strip().rstrip()+"\n"+"Assistant:"+s.split("Assistant:")[-1].strip().rstrip()
        except Exception as e:
            continue

    q_all = data['instruction'].tolist()

    dict_q = {i: {'question': q,'pos':[], 'neg':[]} for i, q in enumerate(q_all)}
    
    for p in pos_samples:
        for idx in dict_q:
            q = dict_q[idx]['question']
            if q.lower() in p.lower():
                dict_q[idx]['pos'].append(p)
    for n in neg_samples:
        for idx in dict_q:
            q = dict_q[idx]['question']
            if q.lower() in n.lower():
                dict_q[idx]['neg'].append(n)

    pos_insight_id_tracker = 0
    neg_insight_id_tracker = 0
    for idx in dict_q:
        dict_q[idx]['pos'] = list(set(dict_q[idx]['pos']))
        dict_q[idx]['neg'] = list(set(dict_q[idx]['neg']))
        if idx == 0:
            dict_q[idx]['pos_insight_idx'] = [i for i in range(len(dict_q[idx]['pos']))]
            dict_q[idx]['neg_insight_idx'] = [i for i in range(len(dict_q[idx]['neg']))]
        else:
            dict_q[idx]['pos_insight_idx'] = [pos_insight_id_tracker+i for i in range(len(dict_q[idx]['pos']))]
            dict_q[idx]['neg_insight_idx'] = [neg_insight_id_tracker+i for i in range(len(dict_q[idx]['neg']))]
        pos_insight_id_tracker += len(dict_q[idx]['pos_insight_idx'])
        neg_insight_id_tracker += len(dict_q[idx]['neg_insight_idx'])
    return dict_q

def get_interventions_dict(pos_emb, neg_emb):
    layer_idxs = [i for i in range(n_layers)]
    tmp_dict = {}
    for layer_idx in layer_idxs:
        pos_emb_ = np.vstack([p[layer_idx, :] for p in pos_emb])
        neg_emb_ = np.vstack([n[layer_idx, :] for n in neg_emb])
        
        _,_,v_pos = linalg.svd(pos_emb_, full_matrices=False)
        _,_,v_neg = linalg.svd(neg_emb_, full_matrices=False)

        tmp_dict[layer_idx] = (v_pos[0,:], v_neg[0,:])
        
    chosen_idxs = layer_idxs
    interventions = {}

    tmp_ = []
    for idx in chosen_idxs:
        tmp_.append(tmp_dict[idx][0])
    tmp_ = np.vstack(tmp_)
    cos = cosine_similarity(tmp_, tmp_)

    sum_cos = np.sum(cos, axis=0)
    sorted_ = np.argsort(sum_cos)[::-1]
    chosen_idxs = sorted_[:n_layers_to_edit]
        
    chosen_idxs = np.sort(chosen_idxs)
    print("CHOSEN IDXS", chosen_idxs)

    norm_prev_pos = 1
    norm_prev_neg = 1
    for i, layer_idx in enumerate(chosen_idxs):
        interventions[f"model.layers.{layer_idx}.mlp"] = tmp_dict[layer_idx]
    return interventions

def lt_modulated_proj(layer_output, layer_name, interventions):
    v_pos, v_neg = interventions[layer_name]
    layer_output = layer_output.squeeze() 
    if len(layer_output.shape) > 1:
        x_test = layer_output[-1,:]
    else:
        x_test = layer_output
    
    x_test = x_test.squeeze()

    v_pos = torch.Tensor(v_pos).to(model.dtype).to(x_test.device).squeeze() 
    v_neg = torch.Tensor(v_neg).to(model.dtype).to(x_test.device).squeeze() 
    
    proj_harm = torch.dot(x_test, v_neg)/torch.linalg.vector_norm(v_neg)
    proj_harm = proj_harm * v_neg
    proj_harm = x_test - proj_harm

    proj_help = torch.dot(proj_harm, v_pos)/torch.linalg.vector_norm(v_pos)
    proj_help = proj_harm * v_pos
    proj_help = proj_harm + proj_help
    proj = proj_help
 
    if len(layer_output.shape) > 1:
        layer_output[-1,:] = proj
    else:
        layer_output = proj
        
    layer_output = layer_output.unsqueeze(0)
    layer_output = layer_output.to(model.dtype)
    layer_output = layer_output.to(model.device)
    return layer_output

def get_answer_with_intervention(model, tokenizer, prompt, max_new_tokens=1024, interventions={}, intervention_fn=None):
    out = tokenizer(prompt, return_tensors="pt")
    input_ids = out.input_ids.cuda()
    attention_mask = out.attention_mask.cuda()
    # --- intervention code --- #
    def id(head_output, layer_name): 
        return head_output
    if interventions == {}: 
        intervene = id
        layers_to_intervene = []
    else: 
        intervene = partial(intervention_fn, interventions=interventions)
        layers_to_intervene = list(interventions.keys())
    # --- intervention code --- #
    input_token_len = input_ids.shape[1]
    with torch.inference_mode():
        with TraceDict(model, layers_to_intervene, edit_output=intervene) as ret: 
            model_output = model.generate(inputs = input_ids, 
                                          attention_mask = attention_mask,
                                          max_new_tokens=max_new_tokens,
                                          eos_token_id=tokenizer.eos_token_id,
                                        #   stopping_criteria=StoppingCriteriaList([StopOnTokens()]),
                                          use_cache=True,
                                          # repetition_penalty=1.1
                                         )
        outstr = tokenizer.decode(model_output[0], skip_special_tokens=True)
    torch.cuda.empty_cache()
    return outstr

def get_st_embeddings(emb_model, questions):
    embeddings = emb_model.encode(questions)
    return embeddings


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, required=True)
    parser.add_argument("--model-name", type=str, required=True)
    parser.add_argument("--k", type=int, default=10)
    parser.add_argument("--n-layers", type=int, default=5)

    args = parser.parse_args()
    dataset_name = args.dataset

    top_k = args.k

    n_layers_to_edit = args.n_layers
    model_name = args.model_name
    model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True,
                                                    trust_remote_code=True, 
                                                    torch_dtype=torch.float16,
                                                    device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    emb_model = SentenceTransformer("infgrad/stella_en_1.5B_v5", trust_remote_code=True).to("cuda")


    device = "cuda"  

    n_layers = model.config.num_hidden_layers
    n_heads = model.config.num_attention_heads

    template_path = 'data/hh-rlhf' # '../data/hh-rlhf'
    SEED = 0
    set_seed(SEED)

    max_new_tokens = 1024
    if 'llama' in model_name.lower():
        outdir=f'llama31_{dataset_name}'
    else:
        outdir=f'mistral3_{dataset_name}'
    print(model_name, outdir)

    template = load_hhrlhf_template(template_path)
    fschat = template['fschat']
    data = pd.read_csv(f'data/just_eval_bydataset/{dataset_name}.csv')

    if 'llama' in model_name.lower():
        insights_file = f'alignment/llama31_generated_data/{dataset_name}_insights.csv'
    else:
        insights_file = f'alignment/mistral3_generated_data/{dataset_name}_insights.csv'
    df = pd.read_csv(insights_file)
    inference_fun = partial(vanila_inference, fschat=fschat, max_new_tokens=max_new_tokens)

    dict_q = get_insights(data, df)
    questions_all = [dict_q[idx]['question'] for idx in dict_q]
    q_st_embedding = get_st_embeddings(emb_model, questions_all)

    pos_insights_all = []
    neg_insights_all = []
    for idx in dict_q:
        pos_insights_all.extend(dict_q[idx]['pos'])
        neg_insights_all.extend(dict_q[idx]['neg'])
    pos_insights_all = np.array(pos_insights_all)
    neg_insights_all = np.array(neg_insights_all)

    pos_emb_all, neg_emb_all = get_insights_emb(model, pos_insights_all, neg_insights_all)
    pos_emb_st = get_st_embeddings(emb_model, pos_insights_all)
    neg_emb_st = get_st_embeddings(emb_model, neg_insights_all)

    nn_dict = {}
    for q_idx in range(len(q_st_embedding)):
        query_embedding = q_st_embedding[q_idx,:]
        similarity_scores = emb_model.similarity(query_embedding, q_st_embedding)[0]
        scores, indices = torch.topk(similarity_scores, k=top_k)
        nn_dict[q_idx] = indices.detach().cpu().numpy().tolist()

    intervention_dict_by_query = {}
    n_no_interv = 0
    scores_all = []
    for q_idx in tqdm(range(len(q_st_embedding))):
        nn_idxs = nn_dict[q_idx]
        pos_insights_idxs = []
        neg_insights_idxs = []
        pos_emb_nn = []
        neg_emb_nn = []
        for i in nn_idxs:
            sample_pos_idxs = dict_q[i]['pos_insight_idx']
            sample_neg_idxs = dict_q[i]['neg_insight_idx']
            pos_insights_idxs.extend(sample_pos_idxs)
            neg_insights_idxs.extend(sample_neg_idxs)     
        pos_emb_nn = [pos_emb_all[i] for i in pos_insights_idxs]
        neg_emb_nn = [neg_emb_all[i] for i in neg_insights_idxs]
        if len(pos_emb_nn) == 0 or len(neg_emb_nn) == 0:
            intervention_dict_by_query[q_idx] = {}
            n_no_interv +=1
            continue
        else:
            intervention_dict = get_interventions_dict(pos_emb_nn, neg_emb_nn)
            intervention_dict_by_query[q_idx] = intervention_dict

    exp_folder = f"roboemb_mlp"
    cached = False

    i = 0
    for i, q in tqdm(enumerate([dict_q[k]['question'] for k in dict_q])): # s: sentence
        print(f"########### {i} ##########")
        raw_query = q
        raw_query = f"{q}\nAssistant:"
        query = fschat + '\n' + raw_query
        print('query')
        print(raw_query)
        try:
            vanila_output = inference_fun(raw_query=raw_query, model=model, tokenizer=tokenizer)
            ans_query = vanila_output.split(query)[-1].split("Human: ")[0].strip().rstrip().split("Assistant: ")[0].strip().rstrip()
            print('VANILLA MODEL')
            print(ans_query)
            intervention_dict = intervention_dict_by_query[i]
            if len(intervention_dict) > 0:
                out = get_answer_with_intervention(model, tokenizer, query, \
                                                max_new_tokens=max_new_tokens, interventions=intervention_dict, \
                                                intervention_fn=lt_modulated_proj)
            else:
                out = ans_query
            print('OURS')
            out_ans = out.split(query)[-1].split("Human: ")[0].strip().rstrip().split("Assistant: ")[0].strip().rstrip()
            print(out_ans)

            tmp = {'question': raw_query,
                'vanila': ans_query,
                'embedding_intervention_output': out_ans
                }
        
            if not os.path.exists(os.path.join(outdir, exp_folder)):
                os.makedirs(os.path.join(outdir, exp_folder))
            with open('{}/{}/hh-rlhf_{}_res_{}.json'.format(outdir, exp_folder, 'mistral@7b', i), 'w') as f:
                f.write(json.dumps(tmp))
        except Exception as e:
            raise e

