import re
import random
import string

from scipy import stats
import os
import copy
import json
from sklearn.metrics.pairwise import cosine_similarity as cossim
import pandas as pd
import time
import numpy as np
# from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from tqdm import tqdm
import gc
# from utils.utils import upload_blob_from_memory
from datetime import datetime
from pathlib import Path
from baukit import Trace, TraceDict

if os.uname()[0] != 'Darwin':  # if not on mac
    # from auto_gptq import exllama_set_max_input_length
    from accelerate.utils import release_memory
    import gc

    device_name = 'cuda'


    def flush():
        # try:
        #     model.cpu()
        #     del model
        # except:
        #     pass
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        # release_memory()
else:
    device_name = 'mps'


    def flush():
        pass

# Questionnaire maps
phq9_map = {'Not at all': 0, 'Several days': 1, 'More than half the days': 2, 'Nearly every day': 3}
ami_map = {'Completely untrue': 4, 'Mostly untrue': 3, 'Neither true nor untrue': 2, 'Quite true': 1,
           'Completely true': 0}
ami_inv_map = {v: k for k, v in ami_map.items()}
sds_map = {"A little of the time": 1, "Some of the time": 2, "Good part of the time": 3, "Most of the time": 4}
sds_rev = [2, 5, 6, 11, 12, 14, 16, 17, 18, 20]
gad7_map = {'Not at all': 0, 'Several days': 1, 'More than half the days': 2, 'Nearly every day': 3}
gad7_inv_map = {v: k for k, v in gad7_map.items()}

letters = list(string.ascii_uppercase)

maps = {'phq9': phq9_map, 'ami': ami_map, 'sds': sds_map, 'gad7': gad7_map}
rev_lists = {'sds': sds_rev}
inv_lists = {'ami': ami_inv_map, 'gad7': gad7_inv_map}

phq9_qs_map = {'lvl3_q1': 'phq9_q1', 'lvl3_q2': 'phq9_q2', 'lvl3_q3': 'phq9_q3', 'lvl3_q4': 'phq9_q4',
               'lvl3_q5': 'phq9_q5', 'lvl3_q6': 'phq9_q6', 'lvl3_q7': 'phq9_q7', 'lvl3_q8': 'phq9_q8',
               'lvl1_q1': 'lvl1_closed_q1', 'lvl2_q1': 'lvl2_closed_q1', 'lvl2_q2': 'lvl2_closed_q2',
               'lvl2_q3': 'lvl2_closed_q3', 'rep_lvl2_q1': 'lvl2_closed_q1'}
qs_maps = {'phq9': phq9_qs_map}
phq9_qs_inv_map = {'phq9_q1': 'lvl3_q1', 'phq9_q2': 'lvl3_q2', 'phq9_q3': 'lvl3_q3', 'phq9_q4': 'lvl3_q4',
                   'phq9_q5': 'lvl3_q5', 'phq9_q6': 'lvl3_q6', 'phq9_q7': 'lvl3_q7', 'phq9_q8': 'lvl3_q8'}

qs_inv_maps = {'phq9': phq9_qs_inv_map}

lvlx_closed_map = {
    'Very Good': 0, 'Good': 1, 'Bad': 2, 'Very Bad': 3
}

qa_loc_key_names = ['oq_qs', 'oq_Answer', 'oq_ans', 'cq_qs', 'cq_Answer', 'last']
qa_loc_key_names_gen = ['cq_qs', 'cq_Answer', 'last']

phq9_qs_map = {'lvl3_q1': 'phq9_q1', 'lvl3_q2': 'phq9_q2', 'lvl3_q3': 'phq9_q3', 'lvl3_q4': 'phq9_q4',
               'lvl3_q5': 'phq9_q5', 'lvl3_q6': 'phq9_q6', 'lvl3_q7': 'phq9_q7', 'lvl3_q8': 'phq9_q8',
               'lvl1_q1': 'lvl1_closed_q1', 'lvl2_q1': 'lvl2_closed_q1', 'lvl2_q2': 'lvl2_closed_q2',
               'lvl2_q3': 'lvl2_closed_q3', 'rep_lvl2_q1': 'lvl2_closed_q1'}


class SteerConfig:
    def __init__(self, layer_ids, steering_vectors, device, n_tokens=5, multiplier=1.0, sample_text=False,
                 run_gen=False,
                 perturb_input_only=True, norm_vecs=True):
        self.layer_ids = layer_ids
        self.steering_vectors = steering_vectors
        self.multiplier = multiplier
        self.device = device

        self.run_gen = run_gen  # whether to generate text
        self.n_tokens = n_tokens
        self.sample_text = sample_text
        self.perturb_input_only = perturb_input_only
        self.norm_vecs = norm_vecs


class SteerHiddenState:
    def __init__(self, steer_config, model, tokenizer):
        # self.steering_vectors = {f'model.layers.{layer_id}': [] for layer_id in steer_config.layer_ids}
        self.steering_vectors = steer_config.steering_vectors
        self.modules_to_steer = {f'model.layers.{layer_id}': model.model.layers[layer_id] for layer_id in
                                 steer_config.layer_ids}
        self.steer_config = steer_config
        self.hooks_dict = None
        self.counter = 0
        self.tokenizer = tokenizer

        self.input_ids = None
        self.attention_mask = None
        self.logits = None
        self.logits_original = None
        self.hidden_states = None
        self.hidden_states_original = None
        self.output_ids = None
        self.output_text = None
        self.output_text_original = None

    def steering_hook_fn(self, steering_vector):
        def steering_hook(output):
            """
            This hook is applied to the output of a layer.
            It adds the steering vector to the last token's activation.
            Optionally allows to add activation at every new token
            Optionally can normalise the added activation
            """
            hidden_state = output[0]

            if self.counter < len(self.steering_vectors) or (not self.steer_config.perturb_input_only):
                self.counter += 1
                if self.steer_config.norm_vecs:
                    # perturbation_vector = steering_vector
                    # perturbation_vector=steering_vector
                    steering_vector_unit = steering_vector
                    # steering_vector_unit = steering_vector / steering_vector.norm(dim=-1)
                    # # hidden_states_dot_sign = 1*torch.sign(hidden_state[:,-1,:]@steering_vector_unit.T)
                    # # # perturbation_vector = steering_vector*hidden_states_dot_sign
                    # perturbation_vector = steering_vector_unit * hidden_states_norm
                    # print(hidden_states_norm.shape,self.counter)
                    hidden_states_norm = hidden_state[:, -1, :].norm(dim=-1, keepdim=True)
                    tmp_sign = torch.sign(steering_vector_unit @ hidden_state[:,-1,:].T)
                    # tmp_proj = hidden_state[:,-1,:] * (tmp_sign*steering_vector_unit @ hidden_state[:,-1,:].T) / (hidden_states_norm ** 2)
                    tmp_proj = hidden_state[:,-1,:] * (steering_vector_unit @ hidden_state[:,-1,:].T) / (hidden_states_norm ** 2)
                    perturbation_vector = steering_vector - tmp_proj

                    # perturbation_vector = steering_vector_unit
                    if self.steer_config.multiplier != 0:
                        hidden_state[:, -1, :] = tmp_sign*hidden_state[:, -1, :] +self.steer_config.multiplier * perturbation_vector
                        # hidden_state[:, -1, :] += self.steer_config.multiplier * perturbation_vector
                        # hidden_state[:, -1, :] = self.steer_config.multiplier * perturbation_vector
                else:
                    hidden_state[:, -1, :] += self.steer_config.multiplier * steering_vector
                    # hidden_state[:, -1, :] = self.steer_config.multiplier * steering_vector

            return (hidden_state,) + output[1:]

        return steering_hook

    def create_hooks_dict(self):
        self.hooks_dict = {
            module_name: self.steering_hook_fn(self.steering_vectors[module_name])
            for module_name in self.modules_to_steer.keys()
        }

    def steer(self, model, input, return_org=False):
        self.create_hooks_dict()
        input_ids = self.tokenizer(input, return_tensors='pt', padding=True, return_attention_mask=True).to(device_name)
        self.attention_mask = input_ids.attention_mask
        self.input_ids = input_ids.input_ids

        # Forward pass with activation steering
        self.counter = 0
        with TraceDict(model, self.modules_to_steer, edit_output=self.hooks_dict):
            model_output = model(self.input_ids, output_hidden_states=True, attention_mask=self.attention_mask)
            self.logits = model_output.logits[:, -1, :]
            self.hidden_states = model_output.hidden_states

        # If to generate new tokens (activations for new tokens is not steered - just the last token in input)
        if self.steer_config.run_gen:
            self.counter = 0
            with TraceDict(model, self.modules_to_steer, edit_output=self.hooks_dict):
                model_output_gen = model.generate(self.input_ids, attention_mask=self.attention_mask,
                                                  max_new_tokens=self.steer_config.n_tokens,
                                                  do_sample=self.steer_config.sample_text)
                self.output_text = self.tokenizer.batch_decode(model_output_gen, skip_special_tokens=True)

        # Forward pass (and optional geneation) without activation steering
        if return_org:
            model_output_original = model(self.input_ids, output_hidden_states=True, attention_mask=self.attention_mask)
            self.logits_original = model_output_original.logits[:, -1, :]
            self.hidden_states_original = model_output_original.hidden_states

            # Generate new tokens original
            if self.steer_config.run_gen:
                model_output_gen_original = model.generate(self.input_ids, attention_mask=self.attention_mask,
                                                           max_new_tokens=self.steer_config.n_tokens,
                                                           do_sample=self.steer_config.sample_text)
                self.output_text_original = self.tokenizer.batch_decode(model_output_gen_original,
                                                                        skip_special_tokens=True)


def find_subsequence_n(text, subtext, n=1):
    text_trim = copy.deepcopy(text)
    loc_counter = 0
    id_loc = -1
    for oc in range(n):
        location = text_trim.find(subtext)  # find substring location start
        if location != -1:
            id_loc = location + loc_counter

            loc_counter += location + len(subtext) + 1  # location in text after the first occurence
            text_trim = text[loc_counter:]  # trim the string so tht the first occurence is not there
        else:
            id_loc = -1
            break
    return id_loc


def load_task_content(sample_config):
    instructions_file = f"{sample_config.prompts_path}experiment/{sample_config.instr_name}.txt"
    openq_file = f"{sample_config.prompts_path}experiment/{sample_config.qs_name}_open_{sample_config.gen_fname}.txt"
    closed_qs_file = f"{sample_config.prompts_path}qs/custom/custom_{sample_config.gen_fname}.txt"
    qsn_qs_file = f"{sample_config.prompts_path}qs/{sample_config.qs_name}/{sample_config.qs_name}.txt"

    with open(instructions_file, encoding='utf-8') as f:
        instructions = f.read().split('^^^')
        # instructions = f.read()

    instr_dict = {}
    for instr in instructions:
        instr_name = instr.split('\n')[1]
        instr_content = '\n'.join(instr.split('\n')[2:])
        instr_dict[instr_name] = instr_content

    with open(openq_file, encoding='utf-8') as f:
        open_qs = f.read().split('^^^')

    open_qs_dict = {}
    for open_q in open_qs:
        open_q_name = open_q.split('\n')[1]
        open_q_content = open_q.split('\n')[2:][0]
        open_qs_dict[open_q_name] = open_q_content

    with open(closed_qs_file, encoding='utf-8') as f:
        closed_qs = f.read()

    closed_qs_sections = closed_qs.split('\n^^\n')
    closed_qs = closed_qs_sections[1].split('\n')
    closed_qs_dict = {q.split('.')[0]: q.split('. ')[1] for q in closed_qs}

    with open(qsn_qs_file, encoding='utf-8') as f:
        qsn_qs = f.read()
    prompt_sections = qsn_qs.split('\n^^\n')
    qsn_qs = prompt_sections[2].split('\n')
    if sample_config.skip_sui:
        qsn_qs.pop()

    inv_map = qs_inv_maps[sample_config.qs_name]
    qsn_qs_dict = {inv_map[sample_config.qs_name + '_q' + q.split('.')[0]]: 'Problem: ' + q.split('. ')[1] for q in
                   qsn_qs}

    closed_qs_dict = dict(sorted((qsn_qs_dict | closed_qs_dict).items()))
    return instr_dict, open_qs_dict, closed_qs_dict


def setup_model(sample_config):
    tokenizer = AutoTokenizer.from_pretrained(sample_config.model_name_hf, use_fast=True, trust_remote_code=True)
    if sample_config.model_name not in ['gemma-2-2b-it', 'gemma2-9b-it']:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    if sample_config.model_name in ['Llama-3.1-8B-Instruct', 'Llama-3.2-3B-Instruct']:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    model = AutoModelForCausalLM.from_pretrained(sample_config.model_name_hf, device_map=device_name,
                                                 trust_remote_code=True, revision=sample_config.revision,
                                                 attn_implementation='eager', torch_dtype=torch.bfloat16)
    model.eval()
    # model = AutoModelForCausalLM.from_pretrained(sample_config.model_name_hf, device_map="auto",
    #                                              trust_remote_code=False, revision=sample_config.revision,
    #                                              torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
    # model = AutoModelForCausalLM.from_pretrained(sample_config.model_name_hf, device_map="auto",
    #                                              trust_remote_code=False, revision=sample_config.revision,
    if sample_config.model_name not in ['gemma-2-2b-it', 'gemma2-9b-it']:
        model.generation_config.pad_token_id = model.generation_config.eos_token_id
    if sample_config.model_name in ['Llama-3.1-8B-Instruct', 'Llama-3.2-3B-Instruct']:
        model.generation_config.pad_token_id = tokenizer.eos_token_id

    # if os.uname()[0] != 'Darwin':  # if not on mac
    #     model = exllama_set_max_input_length(model, 32000)
    # attn_implementation="flash_attention_2")
    return model, tokenizer


def format_messages(messages, toker, sample_config):
    template_name = None
    if len(re.findall('Mistral-.*-Instruct', sample_config.model_name_hf)) > 0:
        template_name = 'mistral-instruct'
    if len(re.findall('Orca', sample_config.model_name_hf)) > 0:
        template_name = 'chatml'
    if len(re.findall('.*gemma.*-it.*', sample_config.model_name_hf)) > 0:
        template_name = 'gemma-it'
    if len(re.findall('.*CausalLM.*', sample_config.model_name_hf)) > 0:
        template_name = 'chatml'
    if len(re.findall('vicuna', sample_config.model_name_hf)) > 0:
        template_name = 'vicuna'
    if len(re.findall('Llama-3.*-Instruct', sample_config.model_name_hf)) > 0:
        template_name = 'llama-3-instruct'
    if len(re.findall('Llama-2.*-chat', sample_config.model_name_hf)) > 0:
        template_name = 'llama-2-chat'
    if len(re.findall('llama2_.*_uncensored', sample_config.model_name_hf)) > 0:
        template_name = 'llama-2-uncensored'
    if len(re.findall('.*Qwen.*-Chat.*', sample_config.model_name_hf)) > 0:
        template_name = 'chatml'
    if len(re.findall('gpt-', sample_config.model_name_hf)) > 0:
        template_name = 'gpt'

    if (template_name is not None) and (template_name != 'llama-2-uncensored') and (template_name != 'gpt'):
        chat_template = open(
            sample_config._paths.sub_path + 'ch_temp/chat_templates/' + template_name + '.jinja').read()
        chat_template = chat_template.replace('    ', '').replace('\n', '')
        toker.chat_template = chat_template
        messages_formatted = toker.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        if template_name != 'mistral-instruct':
            messages_formatted = re.sub('<s>', '', messages_formatted)
            if template_name != 'vicuna':
                messages_formatted = re.sub('</s>', '', messages_formatted)

    else:
        if template_name == 'llama-2-uncensored':
            role_map = {'system': '### HUMAN:\n', 'user': '\n### HUMAN:\n', 'assistant': '\n### RESPONSE:\n'}
            messages_formatted = ''.join([role_map[msg['role']] + msg['content'] for msg in messages])
        elif template_name == 'gpt':
            messages_formatted = messages
        else:
            role_map = {'system': '\n', 'user': '\n', 'assistant': '\n'}
            messages_formatted = ''.join([msg['content'] + role_map[msg['role']] for msg in messages])

    return messages_formatted


def create_question_pairs_instr(sample_config):
    # load instructions and intro
    instr_dict, open_qs_dict, closed_qs_dict = load_task_content(sample_config)
    intro_prompt = instr_dict['intro']
    scale_instr = instr_dict['scale_instr_logit']

    # load and get questionnaire questions
    qs_prompt_file = f"{sample_config.prompts_path}qs/{sample_config.qs_name}/{sample_config.qs_name}.txt"
    with open(qs_prompt_file, encoding='utf-8') as f:
        qs_prompt = f.read()

    qsn_preamble = '\n' + instr_dict[sample_config.qs_name + '_preamble'] + '\n'
    prompt_sections = qs_prompt.split('\n^^\n')
    qsn_questions = prompt_sections[2].split('\n')

    # add letter labels to scale (Permute optionally)
    tmp_labels = prompt_sections[1].split('- ')
    tmp_labels = [l for l in tmp_labels if l != '']
    tmp_labels = [f'{ltr} - {tp}' for tp, ltr in zip(tmp_labels, sample_config.label_letters)]
    tmp_labels = ''.join(tmp_labels)
    prompt_sections[1] = tmp_labels

    qsn_questions = ['' + q + '\n\n' + scale_instr + "\n" + prompt_sections[1] + '\n\nAnswer: ' for q in
                     qsn_questions]
    if sample_config.skip_sui:
        qsn_questions.pop()
    # qsn_questions_dict = {sample_config.qs_name + '_q' + q.split('.')[0]: 'Problem: ' + q.split('. ')[1] for q in
    #                       qsn_questions}
    qsn_questions_dict = {sample_config.qs_name + '_q' + q.split('.')[0]: 'Problem: ' + q.split('. ')[1] for q in
                          qsn_questions}

    # load and get closed questions - leveled
    cq_instr = '\n\n' + instr_dict['closed_instr'] + '\n'
    cq_preamble = '\n\n' + instr_dict['closed_preamble'] + '\n'

    qs_custom_file = f"{sample_config.prompts_path}qs/custom/custom_{sample_config.gen_fname}.txt"
    # qs_custom_spec_file = f"{sample_config.prompts_path}qs/custom/custom_spec_hh.txt"
    with open(qs_custom_file, encoding='utf-8') as f:
        qs_custom = f.read()
    #
    # with open(qs_custom_spec_file, encoding='utf-8') as f:
    #     qs_custom_spec = f.read()

    custom_sections = qs_custom.split('\n^^\n')
    closed_questions = custom_sections[1].split('\n')

    # add letter labels to scale (Permute optionally)
    tmp_labels_custom = custom_sections[0].split('- ')
    tmp_labels_custom = [l for l in tmp_labels_custom if l != '']
    tmp_labels_custom = [f'{ltr} - {tp}' for tp, ltr in zip(tmp_labels_custom, sample_config.label_letters)]
    tmp_labels_custom = ''.join(tmp_labels_custom)
    custom_sections[0] = tmp_labels_custom

    closed_questions = ['' + q + '\n\n' + scale_instr + "\n" + custom_sections[0] + '\n\nAnswer: ' for q in
                        closed_questions]
    closed_questions_dict = {q.split('.')[0]: 'Statement: \n' + q.split('. ')[1] for q in closed_questions}
    closed_questions_dict['rep_lvl2_q1'] = closed_questions_dict['lvl2_q1']

    # custom_spec_sections = qs_custom_spec.split('\n^^\n')
    # closed_spec_questions = custom_spec_sections[1].split('\n')
    # closed_spec_questions = ['' + q + '\n\n' + scale_instr + "\n" + custom_spec_sections[0] + '\n\nAnswer: ' for q in
    #                          closed_spec_questions]
    # closed_spec_questions_dict = {q.split('.')[0]: 'Statement: \n' + q.split('. ')[1] for q in closed_spec_questions}
    # closed_spec_questions_dict['rep_lvl2_q1'] = closed_spec_questions_dict['lvl2_q1']

    oq_instr = instr_dict['openq_instr'] + '\n'
    # return intro_prompt, oq_instr, open_qs_dict, open_spec_qs_dict, closed_questions_dict, closed_spec_questions_dict, cq_preamble, qsn_questions_dict, cq_instr, qsn_preamble
    return intro_prompt, oq_instr, open_qs_dict, closed_questions_dict, cq_preamble, qsn_questions_dict, cq_instr, qsn_preamble


def get_all_question_pairs(openq_data_long, sample_config, toker):
    # intro_prompt, oq_instr, open_qs_dict, open_spec_qs_dict, closed_questions_dict, closed_spec_questions_dict, cq_preamble, qsn_questions_dict, cq_instr, qsn_preamble = create_question_pairs_instr(
    #     sample_config)
    # subject data
    question_pairs = {}
    question_pairs_formatted = {}
    question_pairs_short = {}
    question_pairs_short_formatted = {}
    intro_prompt, oq_instr, open_qs_dict, closed_questions_dict, cq_preamble, qsn_questions_dict, cq_instr, qsn_preamble = create_question_pairs_instr(
        sample_config)
    for sub in openq_data_long['sub'].unique():
        # print(sub)
        openq_data_sub = openq_data_long[openq_data_long['sub'] == sub].reset_index(drop=True)
        # task_v = '_'.join(sub.split('_')[1:])
        # print(task_v)
        # sample_config.gen_fname = spec_types[task_v]

        sub_responses = {}
        sub_responses_formatted = {}
        sub_responses_short = {}
        sub_responses_short_formatted = {}
        for r, row in openq_data_sub.iterrows():
            q_name = row['q_name']
            q_name_rel = qs_maps[sample_config.qs_name][q_name]
            if not ('lvl' in q_name_rel and 'v1' in sub):
                sub_responses[q_name] = []
                sub_responses_formatted[q_name] = []
                openq_text = 'Question:\n' + open_qs_dict[q_name] + '\n\nAnswer:'
                # if openq_data_sub['spec_level'][0] == 'gen':
                #     openq_text = 'Question:\n' + open_qs_dict[q_name] + '\n\nAnswer:'
                # else:
                #     openq_text = 'Question:\n' + open_spec_qs_dict[q_name] + '\n\nAnswer:'
                openq_ans = '\n' + row['value'] + '\n\n'

                if intro_prompt != '':
                    sub_responses[q_name].append({'role': 'system', 'content': intro_prompt})

                sub_responses[q_name].append({'role': 'user', 'content': oq_instr + openq_text})
                sub_responses[q_name].append({'role': 'assistant', 'content': openq_ans})
                if 'lvl' in q_name_rel:
                    closedq_text = closed_questions_dict[q_name]
                    # if openq_data_sub['spec_level'][0] == 'gen':
                    #     closedq_text = closed_questions_dict[q_name]
                    # else:
                    #     closedq_text = closed_spec_questions_dict[q_name]
                else:
                    closedq_text = qsn_questions_dict[q_name_rel]

                if 'lvl' in q_name_rel:
                    sub_responses[q_name].append(
                        {'role': 'user', 'content': cq_preamble + closedq_text})
                else:
                    sub_responses[q_name].append({'role': 'user', 'content': cq_instr + qsn_preamble + closedq_text})

                sub_responses_formatted[q_name] = format_messages(sub_responses[q_name], toker, sample_config)

                # sub_responses_short[q_name] = [{'content': '', 'role': 'user'}] + copy.deepcopy(
                #     sub_responses[q_name][1:])
                # sub_responses_short_formatted[q_name] = format_messages(sub_responses_short[q_name], toker,
                #                                                         sample_config)

        question_pairs[sub] = sub_responses
        question_pairs_formatted[sub] = sub_responses_formatted

        # question_pairs_short[sub] = sub_responses_short
        # question_pairs_short_formatted[sub] = sub_responses_short_formatted
    return question_pairs, question_pairs_formatted, question_pairs_short, question_pairs_short_formatted


def find_loc(input_text, to_find_text, tokenizer, n=1):
    input_ids = tokenizer(input_text, return_tensors="pt", padding=True,
                          return_offsets_mapping=True)  # .input_ids[0].tolist()
    loc_start_text = find_subsequence_n(input_text, to_find_text, n)

    if loc_start_text != -1:
        loc_end_text = loc_start_text + len(to_find_text)
        token_locations = np.where(
            (input_ids.offset_mapping[0, :, 0] >= loc_start_text) & (
                    input_ids.offset_mapping[0, :, 1] <= loc_end_text))[0]
        loc_start = int(token_locations[0])
        loc_end = int(token_locations[-1] + 1)
        last_token_loc = int(token_locations[-1])
    else:
        loc_start = -1
        loc_end = -1
        last_token_loc = -1

    last_token = tokenizer.decode(input_ids.input_ids[0][last_token_loc])

    return loc_start, loc_end, last_token_loc, last_token


def get_sub_locs(openq_data_long, question_pairs_formatted, tokenizer, sample_config):
    # instr_dict, open_qs_dict, open_spec_qs_dict, closed_qs_dict, closed_qs_spec_dict = load_task_content(sample_config)
    # task_v = '_'.join(sample_config.subj.split('_')[1:])
    # sample_config.gen_fname = spec_types[task_v]

    instr_dict, open_qs_dict, closed_qs_dict = load_task_content(sample_config)
    texts_to_find = {}
    last_token_locs = {}
    sub_question_data = question_pairs_formatted[sample_config.subj]
    for q_name, pair in sub_question_data.items():
        # print(f"{"----" * 10}{q_name}{"----" * 10}")
        texts_to_find[q_name] = {}
        last_token_locs[q_name] = {}

        sub_text = sub_question_data[q_name]
        oq_qs = re.sub('[^a-zA-Z0-9]+$', '', open_qs_dict[q_name])
        cq_qs = re.sub('[^a-zA-Z0-9]+$', '', closed_qs_dict[q_name])
        oq_ans = re.sub('[^a-zA-Z0-9]+$', '',
                        openq_data_long[
                            (openq_data_long['sub'] == sample_config.subj) & (openq_data_long['q_name'] == q_name)][
                            'value'].values[0])

        # texts_to_find[q_name][q_name + '_oq_qs'] = oq_qs
        texts_to_find[q_name]['oq_qs'] = oq_qs
        texts_to_find[q_name]['oq_Answer'] = 'Answer'
        texts_to_find[q_name]['oq_ans'] = oq_ans
        texts_to_find[q_name]['cq_qs'] = cq_qs
        texts_to_find[q_name]['cq_Answer'] = 'Answer'

        loc_oq_qs, loc_e_oq_qs, loc_t_oq_qs, l_t_oq_qs = find_loc(sub_text, oq_qs, tokenizer)
        if loc_t_oq_qs == -1:
            print(f"{sample_config.subj}, {q_name} loc_t_oq_qs not found")
        # print(f"{oq_qs}")
        # print(f"\t{loc_oq_qs, loc_e_oq_qs, loc_t_oq_qs, l_t_oq_qs}")

        loc_oq_Answer, loc_e_oq_Answer, loc_t_oq_Answer, l_t_oq_Answer = find_loc(sub_text, 'Answer', tokenizer, n=1)
        if loc_t_oq_Answer == -1:
            print(f"{sample_config.subj}, {q_name} loc_t_oq_Answer not found")
        # print(f"Answer")
        # print(f"\t{loc_oq_Answer, loc_e_oq_Answer, loc_t_oq_Answer, l_t_oq_Answer}")

        loc_oq_ans, loc_e_oq_ans, loc_t_oq_ans, l_t_oq_ans = find_loc(sub_text, oq_ans, tokenizer)
        if loc_t_oq_ans == -1:
            print(f"{sample_config.subj}, {q_name} loc_t_oq_ans not found")
        # print(f"{oq_ans}")
        # print(f"\t{loc_oq_ans, loc_e_oq_ans, loc_t_oq_ans, l_t_oq_ans}")

        loc_cq_qs, loc_e_cq_qs, loc_t_cq_qs, l_t_cq_qs = find_loc(sub_text, cq_qs, tokenizer)
        if loc_t_cq_qs == -1:
            print(f"{sample_config.subj}, {q_name} loc_t_cq_qs not found")
        # print(f"{cq_qs}")
        # print(f"\t{loc_cq_qs, loc_e_cq_qs, loc_t_cq_qs, l_t_cq_qs}")

        loc_cq_Answer, loc_e_cq_Answer, loc_t_cq_Answer, l_t_cq_Answer = find_loc(sub_text, 'Answer', tokenizer, n=2)
        if loc_t_cq_Answer == -1:
            print(f"{sample_config.subj}, {q_name} loc_t_cq_Answer not found")
        # print(f"Answer")
        # print(f"\t{loc_cq_Answer, loc_e_cq_Answer, loc_t_cq_Answer, l_t_cq_Answer}")

        last_token_locs[q_name]['oq_qs'] = loc_t_oq_qs
        last_token_locs[q_name]['oq_Answer'] = loc_t_oq_Answer
        last_token_locs[q_name]['oq_ans'] = loc_t_oq_ans
        last_token_locs[q_name]['cq_qs'] = loc_t_cq_qs
        last_token_locs[q_name]['cq_Answer'] = loc_t_cq_Answer

    return texts_to_find, last_token_locs


def get_model_response_batch(inputs, model, tokenizer, sample_config, last_token_locs=None):
    # sampling_params = SamplingParams(top_p=sample_config.top_p, temperature=sample_config.temp, max_tokens=sample_config.max_new_tokens)
    inputs_ids = tokenizer(inputs, return_tensors="pt", padding=True).input_ids.to(device_name)
    # hidden_states_cq = []
    # hidden_states_oq = None
    hidden_states = None
    # hidden_states_ans = None
    # attentions_preans = None

    if sample_config.save_states:
        with torch.no_grad():
            # outputs = model.generate(inputs_ids, do_sample=True, top_p=sample_config.top_p,
            #                          temperature=sample_config.temp,
            #                          max_new_tokens=sample_config.max_new_tok, return_dict_in_generate=True,
            #                          output_hidden_states=True, output_scores=False, output_attentions=False)
            outputs = model.generate(inputs_ids, do_sample=True,
                                     max_new_tokens=sample_config.max_new_tok, return_dict_in_generate=True,
                                     output_hidden_states=True, output_scores=False, output_attentions=False)
            # hidden_states_cq_list = [torch.unsqueeze(batch_tensor, 0) for batch_tensor in
            #                          torch.stack(outputs.hidden_states[0], dim=1)[:, :, -1, :]]

            # for hidden_states_cq_batch in hidden_states_cq_list:
            #     tmp_dict = {}
            #     for layer in sample_config.layer_list:
            #         tmp_dict['l_' + str(layer)] = hidden_states_cq_batch[0, layer, :]
            #     hidden_states_cq.append(tmp_dict)

            # hidden_states_cq['l_' + str(sample_config.L)] = [hidden_states_cq_batch[0, sample_config.L, :] for
            #                                                  hidden_states_cq_batch in hidden_states_cq_list]
            # hidden_states_cq['l_' + str(sample_config.midL)] = [hidden_states_cq_batch[0, sample_config.midL, :] for
            #                                                     hidden_states_cq_batch in hidden_states_cq_list]
            # hidden_states_cq = outputs.hidden_states

            if last_token_locs is not None:
                hidden_states = {}
                hidden_states_tensor = torch.stack(outputs.hidden_states[0], dim=1)[0]
                for layer in sample_config.layer_list:
                    hidden_states['l_' + str(layer)] = {}
                    for k, v in last_token_locs.items():
                        hidden_states['l_' + str(layer)][k] = hidden_states_tensor[layer, v, :]
                    hidden_states['l_' + str(layer)]['last'] = hidden_states_tensor[layer, -1, :]
                    # hidden_states['l_' + str(layer)]['post_answer'] = outputs.hidden_states[len(outputs.hidden_states)-1][layer][0,0,:]

                # hidden_states = torch.stack(
                #     [torch.stack([v[k2] for k2 in qa_loc_key_names]) for k, v in hidden_states.items()])
                if sample_config.gen_qs_name is None:
                    hidden_states = torch.stack(
                        [torch.stack([hidden_states['l_' + str(l)][k2] for k2 in qa_loc_key_names]) for l in
                         sample_config.layer_list]).detach().to('cpu')
                else:
                    hidden_states = torch.stack(
                        [torch.stack([hidden_states['l_' + str(l)][k2] for k2 in qa_loc_key_names_gen]) for l in
                         sample_config.layer_list]).detach().to('cpu')

        # hidden_states_cq = torch.unsqueeze(torch.stack(outputs.hidden_states[0],dim=1)[:,:,-1,:],1)
        # hidden_states_ans = outputs.hidden_states[1:]
        # hidden_states_ans = torch.stack([torch.stack(out_hs_at, axis=1)[:, :, 0, :] for out_hs_at in outputs.hidden_states[1:]], axis=2)

        # attentions_preans = outputs.attentions[0]
        outputs_ids = [output_ids[input_ids.shape[0]:-1] for input_ids, output_ids in
                       zip(inputs_ids, outputs.sequences)]

    else:
        with torch.no_grad():
            # outputs = model.generate(inputs_ids, do_sample=True, top_p=sample_config.top_p,
            #                          temperature=sample_config.temp,
            #                          max_new_tokens=sample_config.max_new_tok)
            outputs = model.generate(inputs_ids, do_sample=True,
                                     max_new_tokens=sample_config.max_new_tok)
        outputs_ids = [output_ids[input_ids.shape[0]:-1] for input_ids, output_ids in zip(inputs_ids, outputs)]

    # outs = model.generate(inputs_ids, do_sample=True, top_p=sample_config.top_p, temperature=sample_config.temp,
    #                       max_new_tokens=sample_config.max_new_tok)
    # outputs_texts = [output_ids[input_ids.shape[0]:-1] for input_ids, output_ids in zip(inputs_ids, outputs_ids)]
    outputs_texts = tokenizer.batch_decode(outputs_ids, skip_special_tokens=False)
    outputs_texts_clean = tokenizer.batch_decode(outputs_ids, skip_special_tokens=True)
    outputs_ids_clean = [
        torch.tensor([output_id for output_id in output_ids if output_id not in tokenizer.all_special_ids]).to(
            device_name) for output_ids in outputs_ids]
    # return outputs, inputs_ids
    # return outputs, outs
    return inputs_ids, outputs_ids, outputs_ids_clean, outputs_texts, outputs_texts_clean, hidden_states


def process_model_response_batch(model_outputs, sample_config):
    if 'rep' in sample_config.which_q:
        q_level = 2
    else:
        q_level = int(sample_config.which_q.split('_')[0][-1])

    model_outputs_proc = []
    model_outputs_scores = []
    model_outputs_bool_checks = []

    for model_output in model_outputs:
        bool_check = False
        tmp_score = -1
        if q_level < 3:
            map_dict = lvlx_closed_map
        else:
            map_dict = maps[sample_config.qs_name]

        model_output_proc = model_output
        lines = model_output.split('\n')
        count_lines = 0
        ans = []
        for l, line in enumerate(lines):
            line_ans = re.sub(r'(^\s+|\s+$)', '', line)  # remvoe extra spaces at the beginning/end
            line_ans = re.sub(r'[^A-Za-z ]+', '', line_ans)  # remove non-letters (and non-saces)

            if line_ans != '' and len(ans) == 0:
                count_lines += 1
                ans = [key for key in map_dict.keys() if key.lower() in line_ans.lower()]
            else:
                break
        if len(ans) > 0:
            ans_len = np.argsort([len(a) for a in ans])
            ans = [ans[ans_len[-1]]]
            # if count_lines == 1 and len(ans) == 1:
            model_output_proc = ans[0]
            tmp_score = map_dict[model_output_proc]
            bool_check = True
            # if sample_config.qs_name in rev_lists.keys():
            #     to_rev = sample_config.which_q in rev_lists[sample_config.qs_name]
            #     if to_rev:
            #         tmp_score = min(map_dict.values()) + max(map_dict.values()) - tmp_score
        model_outputs_proc.append(model_output_proc)
        model_outputs_scores.append(tmp_score)
        model_outputs_bool_checks.append(bool_check)

    return model_outputs_proc, model_outputs_scores, model_outputs_bool_checks


def sample_model_responses_batch(sub_qs_pairs, sub_q_token_locs, model, tokenizer, sample_ts, sample_config):
    start = time.time()
    _, _, _, _, batch_model_outputs_clean, model_hidden_states = get_model_response_batch(
        sub_qs_pairs, model, tokenizer, sample_config, sub_q_token_locs)
    batch_time = round(time.time() - start, 4)
    # print(batch_time)

    # process responses
    batch_proc_outputs, batch_scores, batch_bools = process_model_response_batch(batch_model_outputs_clean,
                                                                                 sample_config)

    batch_output_logs = ['---------------q: ' + sample_config.which_q + '--------------\n' + model_output_clean for
                         model_output_clean in
                         batch_model_outputs_clean]

    for s, (batch_bool, batch_score, batch_proc_output, batch_output_log) in enumerate(
            zip(batch_bools, batch_scores, batch_proc_outputs, batch_output_logs)):
        if not batch_bool:
            # time_dict = {'sub': sample_config.subj, 'sample_ts': sample_ts + str(s), 'model': sample_config.model_name,
            #              'temp': sample_config.temp, 'top_p': sample_config.top_p, 'status': batch_bool,
            #              'time': batch_time, 'instr_name': sample_config.instr_name, 'qs': sample_config.qs_name,
            #              'question': sample_config.which_q, 'score': batch_score,
            #              'nSamples': sample_config.remBatchSize}
            tmp_dict = {'sub': sample_config.subj, 'sample_ts': sample_ts + str(s), 'question': sample_config.which_q,
                        'score': batch_score, 'response': batch_proc_output, 'model': sample_config.model_name,
                        'temp': sample_config.temp,
                        'top_p': sample_config.top_p, 'status': batch_bool, 'time': batch_time,
                        'instr_name': sample_config.instr_name, 'qs': sample_config.qs_name,
                        'nSamples': sample_config.remBatchSize}
            tmp_output_pd = pd.DataFrame(tmp_dict, index=[0])  # , index=range(len(tmp_dict['score'])))
            # tmp_time_pd = pd.DataFrame(time_dict, index=[0])  # , index=range(len(time_dict['time'])))

            # save output to txt - FAILED
            # with open(sample_config.outputs_path + 'fails/output_model_s-ts-' + sample_ts + str(s) + '.txt',
            #           'w',
            #           encoding='utf-8') as output_file:
            #     output_file.write(batch_output_log + '\n\n^^^^^^^^^\n\n' + tmp_output_pd.to_string())
            # save responses csv - FAILED
            tmp_output_pd.to_csv(
                sample_config.responses_path + 'fails/response_model_s-ts-' + sample_ts + str(s) + '.csv',
                index=False)
            # # save time csv - FAILED
            # tmp_time_pd.to_csv(sample_config.time_path + 'fails/time_s_ts-' + sample_ts + str(s) + '.csv',
            #                    index=False)
    # concat success and discard fail for hidden states of last token
    batch_proc_outputs = [batch_proc_output for batch_proc_output, batch_bool in zip(batch_proc_outputs, batch_bools) if
                          batch_bool]
    batch_scores = [batch_score for batch_score, batch_bool in zip(batch_scores, batch_bools) if batch_bool]
    batch_output_logs = [batch_output_log for batch_output_log, batch_bool in zip(batch_output_logs, batch_bools) if
                         batch_bool]
    batch_bools = [batch_bool for batch_bool in batch_bools if batch_bool]

    # save time, responses and outputs
    if all(batch_bools):
        # print('all good')
        for s, (batch_score, batch_proc_output, batch_output_log, batch_bool) in enumerate(
                zip(batch_scores, batch_proc_outputs, batch_output_logs, batch_bools)):
            # time_dict = {'sub': sample_config.subj, 'sample_ts': sample_ts + str(s), 'model': sample_config.model_name,
            #              'temp': sample_config.temp, 'top_p': sample_config.top_p, 'status': batch_bool,
            #              'time': batch_time, 'instr_name': sample_config.instr_name, 'qs': sample_config.qs_name,
            #              'question': sample_config.which_q, 'score': batch_score,
            #              'nSamples': sample_config.remBatchSize}
            tmp_dict = {'sub': sample_config.subj, 'sample_ts': sample_ts + str(s), 'question': sample_config.which_q,
                        'score': batch_score, 'response': batch_proc_output, 'model': sample_config.model_name,
                        'temp': sample_config.temp, 'top_p': sample_config.top_p, 'status': batch_bool,
                        'time': batch_time, 'instr_name': sample_config.instr_name, 'qs': sample_config.qs_name,
                        'nSamples': sample_config.remBatchSize}
            tmp_output_pd = pd.DataFrame(tmp_dict, index=[0])
            # tmp_time_pd = pd.DataFrame(time_dict, index=[0])

            batch_output_log = '\n^^^\n' + batch_output_log + '\n^^^\n'

            # save output to txt
            # with open(sample_config.outputs_path + 'output_model_s-ts-' + sample_ts + str(s) + '.txt', 'w',
            #           encoding='utf-8') as output_file:
            #     output_file.write(batch_output_log)

            # save responses csv
            tmp_output_pd.to_csv(
                sample_config.responses_path + 'response_model_s-ts-' + sample_ts + str(s) + '.csv', index=False)
            # # SAVE tiem csv
            # tmp_time_pd.to_csv(sample_config.time_path + 'time_s_ts-' + sample_ts + str(s) + '.csv', index=False)

    if sample_config.save_states and any(batch_bools):
        # print('states all good')
        tensor_fname = f"{sample_config.states_path}{sample_config.subj}^^{sample_config.which_q}^^{sample_config.model_name_rhp}_hidden_states_s-ts-{sample_ts}0.pt"
        torch.save(model_hidden_states.clone(), tensor_fname)

    del model_hidden_states
    flush()
    # return model_hidden_states
