import re
import string

# from networkx.classes.filters import hide_edges
from scipy import stats
import os
import copy
import json
import pandas as pd
import time
import numpy as np
# from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from tqdm import tqdm
# import gc

from utils.utils import upload_blob_from_memory
from datetime import datetime
from pathlib import Path

if os.uname()[0] != 'Darwin':  # if not on mac
    # from auto_gptq import exllama_set_max_input_length
    from accelerate.utils import release_memory
    import gc

    device_name = 'cuda'


    def flush():
        # try:
        #     model.cpu()
        #     del model
        # except:
        #     pass
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        # release_memory()
else:
    device_name = 'mps'


    def flush():
        pass

# Questionnaire maps
phq9_map = {'Not at all': 0, 'Several days': 1, 'More than half the days': 2, 'Nearly every day': 3}
ami_map = {'Completely untrue': 4, 'Mostly untrue': 3, 'Neither true nor untrue': 2, 'Quite true': 1,
           'Completely true': 0}
ami_inv_map = {v: k for k, v in ami_map.items()}
sds_map = {"A little of the time": 1, "Some of the time": 2, "Good part of the time": 3, "Most of the time": 4}
sds_rev = [2, 5, 6, 11, 12, 14, 16, 17, 18, 20]
gad7_map = {'Not at all': 0, 'Several days': 1, 'More than half the days': 2, 'Nearly every day': 3}
gad7_inv_map = {v: k for k, v in gad7_map.items()}

ids30_map = {
    'ids30_q1': {
        "I never take longer than 30 minutes to fall asleep.": 0,
        "I take at least 30 minutes to fall asleep, less than half the time.": 1,
        "I take at least 30 minutes to fall asleep, more than half the time.": 2,
        "I take more than 60 minutes to fall asleep, more than half the time.": 3
    },
    'ids30_q2': {
        "I do not wake up at night.": 0,
        "I have a restless, light sleep with a few brief awakenings each night.": 1,
        "I wake up at least once a night, but I go back to sleep easily.": 2,
        "I awaken more than once a night and stay awake for 20 minutes or more, more than half the time.": 3
    },
    'ids30_q3': {
        "Most of the time, I awaken no more than 30 minutes before I need to get up.": 0,
        "More than half the time, I awaken more than 30 minutes before I need to get up.": 1,
        "I almost always awaken at least one hour or so before I need to, but I go back to sleep eventually.": 2,
        "I awaken at least one hour before I need to, and can't go back to sleep.": 3
    },
    'ids30_q4': {
        "I sleep no longer than 7-8 hours/night, without napping during the day.": 0,
        "I sleep no longer than 10 hours in a 24-hour period including naps.": 1,
        "I sleep no longer than 12 hours in a 24-hour period including naps.": 2,
        "I sleep longer than 12 hours in a 24-hour period including naps.": 3
    },
    'ids30_q5': {
        "I do not feel sad.": 0,
        "I feel sad less than half the time.": 1,
        "I feel sad more than half the time.": 2,
        "I feel sad nearly all of the time.": 3
    },
    'ids30_q6': {
        "I do not feel irritable.": 0,
        "I feel irritable less than half the time.": 1,
        "I feel irritable more than half the time.": 2,
        "I feel extremely irritable nearly all of the time.": 3
    },
    'ids30_q7': {
        "I do not feel anxious or tense.": 0,
        "I feel anxious (tense) less than half the time.": 1,
        "I feel anxious (tense) more than half the time.": 2,
        "I feel extremely anxious (tense) nearly all of the time.": 3
    },
    'ids30_q8': {
        "My mood brightens to a normal level which lasts for several hours when good events occur.": 0,
        "My mood brightens but I do not feel like my normal self when good events occur.": 1,
        "My mood brightens only somewhat to a rather limited range of desired events.": 2,
        "My mood does not brighten at all, even when very good or desired events occur in my life.": 3
    },
    'ids30_q9': {
        "There is no regular relationship between my mood and the time of day.": 0,
        "My mood often relates to the time of day because of environmental events (e.g., being alone, working).": 1,
        "In general, my mood is more related to the time of day than to environmental events.": 2,
        "My mood is clearly and predictably better or worse at a particular time each day.": 3
    },
    'ids30_q10': {
        "The mood (internal feelings) that I experience is very much a normal mood.": 0,
        "My mood is sad, but this sadness is pretty much like the sad mood I would feel if someone close to me died or left.": 1,
        "My mood is sad, but this sadness has a rather different quality to it than the sadness I would feel if someone close to me died or left.": 2,
        "My mood is sad, but this sadness is different from the type of sadness associated with grief or loss.": 3
    },
    'ids30_q11': {
        "There is no change in my usual appetite.": 0,
        "I eat somewhat less often or lesser amounts of food than usual.": 1,
        "I feel a need to eat more frequently than usual.": 1,
        "I eat much less than usual and only with personal effort.": 2,
        "I regularly eat more often and/or greater amounts of food than usual.": 2,
        "I rarely eat within a 24-hour period, and only with extreme personal effort or when others persuade me to eat.": 3,
        "I feel driven to overeat both at mealtime and between meals.": 3
    },
    'ids30_q12': {
        "I have not had a change in my weight.": 0,
        "I feel as if I've had a slight weight loss.": 1,
        "I feel as if I've had a slight weight gain.": 1,
        "I have lost 2 pounds or more.": 2,
        "I have gained 2 pounds or more.": 2,
        "I have lost 5 pounds or more.": 3,
        "I have gained 5 pounds or more.": 3
    },
    'ids30_q13': {
        "There is no change in my usual capacity to concentrate or make decisions.": 0,
        "I occasionally feel indecisive or find that my attention wanders.": 1,
        "Most of the time, I struggle to focus my attention or to make decisions.": 2,
        "I cannot concentrate well enough to read or cannot make even minor decisions.": 3
    },
    'ids30_q14': {
        "I see myself as equally worthwhile and deserving as other people.": 0,
        "I am more self-blaming than usual.": 1,
        "I largely believe that I cause problems for others.": 2,
        "I think almost constantly about major and minor defects in myself.": 3
    },
    'ids30_q15': {
        "I have an optimistic view of my future.": 0,
        "I am occasionally pessimistic about my future, but for the most part I believe things will get better.": 1,
        "I'm pretty certain that my immediate future (1-2 months) does not hold much promise of good things for me.": 2,
        "I see no hope of anything good happening to me anytime in the future.": 3
    },
    'ids30_q16': {
        "I do not think of suicide or death.": 0,
        "I feel that life is empty or wonder if it's worth living.": 1,
        "I think of suicide or death several times a week for several minutes.": 2,
        "I think of suicide or death several times a day in some detail, or I have made specific plans for suicide or have actually tried to take my life.": 3
    },
    'ids30_q17': {
        "There is no change from usual in how interested I am in other people or activities.": 0,
        "I notice that I am less interested in people or activities.": 1,
        "I find I have interest in only one or two of my formerly pursued activities.": 2,
        "I have virtually no interest in formerly pursued activities.": 3
    },
    'ids30_q18': {
        "There is no change in my usual level of energy.": 0,
        "I get tired more easily than usual.": 1,
        "I have to make a big effort to start or finish my usual daily activities (for example, shopping, homework, cooking or going to work).": 2,
        "I really cannot carry out most of my usual daily activities because I just don't have the energy.": 3
    },
    'ids30_q19': {
        "I enjoy pleasurable activities just as much as usual.": 0,
        "I do not feel my usual sense of enjoyment from pleasurable activities.": 1,
        "I rarely get a feeling of pleasure from any activity.": 2,
        "I am unable to get any pleasure or enjoyment from anything.": 3
    },
    'ids30_q20': {
        "I'm just as interested in sex as usual.": 0,
        "My interest in sex is somewhat less than usual or I do not get the same pleasure from sex as I used to.": 1,
        "I have little desire for or rarely derive pleasure from sex.": 2,
        "I have absolutely no interest in or derive no pleasure from sex.": 3
    },
    'ids30_q21': {
        "I think, speak, and move at my usual rate of speed.": 0,
        "I find that my thinking is slowed down or my voice sounds dull or flat.": 1,
        "It takes me several seconds to respond to most questions and I'm sure my thinking is slowed.": 2,
        "I am often unable to respond to questions without extreme effort.": 3
    },
    'ids30_q22': {
        "I do not feel restless.": 0,
        "I'm often fidgety, wring my hands, or need to shift how I am sitting.": 1,
        "I have impulses to move about and am quite restless.": 2,
        "At times, I am unable to stay seated and need to pace around.": 3
    },
    'ids30_q23': {
        "I don't have any feeling of heaviness in my arms or legs and don't have any aches or pains.": 0,
        "Sometimes I get headaches or pains in my stomach, back or joints but these pains are only sometimes present and they don't stop me from doing what I need to do.": 1,
        "I have these sorts of pains most of the time.": 2,
        "These pains are so bad they force me to stop what I am doing.": 3
    },
    'ids30_q24': {
        "I don't have any of these symptoms: heart pounding fast, blurred vision, sweating, hot and cold flashes, chest pain, heart turning over in my chest, ringing in my ears, or shaking.": 0,
        "I have some of these symptoms but they are mild and are present only sometimes.": 1,
        "I have several of these symptoms and they bother me quite a bit.": 2,
        "I have several of these symptoms and when they occur I have to stop doing whatever I am doing.": 3
    },
    'ids30_q25': {
        "I have no spells of panic or specific fears (phobia) (such as animals or heights).": 0,
        "I have mild panic episodes or fears that do not usually change my behavior or stop me from functioning.": 1,
        "I have significant panic episodes or fears that force me to change my behavior but do not stop me functioning.": 2,
        "I have panic episodes at least once a week or severe fears that stop me from carrying on my daily activities.": 3
    },
    'ids30_q26': {
        "There is no change in my usual bowel habits.": 0,
        "I have intermittent constipation or diarrhea which is mild.": 1,
        "I have diarrhea or constipation most of the time but it does not interfere with my day-to-day functioning.": 2,
        "I have constipation or diarrhea for which I take medicine or which interferes with my day-to-day activities.": 3
    },
    'ids30_q27': {
        "I have not felt easily rejected, slighted, criticized or hurt by others at all.": 0,
        "I have occasionally felt rejected, slighted, criticized or hurt by others.": 1,
        "I have often felt rejected, slighted, criticized or hurt by others, but these feelings have had only slight effects on my relationships or work.": 2,
        "I have often felt rejected, slighted, criticized or hurt by others and these feelings have impaired my relationships and work.": 3
    },
    'ids30_q28': {
        "I have not experienced the physical sensation of feeling weighted down and without physical energy.": 0,
        "I have occasionally experienced periods of feeling physically weighted down and without physical energy, but without a negative effect on work, school, or activity level.": 1,
        "I feel physically weighted down (without physical energy) more than half the time.": 2,
        "I feel physically weighted down (without physical energy) most of the time, several hours per day, several days per week.": 3
    }
}

ids30_map_option = {k: {'O' + str(o + 1): v2 for o, v2 in enumerate(v.values())} for k, v in ids30_map.items()}

letters = list(string.ascii_uppercase)

ids30_scales = {k: ''.join(['- ' + r + '\n' for r in v.keys()]) for k, v in ids30_map.items()}
ids30_scales_let = {k: ''.join(['- O' + str(l + 1) + ': ' + vl + '\n' for l, vl in zip(range(len(v)), v)]) for k, v in
                    ids30_map.items()}
# ids30_scales_let = {k: ''.join(['- Option ' + l + ': ' + vl + '\n' for l, vl in zip(letters[0:len(v)], v)]) for k, v in
#                     ids30_map.items()}
maps = {'phq9': phq9_map, 'ami': ami_map, 'sds': sds_map, 'gad7': gad7_map, 'ids30': ids30_map_option}
rev_lists = {'sds': sds_rev}
inv_lists = {'ami': ami_inv_map, 'gad7': gad7_inv_map}

phq9_qs_map = {'lvl3_q1': 'phq9_q1', 'lvl3_q2': 'phq9_q2', 'lvl3_q3': 'phq9_q3', 'lvl3_q4': 'phq9_q4',
               'lvl3_q5': 'phq9_q5', 'lvl3_q6': 'phq9_q6', 'lvl3_q7': 'phq9_q7', 'lvl3_q8': 'phq9_q8',
               'lvl1_q1': 'lvl1_closed_q1', 'lvl2_q1': 'lvl2_closed_q1', 'lvl2_q2': 'lvl2_closed_q2',
               'lvl2_q3': 'lvl2_closed_q3', 'rep_lvl2_q1': 'lvl2_closed_q1'}
qs_maps = {'phq9': phq9_qs_map}
phq9_qs_inv_map = {'phq9_q1': 'lvl3_q1', 'phq9_q2': 'lvl3_q2', 'phq9_q3': 'lvl3_q3', 'phq9_q4': 'lvl3_q4',
                   'phq9_q5': 'lvl3_q5', 'phq9_q6': 'lvl3_q6', 'phq9_q7': 'lvl3_q7', 'phq9_q8': 'lvl3_q8'}

phq9_lvl_qs = {'lvl1': ['lvl1_q1'], 'lvl2': ['lvl2_q1', 'lvl2_q2', 'lvl2_q3', 'rep_lvl2_q1'],
               'lvl3': ['lvl3_q1', 'lvl3_q2', 'lvl3_q3', 'lvl3_q4', 'lvl3_q5', 'lvl3_q6', 'lvl3_q7', 'lvl3_q8']}

qs_inv_maps = {'phq9': phq9_qs_inv_map}

lvlx_closed_map = {
    'Very Good': 0, 'Good': 1, 'Bad': 2, 'Very Bad': 3
}

qa_loc_key_names = ['oq_qs', 'oq_Answer', 'oq_ans', 'cq_qs', 'cq_Answer', 'last']
qa_loc_key_names_gen = ['cq_qs', 'cq_Answer', 'last']

spec_levels = {'v3_s': {'lvl3_q1': 'hb', 'lvl3_q2': 'hmb', 'lvl3_q3': 'hb', 'lvl3_q4': 'hmb',
                        'lvl3_q5': 'hb', 'lvl3_q6': 'hmb', 'lvl3_q7': 'hb', 'lvl3_q8': 'hmb',
                        'lvl1_q1': 'hb', 'lvl2_q1': 'hmb', 'lvl2_q2': 'hb',
                        'lvl2_q3': 'hmb', 'rep_lvl2_q1': 'hmb'}}


def zip_and_bucket_gen_hs(sample_config, paths, prefix):
    states_dir = f"{paths.files_dir}states/{sample_config.gen_qs_name}/subjects/"
    subs = [d for d in os.listdir(states_dir) if '.DS_Store' not in d]
    store_pt_files = []
    # store_pt_file_names = []
    for sub in subs:
        sub_path = f"{states_dir}/{sub}"
        qs = [d for d in os.listdir(sub_path) if '.DS_Store' not in d]
        for q in qs:
            sub_q_path = f"{sub_path}/{q}/"
            gen_qs = [d for d in os.listdir(sub_q_path) if '.DS_Store' not in d]
            for gq in gen_qs:
                sub_gq_path = f"{sub_q_path}/{gq}/"
                spec = [d for d in os.listdir(sub_gq_path) if '.DS_Store' not in d][0]
                spec_list = spec.split('^')
                pt_file = [f for f in os.listdir(f"{sub_gq_path}/{spec}") if '.pt' in f][0]
                fname = f"{sub}^^{q}^^{gq}^^{spec}_" + pt_file
                # store_pt_file_names.append(fname)
                pt_file = f"{sub_gq_path}{spec}/" + pt_file
                store_pt_files.append(pt_file)

    sup_dir = f"{paths.sub_path}states_{prefix}/"
    Path(sup_dir).mkdir(parents=True, exist_ok=True)
    zip_ts = str(round(datetime.timestamp(datetime.now()) * 10000))

    zip_fname = sup_dir + sample_config.gen_qs_name + '_' + prefix + '^sub_hidden_states_' + zip_ts + '.zip'
    # zip_fname = sup_dir + 'sub_hidden_states_' + zip_ts + '.zip'
    files_to_zip = " ".join(store_pt_files)
    zip_out = None
    try:
        zip_out = os.system("zip -jv " + zip_fname + ' ' + files_to_zip)
    except:
        print('zip problem')
    try:
        if zip_out != 0:
            # backup saving
            print('zip problem resolve')
            # zip_fname = sup_dir + sample_config.gen_qs_name + '_gen^sub_hidden_states_' + zip_ts + '.zip'
            files_to_zip = f"{paths.files_dir}states/{sample_config.gen_qs_name}/"
            os.system("zip -rv " + zip_fname + ' ' + files_to_zip)

            bucket_fname = '/'.join(zip_fname.split('/')[2:])
            upload_blob_from_memory('llm-bucket-res', zip_fname, bucket_fname)
        else:
            bucket_fname = '/'.join(zip_fname.split('/')[2:])
            upload_blob_from_memory('llm-bucket-res', zip_fname, bucket_fname)

    except:
        print('zip problem 2')


def find_subsequence_n(text, subtext, n=1):
    text_trim = copy.deepcopy(text)
    loc_counter = 0
    id_loc = -1
    if n == -1:
        loc_tmp = -3
        while loc_tmp < -1:
            location = text_trim.find(subtext)  # find substring location start
            if location != -1:
                id_loc = location + loc_counter

                loc_counter += location + len(subtext) + 1  # location in text after the first occurence
                text_trim = text[loc_counter:]  # trim the string so tht the first occurence is not there
                loc_tmp = -2
            else:
                if loc_tmp < -2:
                    id_loc = -1
                else:
                    pass
                break
    else:
        for oc in range(n):
            location = text_trim.find(subtext)  # find substring location start
            if location != -1:
                id_loc = location + loc_counter

                loc_counter += location + len(subtext) + 1  # location in text after the first occurence
                text_trim = text[loc_counter:]  # trim the string so tht the first occurence is not there
            else:
                id_loc = -1
                break
    return id_loc


def load_task_content(sample_config):
    instructions_file = f"{sample_config.prompts_path}experiment/{sample_config.instr_name}.txt"
    openq_file = f"{sample_config.prompts_path}experiment/{sample_config.qs_name}_open_{sample_config.gen_fname}.txt"
    # openq_spec_file = f"{sample_config.prompts_path}experiment/{sample_config.qs_name}_open_spec_hh.txt"
    closed_qs_file = f"{sample_config.prompts_path}qs/custom/custom_{sample_config.gen_fname}.txt"
    # closed_qs_spec_file = f"{sample_config.prompts_path}qs/custom/custom_spec_hh.txt"
    qsn_qs_file = f"{sample_config.prompts_path}qs/{sample_config.qs_name}/{sample_config.qs_name}.txt"
    # openq_preamble_file = f"{sample_config.prompts_path}experiment/openq_preamble.txt"
    # openq_rep_file = f"{sample_config.prompts_path}experiment/{sample_config.qs_name}_open_rep.txt"

    with open(instructions_file, encoding='utf-8') as f:
        instructions = f.read().split('^^^')
        # instructions = f.read()

    instr_dict = {}
    for instr in instructions:
        instr_name = instr.split('\n')[1]
        instr_content = '\n'.join(instr.split('\n')[2:])
        instr_dict[instr_name] = instr_content

    # with open(openq_preamble_file, encoding='utf-8') as f:
    #     openq_preamble = f.read()

    with open(openq_file, encoding='utf-8') as f:
        open_qs = f.read().split('^^^')

    open_qs_dict = {}
    for open_q in open_qs:
        open_q_name = open_q.split('\n')[1]
        open_q_content = open_q.split('\n')[2:][0].replace("\xa0", " ")
        # open_q_content = re.sub('[^a-zA-Z0-9]+$', '', open_q_content)

        open_qs_dict[open_q_name] = open_q_content

    with open(closed_qs_file, encoding='utf-8') as f:
        closed_qs = f.read()

    closed_qs_sections = closed_qs.split('\n^^\n')
    closed_qs = closed_qs_sections[1].split('\n')
    closed_qs_dict = {q.split('.')[0]: q.split('. ')[1] for q in closed_qs}

    # with open(closed_qs_spec_file, encoding='utf-8') as f:
    #     closed_qs_spec = f.read()
    #
    # closed_qs_spec_sections = closed_qs_spec.split('\n^^\n')
    # closed_qs_spec = closed_qs_spec_sections[1].split('\n')
    # closed_qs_spec_dict = {q.split('.')[0]: q.split('. ')[1] for q in closed_qs_spec}

    with open(qsn_qs_file, encoding='utf-8') as f:
        qsn_qs = f.read()
    prompt_sections = qsn_qs.split('\n^^\n')
    qsn_qs = prompt_sections[2].split('\n')
    if sample_config.skip_sui:
        qsn_qs.pop()

    inv_map = qs_inv_maps[sample_config.qs_name]
    qsn_qs_dict = {inv_map[sample_config.qs_name + '_q' + q.split('.')[0]]: 'Problem: ' + q.split('. ')[1] for q in
                   qsn_qs}

    closed_qs_dict = dict(sorted((qsn_qs_dict | closed_qs_dict).items()))
    # closed_qs_spec_dict = dict(sorted((qsn_qs_dict | closed_qs_spec_dict).items()))

    # with open(openq_spec_file, encoding='utf-8') as f:
    #     open_spec_qs = f.read().split('^^^')
    #
    # open_spec_qs_dict = {}
    # for open_spec_q in open_spec_qs:
    #     open_spec_q_name = open_spec_q.split('\n')[1]
    #     open_spec_q_content = open_spec_q.split('\n')[2:][0].replace("\xa0", " ")
    #     # open_spec_q_content = re.sub('[^a-zA-Z0-9]+$', '', open_spec_q_content)
    #
    #     open_spec_qs_dict[open_spec_q_name] = open_spec_q_content

    # with open(openq_rep_file, encoding='utf-8') as f:
    #     open_qs_rep = f.read().split('^^^')
    #
    # open_qs_rep_dict = {}
    # for open_q_rep in open_qs_rep:
    #     open_q_rep_name = open_q_rep.split('\n')[1]
    #     open_q_rep_content = open_q_rep.split('\n')[2:][0]
    #     open_qs_rep_dict[open_q_rep_name] = open_q_rep_content

    # return instr_dict, open_qs_dict, open_qs_rep_dict
    return instr_dict, open_qs_dict, closed_qs_dict,


def load_gen_qsn_content(sample_config):
    # instructions_file = f"{sample_config.prompts_path}experiment/{sample_config.instr_name}.txt"
    # openq_file = f"{sample_config.prompts_path}experiment/{sample_config.qs_name}_open_gen.txt"
    # openq_spec_file = f"{sample_config.prompts_path}experiment/{sample_config.qs_name}_open_spec_hh.txt"
    # closed_qs_file = f"{sample_config.prompts_path}qs/custom/custom.txt"
    gen_qsn_qs_file = f"{sample_config.prompts_path}qs/{sample_config.gen_qs_name}/{sample_config.gen_qs_name}.txt"
    # openq_preamble_file = f"{sample_config.prompts_path}experiment/openq_preamble.txt"
    # openq_rep_file = f"{sample_config.prompts_path}experiment/{sample_config.qs_name}_open_rep.txt"

    with open(gen_qsn_qs_file, encoding='utf-8') as f:
        gen_qsn_qs = f.read()
    prompt_sections = gen_qsn_qs.split('\n^^\n')
    gen_qsn_qs = prompt_sections[2].split('\n')

    # inv_map = qs_inv_maps[sample_config.gen_qs_name]
    if sample_config.gen_qs_name in ['sds', 'ids30']:
        gen_qsn_qs_dict = {sample_config.gen_qs_name + '_q' + q.split('.')[0]: 'Item: ' + q.split('. ')[1] for q in
                           gen_qsn_qs}
    if sample_config.gen_qs_name in ['phq9', 'gad7']:
        gen_qsn_qs_dict = {sample_config.gen_qs_name + '_q' + q.split('.')[0]: 'Problem: ' + q.split('. ')[1] for q in
                           gen_qsn_qs}
    # if sample_config.gen_qs_name == 'ids30':
    #     gen_qsn_qs_dict = {sample_config.gen_qs_name + '_q' + q.split('.')[0]: 'Item: ' + q.split('. ')[1] for q in
    #                        gen_qsn_qs}

    gen_qsn_preamble = prompt_sections[0]
    if sample_config.gen_qs_name == 'ids30':
        gen_qsn_scale = None
    else:
        gen_qsn_scale = prompt_sections[1]
    gen_qsn_instr_dict = {'preamble': gen_qsn_preamble, 'scale': gen_qsn_scale}

    return gen_qsn_qs_dict, gen_qsn_instr_dict


def get_openq_data(openq_data, tasks_incl, sample_config, context_qs_name_list):
    instr_dict, open_qs_dict, closed_qs_dict = load_task_content(sample_config)
    # gen_qsn_dict, gen_instr_dict = load_gen_qsn_content(sample_config)
    oq_names = list(open_qs_dict.keys())  # + list(open_qs_rep_dict.keys())
    # openq_data=openq_data_all
    # openq_data = openq_data[openq_data['sub'] == 'sub115_v4_dd']

    # openq_data = pd.read_csv(f"{paths.files_data_dir}openq_data.csv")
    openq_data = openq_data[openq_data['task_version'].isin(tasks_incl)]
    # preprocess text
    openq_data = openq_data.replace(r'\s+\.', '.', regex=True)
    openq_data = openq_data.replace(r'\s+,', ', ', regex=True)
    openq_data = openq_data.replace(r'\n+,', ' ', regex=True)
    df_obj = openq_data.select_dtypes('object')
    openq_data[df_obj.columns] = df_obj.apply(lambda x: x.str.strip())
    openq_data_long = pd.melt(openq_data, id_vars=['sub'], value_vars=oq_names, var_name='q_name').sort_values(
        by=['sub', 'q_name']).reset_index(drop=True)

    empty_resp_idx = openq_data_long['value'].apply(lambda x: len(x.split(' ')) if type(x) == str else 0) == 0
    # openq_data_long = openq_data_long.loc[~empty_resp_idx, :]
    ## empty
    openq_data_long.loc[empty_resp_idx, 'value'] = ''
    # openq_data_long['value'] = openq_data_long['value'].str.replace(r'[^a-zA-Z0-9]+$', '', regex=True)

    ## include specificity level for version 3
    openq_data_long['spec_level'] = 'gen'
    openq_data_long.loc[openq_data_long['sub'].str.contains('v3_s'), 'spec_level'] = openq_data_long.loc[
        openq_data_long['sub'].str.contains('v3_s'), 'q_name']
    openq_data_long.loc[openq_data_long['sub'].str.contains('v3_s'), 'spec_level'] = \
        openq_data_long[openq_data_long['sub'].str.contains('v3_s')]['spec_level'].replace(spec_levels['v3_s'])
    del df_obj, empty_resp_idx, instr_dict, oq_names  # , phq9_names

    openq_data_long = openq_data_long[openq_data_long['q_name'].isin(context_qs_name_list)]

    return openq_data_long


def setup_model(sample_config):
    tokenizer = AutoTokenizer.from_pretrained(sample_config.model_name_hf, use_fast=True, trust_remote_code=True)
    if sample_config.model_name not in ['gemma-2-2b-it']:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    model = AutoModelForCausalLM.from_pretrained(sample_config.model_name_hf, device_map=device_name,
                                                 trust_remote_code=True, revision=sample_config.revision,
                                                 attn_implementation='eager', torch_dtype=torch.bfloat16)
    model.eval()
    # model = AutoModelForCausalLM.from_pretrained(sample_config.model_name_hf, device_map="auto",
    #                                              trust_remote_code=False, revision=sample_config.revision,
    #                                              torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
    # model = AutoModelForCausalLM.from_pretrained(sample_config.model_name_hf, device_map="auto",
    #                                              trust_remote_code=False, revision=sample_config.revision,
    if sample_config.model_name not in ['gemma-2-2b-it']:
        model.generation_config.pad_token_id = model.generation_config.eos_token_id
    # if os.uname()[0] != 'Darwin':  # if not on mac
    #     model = exllama_set_max_input_length(model, 32000)
    # attn_implementation="flash_attention_2")
    return model, tokenizer


def format_messages(messages, toker, sample_config):
    template_name = None
    if len(re.findall('Mistral-.*-Instruct', sample_config.model_name_hf)) > 0:
        template_name = 'mistral-instruct'
    if len(re.findall('Orca', sample_config.model_name_hf)) > 0:
        template_name = 'chatml'
    if len(re.findall('.*gemma.*-it.*', sample_config.model_name_hf)) > 0:
        template_name = 'gemma-it'
    if len(re.findall('.*CausalLM.*', sample_config.model_name_hf)) > 0:
        template_name = 'chatml'
    if len(re.findall('vicuna', sample_config.model_name_hf)) > 0:
        template_name = 'vicuna'
    if len(re.findall('Llama-.*-chat', sample_config.model_name_hf)) > 0:
        template_name = 'llama-2-chat'
    if len(re.findall('llama2_.*_uncensored', sample_config.model_name_hf)) > 0:
        template_name = 'llama-2-uncensored'
    if len(re.findall('.*Qwen.*-Chat.*', sample_config.model_name_hf)) > 0:
        template_name = 'chatml'
    if len(re.findall('gpt-', sample_config.model_name_hf)) > 0:
        template_name = 'gpt'

    if (template_name is not None) and (template_name != 'llama-2-uncensored') and (template_name != 'gpt'):
        chat_template = open(
            sample_config._paths.sub_path + 'ch_temp/chat_templates/' + template_name + '.jinja').read()
        chat_template = chat_template.replace('    ', '').replace('\n', '')
        toker.chat_template = chat_template
        messages_formatted = toker.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        if template_name != 'mistral-instruct':
            messages_formatted = re.sub('<s>', '', messages_formatted)
            if template_name != 'vicuna':
                messages_formatted = re.sub('</s>', '', messages_formatted)

    else:
        if template_name == 'llama-2-uncensored':
            role_map = {'system': '### HUMAN:\n', 'user': '\n### HUMAN:\n', 'assistant': '\n### RESPONSE:\n'}
            messages_formatted = ''.join([role_map[msg['role']] + msg['content'] for msg in messages])
        elif template_name == 'gpt':
            messages_formatted = messages
        else:
            role_map = {'system': '\n', 'user': '\n', 'assistant': '\n'}
            messages_formatted = ''.join([msg['content'] + role_map[msg['role']] for msg in messages])

    return messages_formatted


def create_question_pairs_instr_perm(sample_config):
    # load instructions and intro
    instr_dict, open_qs_dict, closed_qs_dict = load_task_content(sample_config)
    intro_prompt = instr_dict['intro']
    if sample_config.gen_qs_name == 'ids30':
        scale_instr = instr_dict['scale_instr_ids30']
    else:
        scale_instr = instr_dict['scale_instr']

    qsn_questions, gen_instr_dict = load_gen_qsn_content(sample_config)

    # load and get questionnaire questions
    qs_prompt_file = f"{sample_config.prompts_path}qs/{sample_config.qs_name}/{sample_config.qs_name}.txt"
    with open(qs_prompt_file, encoding='utf-8') as f:
        qs_prompt = f.read()

    # qsn_preamble = '\n' + instr_dict[sample_config.qs_name + '_preamble'] + '\n'
    qsn_preamble = '\n' + gen_instr_dict['preamble'] + '\n'
    # prompt_sections = qs_prompt.split('\n^^\n')
    # qsn_questions = prompt_sections[2].split('\n')
    # qsn_questions = ['' + q + '\n\n' + scale_instr + "\n" + prompt_sections[1] + '\n\nAnswer: ' for q in
    #                  enumerate(gen_qsn_dict.values())]
    # qsn_questions = ['' + q + '\n\n' + scale_instr + "\n" + prompt_sections[1] + '\n\nAnswer: ' for q in
    #                  qsn_questions]
    # if sample_config.skip_sui:
    #     qsn_questions.pop()
    # qsn_questions_dict = {sample_config.qs_name + '_q' + q.split('.')[0]: 'Problem: ' + q.split('. ')[1] for q in
    #                       qsn_questions}
    if sample_config.gen_qs_name == 'ids30':
        qsn_questions_dict = {
            k: '' + '\n' + v + '\n\n' + scale_instr.replace('X', ', '.join(
                [f"O{o + 1}" for o in range(len(ids30_map[k]))])) + '\n' +
               ids30_scales_let[k] + '\nAnswer:'
            for k, v in qsn_questions.items()}
        # qsn_questions_dict = {
        #     k: '' + '\n' + v + '\n\n' + scale_instr.replace('X', ', '.join(letters[0:len(ids30_map[k])])) + '\n' +
        #        ids30_scales_let[k] + '\nAnswer:'
        #     for k, v in qsn_questions.items()}
    else:
        qsn_questions_dict = {k: '' + '\n' + v + '\n\n' + scale_instr + "\n" + gen_instr_dict['scale'] + '\n\nAnswer:'
                              for k, v in qsn_questions.items()}

    # load and get closed questions - leveled
    # cq_instr = '\n\n' + instr_dict['closed_instr'] + '\n'
    # cq_preamble = '\n\n' + instr_dict['closed_preamble'] + '\n'
    # qs_custom_file = f"{sample_config.prompts_path}qs/custom/custom.txt"
    # with open(qs_custom_file, encoding='utf-8') as f:
    #     qs_custom = f.read()
    #
    # custom_sections = qs_custom.split('\n^^\n')
    # closed_questions = custom_sections[1].split('\n')
    # closed_questions = ['' + q + '\n\n' + scale_instr + "\n" + custom_sections[0] + '\n\nAnswer: ' for q in
    #                     closed_questions]
    # closed_questions_dict = {q.split('.')[0]: 'Statement: \n' + q.split('. ')[1] for q in closed_questions}
    # closed_questions_dict['rep_lvl2_q1'] = closed_questions_dict['lvl2_q1']

    oq_instr = instr_dict['openq_instr'] + '\n'
    # return intro_prompt, oq_instr, open_qs_dict, open_spec_qs_dict, closed_questions_dict, cq_preamble, qsn_questions_dict, cq_instr, qsn_preamble
    return intro_prompt, oq_instr, open_qs_dict, qsn_questions_dict, qsn_preamble


def find_loc(input_text, to_find_text, tokenizer, n=1):
    input_ids = tokenizer(input_text, return_tensors="pt", padding=True,
                          return_offsets_mapping=True)  # .input_ids[0].tolist()
    loc_start_text = find_subsequence_n(input_text, to_find_text, n)

    if loc_start_text != -1:
        loc_end_text = loc_start_text + len(to_find_text)
        token_locations = np.where(
            (input_ids.offset_mapping[0, :, 0] >= loc_start_text) & (
                    input_ids.offset_mapping[0, :, 1] <= loc_end_text))[0]
        loc_start = int(token_locations[0])
        loc_end = int(token_locations[-1] + 1)
        last_token_loc = int(token_locations[-1])
    else:
        loc_start = -1
        loc_end = -1
        last_token_loc = -1

    last_token = tokenizer.decode(input_ids.input_ids[0][last_token_loc])

    return loc_start, loc_end, last_token_loc, last_token


def get_sub_conv_loc_perm(i, qs_name, cq_qs, cq_Answer, context_qs_name_list, sub_context_text, openq_data_long,
                          loc_keys,
                          last_token_locs, tokenizer, sample_config):
    _, open_qs_dict, _ = load_task_content(sample_config)
    if i == 0:
        for context_qs in context_qs_name_list:
            openq_data_sub_q = openq_data_long[
                (openq_data_long['sub'] == sample_config.subj) & (openq_data_long['q_name'] == context_qs)].reset_index(
                drop=True)

            oq_qs = re.sub('[^a-zA-Z0-9]+$', '', open_qs_dict[context_qs])
            oq_Answer = open_qs_dict[context_qs] + '\n\n' + 'Answer'

            loc_key = 'oq_qs^' + context_qs
            loc_keys.append(loc_key)
            # print(loc_key)
            loc_oq_qs, loc_e_oq_qs, loc_t_oq_qs, l_t_oq_qs = find_loc(sub_context_text, oq_qs, tokenizer)
            # print(loc_oq_qs, loc_e_oq_qs, loc_t_oq_qs, l_t_oq_qs)
            last_token_locs[loc_key] = loc_t_oq_qs

            loc_key = 'oq_Answer^' + context_qs
            loc_keys.append(loc_key)
            # print(loc_key)
            loc_oq_Answer, loc_e_oq_Answer, loc_t_oq_Answer, l_t_oq_Answer = find_loc(sub_context_text, oq_Answer,
                                                                                      tokenizer)
            # print(loc_oq_Answer, loc_e_oq_Answer, loc_t_oq_Answer, l_t_oq_Answer)
            last_token_locs[loc_key] = loc_t_oq_Answer

            loc_key = 'oq_ans^' + context_qs
            loc_keys.append(loc_key)
            # print(loc_key)
            if len(openq_data_sub_q) > 0:
                oq_ans = openq_data_sub_q['value'][0]
                oq_ans = re.sub('[^a-zA-Z0-9]+$', '', oq_ans)
                loc_oq_ans, loc_e_oq_ans, loc_t_oq_ans, l_t_oq_ans = find_loc(sub_context_text, oq_ans, tokenizer)
                # print(loc_oq_ans, loc_e_oq_ans, loc_t_oq_ans, l_t_oq_ans)
            else:
                loc_oq_ans, loc_e_oq_ans, loc_t_oq_ans, l_t_oq_ans = -1, -1, -1, None
                # print(loc_oq_ans, loc_e_oq_ans, loc_t_oq_ans, l_t_oq_ans)
            last_token_locs[loc_key] = loc_t_oq_ans

    loc_key = 'cq_qs^' + qs_name
    loc_keys.append(loc_key)
    # print(loc_key)
    loc_cq_qs, loc_e_cq_qs, loc_t_cq_qs, l_t_cq_qs = find_loc(sub_context_text, cq_qs, tokenizer)
    # print(loc_cq_qs, loc_e_cq_qs, loc_t_cq_qs, l_t_cq_qs)
    last_token_locs[loc_key] = loc_t_cq_qs

    loc_key = 'cq_Answer^' + qs_name
    loc_keys.append(loc_key)
    # print(loc_key)
    loc_cq_Answer, loc_e_cq_Answer, loc_t_cq_Answer, l_t_cq_Answer = find_loc(sub_context_text, cq_Answer, tokenizer)
    # print(loc_cq_Answer, loc_e_cq_Answer, loc_t_cq_Answer, l_t_cq_Answer)
    last_token_locs[loc_key] = loc_t_cq_Answer

    # last token before answer too
    loc_key = 'cq_last^' + qs_name
    loc_keys.append(loc_key)
    # print(loc_key

    # loc_last, loc_e_last, loc_t_last, l_t_last = -1, -1, -1, None
    loc_last, loc_e_last, l_t_last = -1, -1, None

    loc_t_last = len(tokenizer(sub_context_text, return_tensors="pt", padding=True,
                               return_offsets_mapping=True).input_ids[0].tolist()) - 1
    # print(loc_last, loc_e_last, loc_t_last, l_t_last)
    last_token_locs[loc_key] = loc_t_last

    return last_token_locs, loc_keys


def get_sub_conv_ans_loc_perm(qs_name, cq_ans, sub_context_text, last_token_locs, loc_keys, tokenizer):
    loc_key = 'cq_ans^' + qs_name
    # print(loc_key)
    loc_keys.append(loc_key)
    loc_cq_ans, loc_e_cq_ans, loc_t_cq_ans, l_t_cq_ans = find_loc(sub_context_text, cq_ans, tokenizer, n=-1)
    # print(loc_cq_ans, loc_e_cq_ans, loc_t_cq_ans, l_t_cq_ans)
    last_token_locs[loc_key] = loc_t_cq_ans
    return last_token_locs, loc_keys


def get_model_response_batch(inputs, model, tokenizer, sample_config, last_token_locs=None, loc_keys=None):
    # sampling_params = SamplingParams(top_p=sample_config.top_p, temperature=sample_config.temp, max_tokens=sample_config.max_new_tokens)
    inputs_ids = tokenizer(inputs, return_tensors="pt", padding=True).input_ids.to(device_name)
    # print([len(inp) for inp in inputs_ids])
    # hidden_states_cq = []
    # hidden_states_oq = None
    hidden_states = None
    # hidden_states_ans = None
    # attentions_preans = None

    if sample_config.save_states:
        with torch.no_grad():
            # outputs = model.generate(inputs_ids, do_sample=True, top_p=sample_config.top_p,
            #                          temperature=sample_config.temp,
            #                          max_new_tokens=sample_config.max_new_tok, return_dict_in_generate=True,
            #                          output_hidden_states=True, output_scores=False, output_attentions=False)
            outputs = model.generate(inputs_ids, do_sample=True,
                                     max_new_tokens=sample_config.max_new_tok, return_dict_in_generate=True,
                                     output_hidden_states=True, output_scores=False, output_attentions=False)

            outputs_ids = [output_ids[input_ids.shape[0]:-1] for input_ids, output_ids in
                           zip(inputs_ids, outputs.sequences)]
            outputs_texts = tokenizer.batch_decode(outputs_ids, skip_special_tokens=False)
            outputs_texts_clean = tokenizer.batch_decode(outputs_ids, skip_special_tokens=True)
            outputs_ids_clean = [
                torch.tensor([output_id for output_id in output_ids if output_id not in tokenizer.all_special_ids]).to(
                    device_name) for output_ids in outputs_ids]

            _, _, batch_model_bool_checks = process_model_response_batch_perm(outputs_texts_clean, sample_config)
            # print(batch_model_bool_checks)
            if last_token_locs is not None and loc_keys is not None:
                # saving further questions
                if sample_config.qs_i > 0:
                    hidden_states_ls = [t if model_bool else None for model_bool, t in
                                        zip(batch_model_bool_checks, torch.stack(outputs.hidden_states[0], dim=1))]
                    # hidden_states = hidden_states_ls

                    hidden_states = []
                    for h, (hidden_states_ls_item, last_token_locs_el) in enumerate(
                            zip(hidden_states_ls, last_token_locs)):
                        if hidden_states_ls_item is not None:
                            hidden_states.append({})
                            for layer in sample_config.layer_list:
                                hidden_states[h]['l_' + str(layer)] = {}
                                for k, v in last_token_locs_el.items():
                                    if hidden_states_ls_item is not None:
                                        hidden_states[h]['l_' + str(layer)][k] = hidden_states_ls_item[layer, v, :]
                                    else:
                                        hidden_states[h]['l_' + str(layer)][k] = None
                else:
                    # hidden_states_tensor = torch.stack(outputs.hidden_states[0], dim=1)[0]

                    hidden_states_ls = torch.stack(outputs.hidden_states[0], dim=1)[0]
                    hidden_states = {}
                    for layer in sample_config.layer_list:
                        hidden_states['l_' + str(layer)] = {}
                        for k, v in last_token_locs.items():
                            hidden_states['l_' + str(layer)][k] = hidden_states_ls[layer, v, :]

                    hidden_states = torch.stack(
                        [torch.stack([hidden_states['l_' + str(layer)][loc] for loc in last_token_locs.keys()]) for
                         layer in sample_config.layer_list]).detach().to('cpu')

                #     hidden_states['l_' + str(layer)]['last'] = hidden_states_tensor[layer, -1, :]
                #     # hidden_states['l_' + str(layer)]['post_answer'] = outputs.hidden_states[len(outputs.hidden_states)-1][layer][0,0,:]
                #
                # # hidden_states = torch.stack(
                # #     [torch.stack([v[k2] for k2 in qa_loc_key_names]) for k, v in hidden_states.items()])
                # if sample_config.gen_qs_name is None:
                #     hidden_states = torch.stack(
                #         [torch.stack([hidden_states['l_' + str(l)][k2] for k2 in qa_loc_key_names]) for l in
                #          sample_config.layer_list]).detach().to('cpu')
                # else:
                #     hidden_states = torch.stack(
                #         [torch.stack([hidden_states['l_' + str(l)][k2] for k2 in qa_loc_key_names_gen]) for l in
                #          sample_config.layer_list]).detach().to('cpu')

            # hidden_states_cq_list = [torch.unsqueeze(batch_tensor, 0) for batch_tensor in
            #                          torch.stack(outputs.hidden_states[0], dim=1)[:, :, -1, :]]

            # for hidden_states_cq_batch in hidden_states_cq_list:
            #     tmp_dict = {}
            #     for layer in sample_config.layer_list:
            #         tmp_dict['l_' + str(layer)] = hidden_states_cq_batch[0, layer, :]
            #     hidden_states_cq.append(tmp_dict)

            # hidden_states_cq['l_' + str(sample_config.L)] = [hidden_states_cq_batch[0, sample_config.L, :] for
            #                                                  hidden_states_cq_batch in hidden_states_cq_list]
            # hidden_states_cq['l_' + str(sample_config.midL)] = [hidden_states_cq_batch[0, sample_config.midL, :] for
            #                                                     hidden_states_cq_batch in hidden_states_cq_list]
            # hidden_states_cq = outputs.hidden_states

        # hidden_states_cq = torch.unsqueeze(torch.stack(outputs.hidden_states[0],dim=1)[:,:,-1,:],1)
        # hidden_states_ans = outputs.hidden_states[1:]
        # hidden_states_ans = torch.stack([torch.stack(out_hs_at, axis=1)[:, :, 0, :] for out_hs_at in outputs.hidden_states[1:]], axis=2)

        # attentions_preans = outputs.attentions[0]

    else:
        with torch.no_grad():
            # outputs = model.generate(inputs_ids, do_sample=True, top_p=sample_config.top_p,
            #                          temperature=sample_config.temp,
            #                          max_new_tokens=sample_config.max_new_tok)
            outputs = model.generate(inputs_ids, do_sample=True,
                                     max_new_tokens=sample_config.max_new_tok)
            outputs_ids = [output_ids[input_ids.shape[0]:-1] for input_ids, output_ids in zip(inputs_ids, outputs)]
            outputs_texts = tokenizer.batch_decode(outputs_ids, skip_special_tokens=False)
            outputs_texts_clean = tokenizer.batch_decode(outputs_ids, skip_special_tokens=True)
            outputs_ids_clean = [
                torch.tensor([output_id for output_id in output_ids if output_id not in tokenizer.all_special_ids]).to(
                    device_name) for output_ids in outputs_ids]

    # outs = model.generate(inputs_ids, do_sample=True, top_p=sample_config.top_p, temperature=sample_config.temp,
    #                       max_new_tokens=sample_config.max_new_tok)
    # outputs_texts = [output_ids[input_ids.shape[0]:-1] for input_ids, output_ids in zip(inputs_ids, outputs_ids)]
    # outputs_texts = tokenizer.batch_decode(outputs_ids, skip_special_tokens=False)
    # outputs_texts_clean = tokenizer.batch_decode(outputs_ids, skip_special_tokens=True)
    # outputs_ids_clean = [
    #     torch.tensor([output_id for output_id in output_ids if output_id not in tokenizer.all_special_ids]).to(
    #         device_name) for output_ids in outputs_ids]
    # return outputs, inputs_ids
    # return outputs, outs
    return inputs_ids, outputs_ids, outputs_ids_clean, outputs_texts, outputs_texts_clean, hidden_states


def process_model_response_batch_perm(model_outputs, sample_config):
    which_q = sample_config.qs_index
    model_outputs_proc = []
    model_outputs_scores = []
    model_outputs_bool_checks = []
    # map_dict = maps[sample_config.gen_qs_name]
    if sample_config.gen_qs_name == 'ids30':
        map_dict = maps[sample_config.gen_qs_name][sample_config.which_gen_q]
    else:
        map_dict = maps[sample_config.gen_qs_name]

    for model_output in model_outputs:
        bool_check = False
        tmp_score = -1

        model_output_proc = model_output
        lines = model_output.split('\n')
        count_lines = 0
        ans = []
        for l, line in enumerate(lines):
            line_ans = re.sub(r'(^\s+|\s+$)', '', line)  # remvoe extra spaces at the beginning/end
            if sample_config.gen_qs_name == 'ids30':
                line_ans = re.sub(r'[^A-Za-z0-9 ]+', '', line_ans)  # remove non-letters (and non-saces)
            else:
                line_ans = re.sub(r'[^A-Za-z ]+', '', line_ans)  # remove non-letters (and non-saces)

            if line_ans != '' and len(ans) == 0:
                count_lines += 1
                ans = [key for key in map_dict.keys() if key.lower() in line_ans.lower()]
            else:
                break
        if len(ans) > 0:
            ans_len = np.argsort([len(a) for a in ans])
            ans = [ans[ans_len[-1]]]
            model_output_proc = ans[0]
            tmp_score = map_dict[model_output_proc]
            bool_check = True
            if sample_config.gen_qs_name in rev_lists.keys():
                to_rev = which_q in rev_lists[sample_config.gen_qs_name]
                if to_rev:
                    tmp_score = min(map_dict.values()) + max(map_dict.values()) - tmp_score

        model_outputs_proc.append(model_output_proc)
        model_outputs_scores.append(tmp_score)
        model_outputs_bool_checks.append(bool_check)

    return model_outputs_proc, model_outputs_scores, model_outputs_bool_checks


def sample_conv_batch_perm(questions, context_qs_name_list, sample_qs_name_list, sample_qs_idx_list, batch_messages,
                           batch_messages_formatted,
                           openq_data_long, model,
                           tokenizer, sample_ts, sample_config):
    tmp_save_states = False
    qsn_questions, gen_instr_dict = load_gen_qsn_content(sample_config)

    batch_times = [[]] * len(batch_messages)
    batch_bools = [[]] * len(batch_messages)
    batch_scores = [[]] * len(batch_messages)
    # batch_hidden_states = []
    batch_output_logs = [[]] * len(batch_messages)
    batch_q_inds = [[]] * len(batch_messages)
    batch_proc_outputs = [[]] * len(batch_messages)
    batch_raw_outputs = [[]] * len(batch_messages)

    # Sample responses to each question
    for i, (qs_name, qs_index, qs) in enumerate(zip(sample_qs_name_list, sample_qs_idx_list, questions)):
        sample_config.which_gen_q = qs_name
        sample_config.qs_index = qs_index
        sample_config.qs_i = i

        cq_qs = re.sub('[^a-zA-Z0-9]+$', '', qsn_questions[qs_name])
        cq_Answer = qs[:-1]

        # if i == 0:
        # if i == 1:
        # if i == 2:
        #     break
        # print('len msgs:', len(batch_messages))
        if len(batch_messages) > 0:
            batch_q_inds = [q_inds + [i + 1] for q_inds in batch_q_inds]

            # sample responses to first question
            if i == 0:
                # last_token_batch_locs, loc_batch_keys = [], []
                # get locations for the contexrt
                # loc_keys_context = []
                # last_token_locs_context = {}

                # if i=0 then do only one example from batch, else to example by example in batch
                last_token_locs_context, loc_keys_context = get_sub_conv_loc_perm(i, qs_name, cq_qs, cq_Answer,
                                                                                  context_qs_name_list,
                                                                                  batch_messages_formatted[0],
                                                                                  openq_data_long, [], {}, tokenizer,
                                                                                  sample_config)

                # sample responses to the first question and get hidden states for the context and first question
                start = time.time()
                _, batch_model_outputs_ids, batch_model_outputs_ids_clean, _, batch_model_outputs_clean, batch_model_hidden_states = get_model_response_batch(
                    batch_messages_formatted, model, tokenizer, sample_config, last_token_locs=last_token_locs_context,
                    loc_keys=loc_keys_context)
                elapsed = round(time.time() - start, 4)
                flush()
                # print(last_token_locs_context)

                # Save hidden states for the first iteration
                if sample_config.save_states:
                    layer_keys = ['l_' + str(l) for l in sample_config.layer_list]
                    key_info = ', '.join(layer_keys) + '\n' + ', '.join(loc_keys_context)
                    key_info_fname = f"{sample_config.states_path}{sample_config.subj}^^presample^^{sample_config.context_name}^^{sample_config.gen_sample_name}^^{sample_config.model_name_rhp}_states_info_s-ts-{sample_ts}.txt"
                    tensor_fname = f"{sample_config.states_path}{sample_config.subj}^^presample^^{sample_config.context_name}^^{sample_config.gen_sample_name}^^{sample_config.model_name_rhp}_hidden_states_s-ts-{sample_ts}.pt"
                    if sample_config.save_context_states:
                        torch.save(batch_model_hidden_states.clone(), tensor_fname)
                        with open(key_info_fname, 'w') as hs_info:
                            hs_info.write(key_info)
                    del batch_model_hidden_states
                    flush()
                if sample_config.save_states:
                    tmp_save_states = True
                    sample_config.save_states = False

            # sample responses to remaining quesitons
            elif i > 0 and (not sample_config.do_only_first_states):
                # append questions if Q>1
                batch_messages = [batch_message + [{'role': 'user', 'content': qs}] for batch_message in
                                  batch_messages]
                batch_messages_formatted = [format_messages(batch_message, tokenizer, sample_config) for
                                            batch_message_formatted, batch_message in
                                            zip(batch_messages_formatted, batch_messages)]

                for b, (sub_context_text, last_token_locs, loc_keys) in enumerate(
                        zip(batch_messages_formatted, last_token_batch_locs, loc_batch_keys)):
                    # print(last_token_locs, loc_keys)
                    # get locations for closed question when i>0 for each element in batcg
                    last_token_locs, loc_keys = get_sub_conv_loc_perm(i, qs_name, cq_qs, cq_Answer,
                                                                      context_qs_name_list,
                                                                      sub_context_text, openq_data_long, loc_keys,
                                                                      last_token_locs,
                                                                      tokenizer, sample_config)

                # Sample responses to next question
                start = time.time()
                # _, batch_model_outputs_ids, batch_model_outputs_ids_clean, _, batch_model_outputs_clean, batch_model_hidden_states = get_model_response_batch(
                #     batch_messages_formatted, model, tokenizer, sample_config, last_token_locs=last_token_batch_locs,
                #     loc_keys=loc_batch_keys)
                _, batch_model_outputs_ids, batch_model_outputs_ids_clean, _, batch_model_outputs_clean, batch_model_hidden_states = get_model_response_batch(
                    batch_messages_formatted, model, tokenizer, sample_config, last_token_locs=None,
                    loc_keys=None)
                elapsed = round(time.time() - start, 4)

                # if len(batch_hidden_states) == 0:
                #     batch_hidden_states = [batch_model_hidden_state for batch_model_hidden_state in
                #                            batch_model_hidden_states]
                # else:
                #     for batch_hidden_state, batch_model_hidden_state in zip(batch_hidden_states, batch_model_hidden_states):
                #         for layer in batch_hidden_state.keys():
                #             batch_hidden_state[layer] = batch_hidden_state[layer] | batch_model_hidden_state[layer]

                flush()

            # process responses
            if not sample_config.do_only_first_states:
                batch_model_outputs_proc, batch_model_scores, batch_model_bool_checks = process_model_response_batch_perm(
                    batch_model_outputs_clean, sample_config)
                # print(batch_model_outputs_clean, batch_model_outputs_proc)

                # save outputs, socres, times, etc
                batch_output_logs = [
                    output_log + ['---------------q: ' + str(qs_name) + '--------------'] + [model_output_clean]
                    for
                    output_log, model_output_clean in zip(batch_output_logs, batch_model_outputs_clean)]
                batch_raw_outputs = [batch_raw_output + [model_output_clean] for batch_raw_output, model_output_clean in
                                     zip(batch_raw_outputs, batch_model_outputs_clean)]

                batch_scores = [batch_score + [batch_model_score] for batch_score, batch_model_score in
                                zip(batch_scores, batch_model_scores)]
                batch_bools = [batch_bool + [batch_model_bool] for batch_bool, batch_model_bool in
                               zip(batch_bools, batch_model_bool_checks)]
                batch_times = [batch_time + [elapsed] for batch_time in batch_times]
                batch_proc_outputs = [batch_proc_output + [batch_model_output_proc] for
                                      batch_proc_output, batch_model_output_proc in
                                      zip(batch_proc_outputs, batch_model_outputs_proc)]

                # append responses
                batch_messages = [batch_message + [{'role': 'assistant', 'content': batch_model_output_proc}] for
                                  batch_message, batch_model_output_proc
                                  in zip(batch_messages, batch_model_outputs_proc)]
                batch_messages_formatted = [format_messages(batch_message, tokenizer, sample_config) for
                                            batch_message_formatted, batch_message in
                                            zip(batch_messages_formatted, batch_messages)]

                # print('bools: ', batch_bools)

                # save fails
                for s, (batch_bool, batch_q_ind, batch_score, batch_time, batch_output_log, batch_message,
                        batch_message_formatted, batch_output_log, batch_proc_output, batch_raw_output) in enumerate(
                    zip(batch_bools, batch_q_inds, batch_scores, batch_times, batch_output_logs, batch_messages,
                        batch_messages_formatted, batch_output_logs, batch_proc_outputs, batch_raw_outputs)):

                    if not batch_bool[-1]:
                        # print('failed:', batch_raw_output,batch_q_ind)
                        tmp_dict = {'index': i, 'sub': sample_config.subj, 'sample_ts': sample_ts + str(s),
                                    'context_qs': sample_config.qs_name, 'context_name': sample_config.context_name,
                                    'gen_qs': sample_config.gen_qs_name, 'question_gen': sample_config.which_gen_q,
                                    'score': batch_score, 'response': batch_proc_output,
                                    'raw_response': batch_raw_output,
                                    'model': sample_config.model_name, 'temp': sample_config.temp,
                                    'top_p': sample_config.top_p,
                                    'status': batch_bool, 'time': batch_time, 'instr_name': sample_config.instr_name,
                                    'openq_fname': sample_config.openq_fname,
                                    'gen_sample_name': sample_config.gen_sample_name,
                                    'nSamples': sample_config.remBatchSize}
                        tmp_output_pd = pd.DataFrame(tmp_dict, index=range(len(tmp_dict['score'])))

                        batch_output_log = '\n^^^\n' + '\n'.join(batch_output_log) + '\n^^^\n'
                        if type(batch_message_formatted) != list:
                            txt_content = batch_message_formatted  # + '\n\n' + json.dumps(batch_message, indent=1)
                            # txt_content = batch_message_formatted + '\n\n' + json.dumps(batch_message, indent=1)
                        else:
                            txt_content = json.dumps(batch_message, indent=1)
                        txt_content = batch_output_log + txt_content

                        # save output to txt - FAILED
                        if sample_config.save_fail_outputs:
                            with open(sample_config.outputs_path + 'fails/output_model_s-ts-' + sample_ts + str(
                                    s) + '.txt',
                                      'w',
                                      encoding='utf-8') as output_file:
                                # output_file.write(txt_content + '\n' + tmp_output_pd.to_string())
                                output_file.write(txt_content)
                        # save responses csv - FAILED
                        tmp_output_pd.to_csv(
                            sample_config.responses_path + 'fails/response_model_s-ts-' + sample_ts + str(s) + '.csv',
                            index=False)

                # discard fails
                batch_messages = [batch_message for batch_model_bool_check, batch_message in
                                  zip(batch_model_bool_checks, batch_messages) if batch_model_bool_check]
                batch_messages_formatted = [batch_message_formatted for batch_model_bool_check, batch_message_formatted in
                                            zip(batch_model_bool_checks, batch_messages_formatted) if
                                            batch_model_bool_check]
                batch_times = [batch_time for batch_model_bool_check, batch_time in
                               zip(batch_model_bool_checks, batch_times) if
                               batch_model_bool_check]
                batch_scores = [batch_score for batch_model_bool_check, batch_score in
                                zip(batch_model_bool_checks, batch_scores) if batch_model_bool_check]
                batch_output_logs = [batch_output_log for batch_model_bool_check, batch_output_log in
                                     zip(batch_model_bool_checks, batch_output_logs) if batch_model_bool_check]
                batch_q_inds = [batch_q_ind for batch_model_bool_check, batch_q_ind in
                                zip(batch_model_bool_checks, batch_q_inds) if batch_model_bool_check]
                batch_proc_outputs = [batch_proc_output for batch_model_bool_check, batch_proc_output in
                                      zip(batch_model_bool_checks, batch_proc_outputs) if batch_model_bool_check]
                batch_raw_outputs = [batch_raw_output for batch_model_bool_check, batch_raw_output in
                                     zip(batch_model_bool_checks, batch_raw_outputs) if batch_model_bool_check]

                if i > 0:
                    last_token_batch_locs = [last_token_batch_loc for batch_model_bool_check, last_token_batch_loc in
                                             zip(batch_model_bool_checks, last_token_batch_locs) if batch_model_bool_check]
                    loc_batch_keys = [loc_batch_key for batch_model_bool_check, loc_batch_key in
                                      zip(batch_model_bool_checks, loc_batch_keys) if batch_model_bool_check]

                batch_bools = [batch_bool for batch_model_bool_check, batch_bool in
                               zip(batch_model_bool_checks, batch_bools) if
                               batch_model_bool_check]
                # print(batch_bools,'after discard')

                # get locations for the responses to closed question
                if i == 0:
                    last_token_batch_locs, loc_batch_keys = [], []
                    for sub_context_text, proc_outputs in zip(batch_messages_formatted, batch_proc_outputs):
                        cq_ans = proc_outputs[-1]
                        loc_keys = []
                        last_token_locs = {}

                        last_token_locs, loc_keys = get_sub_conv_ans_loc_perm(qs_name, cq_ans, sub_context_text,
                                                                              last_token_locs, loc_keys,
                                                                              tokenizer)
                        # loc_keys.append(cq_ans_loc_key)
                        last_token_batch_locs.append(last_token_locs)
                        loc_batch_keys.append(loc_keys)
                else:
                    for b, (sub_context_text, proc_outputs, last_token_locs, loc_keys) in enumerate(
                            zip(batch_messages_formatted, batch_proc_outputs, last_token_batch_locs, loc_batch_keys)):
                        # print(last_token_locs, loc_keys)
                        cq_ans = proc_outputs[-1]

                        last_token_locs, loc_keys = get_sub_conv_ans_loc_perm(qs_name, cq_ans, sub_context_text,
                                                                              last_token_locs, loc_keys,
                                                                              tokenizer)

    # % Get hidden states for continued conversation for all batches in all locations
    if not sample_config.do_only_first_states:
        batch_hidden_states = []
        if tmp_save_states and sample_config.save_sample_states:
            key_info_fname = f"{sample_config.states_path}{sample_config.subj}^^sample^^{sample_config.context_name}^^{sample_config.gen_sample_name}^^{sample_config.model_name_rhp}_states_info_s-ts-{sample_ts}.txt"
            # try:
            layer_keys = ['l_' + str(l) for l in sample_config.layer_list]
            key_info = ', '.join(layer_keys) + '\n' + ', '.join(loc_batch_keys[0])
            with open(key_info_fname, 'w') as hs_info:
                hs_info.write(key_info)

            sample_config.save_states = True
            # start = time.time()
            inputs_ids = tokenizer(batch_messages_formatted, return_tensors="pt", padding=True).input_ids.to(
                device_name)
            flush()
            outputs = model(inputs_ids, output_hidden_states=True, output_attentions=False)
            hidden_states_ls = torch.stack(outputs.hidden_states, axis=1)
            for h, (hidden_states_ls_item, last_token_locs_el) in enumerate(
                    zip(hidden_states_ls, last_token_batch_locs)):
                # print(h, last_token_locs_el)
                if hidden_states_ls_item is not None:
                    batch_hidden_states.append({})
                    for layer in sample_config.layer_list:
                        batch_hidden_states[h]['l_' + str(layer)] = {}
                        for k, v in last_token_locs_el.items():
                            if hidden_states_ls_item is not None:
                                batch_hidden_states[h]['l_' + str(layer)][k] = hidden_states_ls_item[layer, v, :]
                            else:
                                batch_hidden_states[h]['l_' + str(layer)][k] = None
            del hidden_states_ls, hidden_states_ls_item
            flush()
            batch_hidden_states = [torch.stack(
                [torch.stack([batch_hidden_state['l_' + str(layer)][loc] for loc in loc_batch_keys[0]]) for layer in
                 sample_config.layer_list]) for batch_hidden_state in batch_hidden_states]
            # except:
            #     print(f"error with states {key_info_fname}")
        else:
            batch_hidden_states = [[]] * len(batch_messages)

        # save all good responses
        if len(batch_bools) > 0:
            # print('good batch bools',batch_bools)
            # print('hidden states list',batch_hidden_states)
            # print('all tio save',(batch_bools, batch_q_inds, batch_scores, batch_proc_outputs, batch_raw_outputs, batch_times,
            #         batch_output_logs, batch_messages, batch_messages_formatted, batch_hidden_states))
            for s, (
                    batch_bool, batch_q_ind, batch_score, batch_proc_output, batch_raw_output, batch_time, batch_output_log,
                    batch_message,
                    batch_message_formatted, batch_hidden_state) in enumerate(
                zip(batch_bools, batch_q_inds, batch_scores, batch_proc_outputs, batch_raw_outputs, batch_times,
                    batch_output_logs, batch_messages, batch_messages_formatted, batch_hidden_states)):
                if all(batch_bool):
                    # print('all good')
                    tmp_dict = {'index': range(len(sample_qs_name_list)), 'sub': sample_config.subj,
                                'sample_ts': sample_ts + str(s), 'context_qs': sample_config.qs_name,
                                'context_name': sample_config.context_name, 'gen_qs': sample_config.gen_qs_name,
                                'question_gen': sample_qs_name_list, 'score': batch_score, 'response': batch_proc_output,
                                'raw_response': batch_raw_output, 'model': sample_config.model_name,
                                'temp': sample_config.temp,
                                'top_p': sample_config.top_p, 'status': batch_bool, 'time': batch_time,
                                'instr_name': sample_config.instr_name, 'openq_fname': sample_config.openq_fname,
                                'gen_sample_name': sample_config.gen_sample_name, 'nSamples': sample_config.remBatchSize}
                    tmp_output_pd = pd.DataFrame(tmp_dict, index=range(len(tmp_dict['score'])))

                    # store_pds.append(tmp_output_pd)
                    batch_output_log = '\n^^^\n' + '\n'.join(batch_output_log) + '\n^^^\n'
                    if type(batch_message_formatted) != list:
                        txt_content = batch_message_formatted  # + '\n\n' + json.dumps(batch_message, indent=1)
                        # txt_content = batch_message_formatted + '\n\n' + json.dumps(batch_message, indent=1)
                    else:
                        txt_content = json.dumps(batch_message, indent=1)
                    txt_content = batch_output_log + txt_content
                    # print(tmp_output_pd, '\n')

                    # save output to txt
                    if sample_config.save_ok_outputs:
                        with open(sample_config.outputs_path + 'output_model_s-ts-' + sample_ts + str(s) + '.txt', 'w',
                                  encoding='utf-8') as output_file:
                            output_file.write(txt_content)
                    # save responses csv
                    tmp_output_pd.to_csv(
                        sample_config.responses_path + 'response_model_s-ts-' + sample_ts + str(s) + '.csv', index=False)

                    # # save hidden states
                    if sample_config.save_states and sample_config.save_sample_states:
                        tensor_fname = f"{sample_config.states_path}{sample_config.subj}^^sample^^{sample_config.context_name}^^{sample_config.gen_sample_name}^^{sample_config.model_name_rhp}_hidden_states_s-ts-{sample_ts}{s}.pt"
                        torch.save(batch_hidden_state.clone(), tensor_fname)
                        # flush()
                        # del batch_hidden_state

        del batch_hidden_states
        flush()
    sample_config.save_states=True


def get_sub_context_pairs_perm(openq_data_long, context_qs_name_list, sample_qs_name_list, tokenizer, sample_config):
    intro_prompt, oq_instr, open_qs_dict, qsn_gen_questions_dict, qsn_preamble = create_question_pairs_instr_perm(
        sample_config)
    openq_data_sub = openq_data_long[openq_data_long['sub'] == sample_config.subj].reset_index(drop=True)
    context_conv = []
    for q_name in context_qs_name_list:
        # print(q_name)
        row = openq_data_sub[openq_data_sub['q_name'] == q_name]
        openq_text_mid = open_qs_dict[q_name]
        # if openq_data_sub['spec_level'][0] == 'gen':
        #     openq_text_mid = open_qs_dict[q_name]
        #     # openq_text = 'Question:\n' + open_qs_dict[q_name] + '\n\nAnswer:'
        # else:
        #     openq_text_mid = open_spec_qs_dict[q_name]
        #     # openq_text = 'Question:\n' + open_spec_qs_dict[q_name] + '\n\nAnswer:'
        openq_text = 'Question:\n' + openq_text_mid + '\n\nAnswer:'
        openq_ans = '\n' + str(row['value'].values[0]) + '\n\n'

        if intro_prompt != '':
            context_conv.append({'role': 'system', 'content': intro_prompt})
        context_conv.append({'role': 'user', 'content': oq_instr + openq_text})
        context_conv.append({'role': 'assistant', 'content': openq_ans})

    sample_qs_name = sample_qs_name_list[0]
    closedq_text = qsn_gen_questions_dict[sample_qs_name]

    context_conv.append({'role': 'user', 'content': '\n' + qsn_preamble + closedq_text})
    context_conv_formatted = format_messages(context_conv, tokenizer, sample_config)

    questions = [qsn_preamble + qsn_gen_questions_dict[q] for q in sample_qs_name_list]

    return context_conv, context_conv_formatted, questions


# //////////////////////////////////  OLD STUFF


def sample_model_responses_batch_gen(sub_qs_pairs, sub_q_token_locs, model, tokenizer, sample_ts, sample_config):
    start = time.time()
    _, _, _, _, batch_model_outputs_clean, model_hidden_states = get_model_response_batch(
        sub_qs_pairs, model, tokenizer, sample_config, sub_q_token_locs)
    batch_time = round(time.time() - start, 4)
    # print(batch_time)

    # process responses
    batch_proc_outputs, batch_scores, batch_bools = process_model_response_batch_gen(batch_model_outputs_clean,
                                                                                     sample_config)

    batch_output_logs = [
        '---------------context q: ' + sample_config.which_q + '--------------\n' + '---------------gen q: ' + sample_config.which_gen_q + '--------------\n' + model_output_clean
        for model_output_clean in batch_model_outputs_clean]

    for s, (batch_bool, batch_score, batch_proc_output, batch_output_log) in enumerate(
            zip(batch_bools, batch_scores, batch_proc_outputs, batch_output_logs)):
        if not batch_bool:
            # time_dict = {'sub': sample_config.subj, 'sample_ts': sample_ts + str(s), 'model': sample_config.model_name,
            #              'temp': sample_config.temp, 'top_p': sample_config.top_p, 'status': batch_bool,
            #              'time': batch_time, 'instr_name': sample_config.instr_name, 'qs': sample_config.qs_name,
            #              'question': sample_config.which_q, 'score': batch_score,
            #              'nSamples': sample_config.remBatchSize}
            tmp_dict = {'sub': sample_config.subj, 'sample_ts': sample_ts + str(s),
                        'question': sample_config.which_q, 'question_gen': sample_config.which_gen_q,
                        'score': batch_score, 'response': batch_proc_output, 'model': sample_config.model_name,
                        'temp': sample_config.temp, 'top_p': sample_config.top_p, 'status': batch_bool,
                        'time': batch_time, 'instr_name': sample_config.instr_name, 'qs': sample_config.qs_name,
                        'gen_qs': sample_config.gen_qs_name, 'nSamples': sample_config.remBatchSize}
            tmp_output_pd = pd.DataFrame(tmp_dict, index=[0])  # , index=range(len(tmp_dict['score'])))
            # tmp_time_pd = pd.DataFrame(time_dict, index=[0])  # , index=range(len(time_dict['time'])))

            # save output to txt - FAILED
            with open(sample_config.outputs_path + 'fails/output_model_s-ts-' + sample_ts + str(s) + '.txt',
                      'w',
                      encoding='utf-8') as output_file:
                output_file.write(batch_output_log + '\n\n^^^^^^^^^\n\n' + tmp_output_pd.to_string())
            # save responses csv - FAILED
            tmp_output_pd.to_csv(
                sample_config.responses_path + 'fails/response_model_s-ts-' + sample_ts + str(s) + '.csv',
                index=False)
            # # save time csv - FAILED
            # tmp_time_pd.to_csv(sample_config.time_path + 'fails/time_s_ts-' + sample_ts + str(s) + '.csv',
            #                    index=False)
    # concat success and discard fail for hidden states of last token
    batch_proc_outputs = [batch_proc_output for batch_proc_output, batch_bool in zip(batch_proc_outputs, batch_bools) if
                          batch_bool]
    batch_scores = [batch_score for batch_score, batch_bool in zip(batch_scores, batch_bools) if batch_bool]
    batch_output_logs = [batch_output_log for batch_output_log, batch_bool in zip(batch_output_logs, batch_bools) if
                         batch_bool]
    batch_bools = [batch_bool for batch_bool in batch_bools if batch_bool]

    # save time, responses and outputs
    if all(batch_bools):
        # print('all good')
        for s, (batch_score, batch_proc_output, batch_output_log, batch_bool) in enumerate(
                zip(batch_scores, batch_proc_outputs, batch_output_logs, batch_bools)):
            # time_dict = {'sub': sample_config.subj, 'sample_ts': sample_ts + str(s), 'model': sample_config.model_name,
            #              'temp': sample_config.temp, 'top_p': sample_config.top_p, 'status': batch_bool,
            #              'time': batch_time, 'instr_name': sample_config.instr_name, 'qs': sample_config.qs_name,
            #              'question': sample_config.which_q, 'score': batch_score,
            #              'nSamples': sample_config.remBatchSize}
            tmp_dict = {'sub': sample_config.subj, 'sample_ts': sample_ts + str(s), 'question': sample_config.which_q,
                        'question_gen': sample_config.which_gen_q, 'score': batch_score, 'response': batch_proc_output,
                        'model': sample_config.model_name, 'temp': sample_config.temp, 'top_p': sample_config.top_p,
                        'status': batch_bool, 'time': batch_time, 'instr_name': sample_config.instr_name,
                        'qs': sample_config.qs_name, 'gen_qs': sample_config.gen_qs_name,
                        'nSamples': sample_config.remBatchSize}
            tmp_output_pd = pd.DataFrame(tmp_dict, index=[0])
            # tmp_time_pd = pd.DataFrame(time_dict, index=[0])

            batch_output_log = '\n^^^\n' + batch_output_log + '\n^^^\n'

            # save output to txt
            with open(sample_config.outputs_path + 'output_model_s-ts-' + sample_ts + str(s) + '.txt', 'w',
                      encoding='utf-8') as output_file:
                output_file.write(batch_output_log)

            # save responses csv
            tmp_output_pd.to_csv(
                sample_config.responses_path + 'response_model_s-ts-' + sample_ts + str(s) + '.csv', index=False)
            # # SAVE tiem csv
            # tmp_time_pd.to_csv(sample_config.time_path + 'time_s_ts-' + sample_ts + str(s) + '.csv', index=False)

    if sample_config.save_states and any(batch_bools):
        # print('states all good')
        tensor_fname = f"{sample_config.states_path}{sample_config.subj}^^{sample_config.which_q}^^{sample_config.which_gen_q}^^{sample_config.model_name_rhp}_hidden_states_s-ts-{sample_ts}.pt"
        torch.save(model_hidden_states.clone(), tensor_fname)

    del model_hidden_states
    flush()
    # return model_hidden_states


def process_model_response_batch_gen(model_outputs, sample_config):
    # if 'rep' in sample_config.which_q:
    #     q_level = 2
    # else:
    #     q_level = int(sample_config.which_q.split('_')[0][-1])

    model_outputs_proc = []
    model_outputs_scores = []
    model_outputs_bool_checks = []

    for model_output in model_outputs:
        bool_check = False
        tmp_score = -1
        # if q_level < 3:
        #     map_dict = lvlx_closed_map
        # else:
        if sample_config.gen_qs_name == 'ids30':
            map_dict = maps[sample_config.gen_qs_name][sample_config.which_gen_q]
        else:
            map_dict = maps[sample_config.gen_qs_name]

        model_output_proc = model_output
        lines = model_output.split('\n')
        count_lines = 0
        ans = []
        for l, line in enumerate(lines):
            line_ans = re.sub(r'(^\s+|\s+$)', '', line)  # remvoe extra spaces at the beginning/end
            if sample_config.gen_qs_name == 'ids30':
                line_ans = re.sub(r'[^A-Za-z0-9 ]+', '', line_ans)  # remove non-letters (and non-saces)
            else:
                line_ans = re.sub(r'[^A-Za-z ]+', '', line_ans)  # remove non-letters (and non-saces)

            if line_ans != '' and len(ans) == 0:
                count_lines += 1
                ans = [key for key in map_dict.keys() if key.lower() in line_ans.lower()]
            else:
                break
        if len(ans) > 0:
            ans_len = np.argsort([len(a) for a in ans])
            ans = [ans[ans_len[-1]]]
            # if count_lines == 1 and len(ans) == 1:
            model_output_proc = ans[0]
            tmp_score = map_dict[model_output_proc]
            bool_check = True
            if sample_config.gen_qs_name in rev_lists.keys():
                q_idx = int(re.sub(".*_q", "", sample_config.which_gen_q))
                to_rev = q_idx in rev_lists[sample_config.gen_qs_name]
                if to_rev:
                    tmp_score = min(map_dict.values()) + max(map_dict.values()) - tmp_score
        model_outputs_proc.append(model_output_proc)
        model_outputs_scores.append(tmp_score)
        model_outputs_bool_checks.append(bool_check)

    return model_outputs_proc, model_outputs_scores, model_outputs_bool_checks


def initialise_conv(sample_config, tokenizer, openq_data, q_data):
    # load instructions
    instr_dict, open_qs_dict, open_qs_rep_dict = load_task_content(sample_config)
    qs_instr = f"{instr_dict[sample_config.qs_name + '_instr']}"

    # load questionnaire questions
    qs_prompt_file = f"{sample_config.prompts_path}qs/{sample_config.qs_name}/{sample_config.qs_name}.txt"
    with open(qs_prompt_file, encoding='utf-8') as f:
        qs_prompt = f.read()

    # get questions
    prompt_sections = qs_prompt.split('\n^^\n')
    questions = prompt_sections[2].split('\n')

    questions = ['' + q + '\n' + "Answer on the scale:\n" + prompt_sections[1] + '\nAnswer: ' for q in
                 questions]  # scale vesion

    # creat first question prompt and intro
    first_prompt = '\n' + qs_instr + '\n\n' + questions[0]  # version with scale

    # create subject open-ended question and first qs question dict of conversations
    sub_conv_all = {}
    sub_conv_formatted_all = {}
    sub_conv_all_continued = {}
    sub_conv_formatted_all_continued = {}
    for sub in openq_data['sub'].unique():
        # print(sub)
        openq_data_sub = openq_data[openq_data['sub'] == sub]

        sub_conv = []
        sys_prompt = f"{instr_dict['welcome']}\n{instr_dict['instr1']}"

        sub_conv.append({'role': 'system', 'content': sys_prompt})
        if sample_config.model_name == 'gpt2':
            # q_range = 1
            q_range = 1000
        else:
            q_range = 1000

        for oq_key in sorted(open_qs_dict.keys())[0:q_range]:
            ## this will work if the question number in the level <10

            oq_instr = instr_dict['q_instr']
            oq = open_qs_dict[oq_key]

            qs_prompt = f"{oq_instr}\n{oq}"
            sub_conv.append({'role': 'user', 'content': qs_prompt})

            ans_prompt = openq_data_sub[oq_key].values[0]
            sub_conv.append({'role': 'assistant', 'content': ans_prompt})
            # sub_conv_formatted = format_messages(sub_conv, tokenizer, sample_config)

        sub_conv.append({'role': 'user', 'content': first_prompt})

        sub_conv_formatted = format_messages(sub_conv, tokenizer, sample_config)
        sub_conv_all[sub] = copy.deepcopy(sub_conv)
        sub_conv_formatted_all[sub] = sub_conv_formatted
        # print(sub_conv[-1])

        # continue adding questionnaire closed responses and the final repeated response
        for i, qs in enumerate(questions):
            sub_idx = q_data['sub'] == sub
            qs_response = q_data.loc[sub_idx, sample_config.qs_name + '_q' + str(i + 1)].values[0]
            if i > 0:
                sub_conv.append({'role': 'user', 'content': qs})
            sub_conv.append({'role': 'assistant', 'content': qs_response})

        for repq_key in sorted(open_qs_rep_dict.keys()):
            rep_q = open_qs_rep_dict[repq_key]
            rep_q_instr = instr_dict['q_instr']

            rep_q_prompt = f"{rep_q_instr}\n{rep_q}"
            sub_conv.append({'role': 'user', 'content': rep_q_prompt})

            sub_rep_q_idx = openq_data['sub'] == sub
            rep_a_prompt = openq_data.loc[sub_rep_q_idx, repq_key].values[0]
            sub_conv.append({'role': 'assistant', 'content': rep_a_prompt})

        # print(sub_conv[-1])
        sub_conv_formatted_continued = format_messages(sub_conv, tokenizer, sample_config)
        sub_conv_all_continued[sub] = sub_conv
        sub_conv_formatted_all_continued[sub] = sub_conv_formatted_continued

    return sub_conv_all, sub_conv_formatted_all, sub_conv_all_continued, sub_conv_formatted_all_continued, questions


def check_text_locations(sample_config, tokenizer, sub_conv_formatted_all_continued, openq_data):
    _, oqs, oqs_rep = load_task_content(sample_config)
    collect_unmatched = []
    for sub in ['sub' + str(s) for s in sorted([int(k[3:]) for k in sub_conv_formatted_all_continued.keys()])]:
        sample_config.subj = sub

        input_ids = tokenizer(sub_conv_formatted_all_continued[sub], return_tensors="pt", padding=True,
                              add_special_tokens=True).input_ids[0].tolist()

        for q_name in sorted(list(oqs.keys())):
            sub_response = openq_data[openq_data['sub'] == sub][q_name].values[0]
            oq = oqs[q_name]

            q_ids = \
                tokenizer(oq, return_tensors="pt", padding=True, add_special_tokens=False).input_ids[0].tolist()
            sub_response_ids = \
                tokenizer(sub_response, return_tensors="pt", padding=True, add_special_tokens=False).input_ids[
                    0].tolist()

            q_loc = find_subsequence_n(input_ids, q_ids)
            q_start = q_loc
            q_end = q_loc + len(q_ids)

            # print('q loc: ' + str(loc))
            dec_q = tokenizer.decode(input_ids[q_start:q_end])
            dec_eq = tokenizer.decode(input_ids[q_end - 1])

            if not dec_q == oq:
                tmp_dict = {'sub': sub, 'q_name': q_name, 'type': 'q'}
                collect_unmatched.append(tmp_dict)
                print(sub)
                print(q_name)
                print('Q: ' + oq)
                print('\t Q decoded: ' + dec_q)
                print('\t Q end: ' + dec_eq)
                print('\t Match?: ' + str(dec_q == oq))

            a_loc = find_subsequence_n(input_ids, sub_response_ids)
            a_start = a_loc
            a_end = a_loc + len(sub_response_ids)
            # print('r loc: ' + str(loc2))
            dec_r = tokenizer.decode(input_ids[a_start:a_end])
            dec_er = tokenizer.decode(input_ids[a_end - 1])
            if not dec_r == sub_response:
                tmp_dict = {'sub': sub, 'q_name': q_name, 'type': 'r', 'r': sub_response}
                collect_unmatched.append(tmp_dict)
                if dec_q == oq:
                    print(sub)
                    print(q_name)
                print('R: ' + sub_response)
                print('\t R decoded: ' + dec_r)
                print('\t R end : ' + dec_er)
                print('\t Match?: ' + str(dec_r == sub_response))

        for rep_q_name in sorted(list(oqs_rep.keys())):
            rep_q = oqs_rep[rep_q_name]
            sub_rep_q_idx = openq_data['sub'] == sub
            rep_q_response = openq_data.loc[sub_rep_q_idx, rep_q_name].values[0]

            rep_q_ids = \
                tokenizer(rep_q, return_tensors="pt", padding=True, add_special_tokens=False).input_ids[0].tolist()
            rep_sub_response_ids = \
                tokenizer(rep_q_response, return_tensors="pt", padding=True, add_special_tokens=False).input_ids[
                    0].tolist()

            rep_q_loc = find_subsequence_n(input_ids, rep_q_ids, 2)
            rep_q_start = rep_q_loc
            rep_q_end = rep_q_loc + len(rep_q_ids)

            dec_rep_q = tokenizer.decode(input_ids[rep_q_start:rep_q_end])
            dec_rep_eq = tokenizer.decode(input_ids[rep_q_end - 1])

            if not dec_rep_q == rep_q:
                tmp_dict = {'sub': sub, 'q_name': rep_q_name, 'type': 'rep_q'}
                collect_unmatched.append(tmp_dict)
                print(sub)
                print(rep_q_name)
                print('Q: ' + rep_q)
                print('\t Q decoded: ' + dec_rep_q)
                print('\t Q end: ' + dec_rep_eq)
                print('\t Match?: ' + str(dec_rep_q == rep_q))

            # print('Q: ' + rep_q)
            # print('\t Q decoded: ' + dec_rep_q)

            a_rep_loc = find_subsequence_n(input_ids, rep_sub_response_ids, 1)
            a_rep_start = a_rep_loc
            a_rep_end = a_rep_loc + len(rep_sub_response_ids)
            # print('r loc: ' + str(loc2))
            dec_rep_r = tokenizer.decode(input_ids[a_rep_start:a_rep_end])
            dec_rep_er = tokenizer.decode(input_ids[a_rep_end - 1])
            # print('R: ' + rep_q_response)
            # print('\t R decoded: ' + dec_rep_r)

            if not dec_rep_r == rep_q_response:
                tmp_dict = {'sub': sub, 'q_name': rep_q_name, 'type': 'r', 'r': rep_q_response}
                collect_unmatched.append(tmp_dict)
                if dec_rep_q == rep_q:
                    print(sub)
                    print(rep_q_response)
                print('R: ' + rep_q_response)
                print('\t R decoded: ' + dec_rep_r)
                print('\t R end : ' + dec_rep_er)
                print('\t Match?: ' + str(dec_rep_r == rep_q_response))

    collect_unmatched = pd.DataFrame(collect_unmatched)
    if collect_unmatched.empty:
        collect_unmatched = None
    return collect_unmatched


def get_sub_oq_locs(sub, input_text, tokenizer, openq_data, sample_config):
    _, oqs, _ = load_task_content(sample_config)
    texts_to_find = {}
    last_token_locs = {}
    for q_name in sorted(list(oqs.keys())):
        texts_to_find[q_name + '_ans'] = openq_data[openq_data['sub'] == sub][q_name].values[0]
        texts_to_find[q_name] = oqs[q_name]

        loc_ans_start, loc_ans_end, last_token_loc_ans = find_loc(input_text, texts_to_find[q_name + '_ans'], tokenizer)
        loc_q_start, loc_q_end, last_token_loc_q = find_loc(input_text, texts_to_find[q_name], tokenizer)

        last_token_locs[q_name + '_ans'] = last_token_loc_ans
        last_token_locs[q_name] = last_token_loc_q

    return last_token_locs


def process_model_response_conv_batch(model_outputs, sample_config, which_q):
    model_outputs_proc = []
    model_outputs_scores = []
    model_outputs_bool_checks = []
    # model_outputs_lengths = []
    # model_outputs_tokens = tokenizer(model_outputs, return_tensors='pt', padding=False, add_special_tokens=False,
    #                                  return_offsets_mapping=True).to(
    #     device_name)
    # model_outputs_char_map = model_outputs_tokens['offset_mapping']
    # model_outputs_tokens = model_outputs_tokens['input_ids']
    for model_output in model_outputs:
        # model_outputs_tokens, model_outputs_char_map):
        # check_token_ids = model_output_tokens.tolist() == model_output_ids.tolist()
        # print(model_output_tokens , model_output_ids)
        # print(check_token_ids)
        # if not check_token_ids:
        #     print('Output tokens ids do not match')
        # else:
        #     print('Tokens match')

        # ans_tokens = None
        # ans_len = None
        bool_check = False
        tmp_score = -1

        map_dict = maps[sample_config.qs_name]
        model_output_proc = model_output
        lines = model_output.split('\n')
        count_lines = 0
        ans = []
        for l, line in enumerate(lines):
            line_ans = re.sub(r'(^\s+|\s+$)', '', line)  # remvoe extra spaces at the beginning/end
            line_ans = re.sub(r'[^A-Za-z ]+', '', line_ans)  # remove non-letters (and non-saces)

            if line_ans != '' and len(ans) == 0:
                count_lines += 1
                ans = [key for key in map_dict.keys() if key.lower() in line_ans.lower()]
            else:
                break
        if len(ans) > 0:
            # if count_lines == 1 and len(ans) == 1:
            model_output_proc = ans[0]
            tmp_score = map_dict[model_output_proc]
            bool_check = True
            if sample_config.qs_name in rev_lists.keys():
                to_rev = which_q in rev_lists[sample_config.qs_name]
                if to_rev:
                    tmp_score = min(map_dict.values()) + max(map_dict.values()) - tmp_score

            # get token numbers until the end of answer
            # ans_start_idx = model_output.lower().find(model_output_proc.lower())
            # ans_end_idx = ans_start_idx + len(model_output_proc)

            # for t, r in enumerate(model_output_char_map):
            #     if ans_end_idx <= r[1]:
            #         ans_len = t + 1
            #         # ans_tokens = model_output_tokens[:ans_len]
            #         break

        # model_outputs_lengths.append(ans_len)
        model_outputs_proc.append(model_output_proc)
        model_outputs_scores.append(tmp_score)
        model_outputs_bool_checks.append(bool_check)
    # return model_output_proc, tmp_score, bool_check
    return model_outputs_proc, model_outputs_scores, model_outputs_bool_checks  # , model_outputs_lengths


def sample_messages_conv_batch(questions, batch_messages, batch_messages_formatted, model, tokenizer, sample_ts,
                               sample_config, last_token_locs):
    batch_times = [[]] * len(batch_messages)
    batch_bools = [[]] * len(batch_messages)
    batch_scores = [[]] * len(batch_messages)
    # batch_outputs_lengths = [[]] * len(batch_messages)
    # batch_hidden_states = [torch.tensor([]).to(device_name)] * len(batch_messages)
    batch_hidden_states = [{}] * len(batch_messages)
    batch_output_logs = [[]] * len(batch_messages)
    batch_q_inds = [[]] * len(batch_messages)
    batch_proc_outputs = [[]] * len(batch_messages)

    for i, qs in tqdm(enumerate(questions)):
        if len(batch_messages) > 0:
            batch_q_inds = [q_inds + [i + 1] for q_inds in batch_q_inds]
            # append questions if Q>1
            if i > 0:
                batch_messages = [batch_message + [{'role': 'user', 'content': qs}] for batch_message in batch_messages]
                batch_messages_formatted = [format_messages(batch_message, tokenizer, sample_config) for
                                            batch_message_formatted, batch_message in
                                            zip(batch_messages_formatted, batch_messages)]
            # sample responses
            # print('q: ',i+1,'\n')
            start = time.time()
            if i == 0:
                _, batch_model_outputs_ids, batch_model_outputs_ids_clean, _, batch_model_outputs_clean, batch_model_hidden_states, model_hidden_states_oq = get_model_response_batch(
                    batch_messages_formatted, model, tokenizer, sample_config, last_token_locs)
            else:
                _, batch_model_outputs_ids, batch_model_outputs_ids_clean, _, batch_model_outputs_clean, batch_model_hidden_states, model_hidden_states_oq = get_model_response_batch(
                    batch_messages_formatted, model, tokenizer, sample_config, None)

            elapsed = round(time.time() - start, 4)

            # process responses
            batch_model_outputs_proc, batch_model_scores, batch_model_bool_checks = process_model_response_conv_batch(
                batch_model_outputs_clean, sample_config, i + 1)

            # save outputs, socres, times, etc
            batch_output_logs = [
                output_log + ['---------------q: ' + str(i + 1) + '--------------'] + [model_output_clean]
                for
                output_log, model_output_clean in zip(batch_output_logs, batch_model_outputs_clean)]
            batch_scores = [batch_score + [batch_model_score] for batch_score, batch_model_score in
                            zip(batch_scores, batch_model_scores)]
            # batch_outputs_lengths = [batch_output_length + [batch_model_output_length] for
            #                          batch_output_length, batch_model_output_length in
            #                          zip(batch_outputs_lengths, batch_model_outputs_lengths)]
            batch_bools = [batch_bool + [batch_model_bool] for batch_bool, batch_model_bool in
                           zip(batch_bools, batch_model_bool_checks)]
            batch_times = [batch_time + [elapsed] for batch_time in batch_times]
            batch_proc_outputs = [batch_proc_output + [batch_model_output_proc] for
                                  batch_proc_output, batch_model_output_proc in
                                  zip(batch_proc_outputs, batch_model_outputs_proc)]

            # append responses
            batch_messages = [batch_message + [{'role': 'assistant', 'content': batch_model_output_proc}] for
                              batch_message, batch_model_output_proc
                              in zip(batch_messages, batch_model_outputs_proc)]
            batch_messages_formatted = [format_messages(batch_message, tokenizer, sample_config) for
                                        batch_message_formatted, batch_message in
                                        zip(batch_messages_formatted, batch_messages)]

            # save fails outputs and time/response dataframes
            for s, (batch_bool, batch_q_ind, batch_score, batch_time, batch_output_log, batch_message,
                    batch_message_formatted, batch_output_log) in enumerate(
                zip(batch_bools, batch_q_inds, batch_scores, batch_times, batch_output_logs, batch_messages,
                    batch_messages_formatted, batch_output_logs)):
                if not batch_bool[-1]:
                    time_dict = {'sample_ts': sample_ts + str(s), 'model': sample_config.model_name,
                                 'temp': sample_config.temp, 'top_p': sample_config.top_p, 'status': batch_bool,
                                 'time': batch_time, 'instr_name': sample_config.instr_name,
                                 'qs': sample_config.qs_name, 'question': batch_q_ind, 'score': batch_score,
                                 # 'prefix_name': sample_config.prefix_name,
                                 'nSamples': sample_config.remBatchSize}
                    # ,
                    #              'context': sample_config.context_name,
                    #              'context_instance': sample_config.context_instance,
                    #              'context_prompt': sample_config.context_prompt_name}
                    tmp_dict = {'sample_ts': sample_ts + str(s), 'question': batch_q_ind, 'score': batch_score,
                                'model': sample_config.model_name, 'temp': sample_config.temp,
                                'top_p': sample_config.top_p, 'status': batch_bool,
                                'instr_name': sample_config.instr_name, 'qs': sample_config.qs_name,
                                # 'prefix_name': sample_config.prefix_name,
                                'nSamples': sample_config.remBatchSize}
                    # ,
                    #         'context': sample_config.context_name,
                    #         'context_instance': sample_config.context_instance,
                    #         'context_prompt': sample_config.context_prompt_name}
                    tmp_output_pd = pd.DataFrame(tmp_dict, index=range(len(tmp_dict['score'])))
                    tmp_time_pd = pd.DataFrame(time_dict, index=range(len(time_dict['time'])))

                    batch_output_log = '\n^^^\n' + '\n'.join(batch_output_log) + '\n^^^\n'
                    if type(batch_message_formatted) != list:
                        txt_content = batch_message_formatted + '\n\n' + json.dumps(batch_message, indent=1)
                    else:
                        txt_content = json.dumps(batch_message, indent=1)
                    txt_content = batch_output_log + txt_content

                    # save output to txt - FAILED
                    with open(sample_config.outputs_path + 'fails/output_model_s-ts-' + sample_ts + str(s) + '.txt',
                              'w',
                              encoding='utf-8') as output_file:
                        output_file.write(txt_content + '\n' + tmp_output_pd.to_string())
                    # save responses csv - FAILED
                    tmp_output_pd.to_csv(
                        sample_config.responses_path + 'fails/response_model_s-ts-' + sample_ts + str(s) + '.csv',
                        index=False)
                    # save time csv - FAILED
                    tmp_time_pd.to_csv(sample_config.time_path + 'fails/time_s_ts-' + sample_ts + str(s) + '.csv',
                                       index=False)

            # concat success and discard fail for hidden states of last token
            if sample_config.save_states:
                # batch_hidden_states = [torch.cat((batch_hidden_state, batch_model_hidden_state)) for
                #                        batch_hidden_state, batch_model_hidden_state, batch_model_bool_check in
                #                        zip(batch_hidden_states, batch_model_hidden_states, batch_model_bool_checks) if
                #                        batch_model_bool_check]

                for batch_model_hidden_state, batch_model_bool_check in zip(batch_model_hidden_states,
                                                                            batch_model_bool_checks):
                    if batch_model_bool_check:
                        tmp_dict = {}
                        for k, v in batch_model_hidden_state.items():
                            tmp_dict[k] = {'q' + str(i): v}
                        batch_hidden_states.append(tmp_dict)

                # batch_hidden_states = [batch_model_hidden_state for
                #                        batch_hidden_state, batch_model_hidden_state, batch_model_bool_check in
                #                        zip(batch_hidden_states, batch_model_hidden_states, batch_model_bool_checks) if
                #                        batch_model_bool_check]
            # discard fails
            batch_messages = [batch_message for batch_model_bool_check, batch_message in
                              zip(batch_model_bool_checks, batch_messages) if batch_model_bool_check]
            batch_messages_formatted = [batch_message_formatted for batch_model_bool_check, batch_message_formatted in
                                        zip(batch_model_bool_checks, batch_messages_formatted) if
                                        batch_model_bool_check]
            # batch_outputs_lengths = [batch_output_length for batch_model_bool_check, batch_output_length in
            #                          zip(batch_model_bool_checks, batch_outputs_lengths) if
            #                          batch_model_bool_check]
            batch_times = [batch_time for batch_model_bool_check, batch_time in
                           zip(batch_model_bool_checks, batch_times) if
                           batch_model_bool_check]
            batch_scores = [batch_score for batch_model_bool_check, batch_score in
                            zip(batch_model_bool_checks, batch_scores) if batch_model_bool_check]
            batch_output_logs = [batch_output_log for batch_model_bool_check, batch_output_log in
                                 zip(batch_model_bool_checks, batch_output_logs) if batch_model_bool_check]
            batch_q_inds = [batch_q_ind for batch_model_bool_check, batch_q_ind in
                            zip(batch_model_bool_checks, batch_q_inds) if batch_model_bool_check]
            batch_proc_outputs = [batch_proc_output for batch_model_bool_check, batch_proc_output in
                                  zip(batch_model_bool_checks, batch_proc_outputs) if batch_model_bool_check]
            batch_bools = [batch_bool for batch_model_bool_check, batch_bool in
                           zip(batch_model_bool_checks, batch_bools) if
                           batch_model_bool_check]

    # save time, responses and outputs
    for s, (batch_bool, batch_q_ind, batch_score, batch_time, batch_output_log, batch_message,
            batch_message_formatted, batch_hidden_state) in enumerate(
        zip(batch_bools, batch_q_inds, batch_scores, batch_times, batch_output_logs, batch_messages,
            batch_messages_formatted, batch_hidden_states)):
        if all(batch_bool):
            print('all good')
            time_dict = {'sample_ts': sample_ts + str(s), 'model': sample_config.model_name,
                         'temp': sample_config.temp, 'top_p': sample_config.top_p, 'status': batch_bool,
                         'time': batch_time, 'instr_name': sample_config.instr_name,
                         'qs': sample_config.qs_name, 'question': batch_q_ind, 'score': batch_score,
                         # 'prefix_name': sample_config.prefix_name,
                         'nSamples': sample_config.remBatchSize}
            # ,
            #          'context': sample_config.context_name,
            #          'context_instance': sample_config.context_instance,
            #          'context_prompt': sample_config.context_prompt_name}
            tmp_dict = {'sample_ts': sample_ts + str(s), 'question': batch_q_ind, 'score': batch_score,
                        'model': sample_config.model_name, 'temp': sample_config.temp,
                        'top_p': sample_config.top_p, 'status': batch_bool,
                        'instr_name': sample_config.instr_name, 'qs': sample_config.qs_name,
                        # 'prefix_name': sample_config.prefix_name,
                        'nSamples': sample_config.remBatchSize}
            # ,
            #         'context': sample_config.context_name,
            #         'context_instance': sample_config.context_instance,
            #         'context_prompt': sample_config.context_prompt_name}
            tmp_output_pd = pd.DataFrame(tmp_dict, index=range(len(tmp_dict['score'])))
            tmp_time_pd = pd.DataFrame(time_dict, index=range(len(time_dict['time'])))

            batch_output_log = '\n^^^\n' + '\n'.join(batch_output_log) + '\n^^^\n'
            if type(batch_message_formatted) != list:
                txt_content = batch_message_formatted + '\n\n' + json.dumps(batch_message, indent=1)
            else:
                txt_content = json.dumps(batch_message, indent=1)
            txt_content = batch_output_log + txt_content

            # save output to txt
            with open(sample_config.outputs_path + 'output_model_s-ts-' + sample_ts + str(s) + '.txt', 'w',
                      encoding='utf-8') as output_file:
                output_file.write(txt_content)
            # save responses csv
            tmp_output_pd.to_csv(
                sample_config.responses_path + 'response_model_s-ts-' + sample_ts + str(s) + '.csv', index=False)
            # SAVE tiem csv
            tmp_time_pd.to_csv(sample_config.time_path + 'time_s_ts-' + sample_ts + str(s) + '.csv', index=False)

    # save hidden states
    # if sample_config.save_states:
    #     for s, (batch_bool, batch_hidden_state) in enumerate(zip(batch_bools, batch_hidden_states)):
    #         if all(batch_bool):
    #             print('states all good')
    #             tensor_fname = f"{sample_config.states_path}hidden_states_s-ts-{sample_ts}{s}.pt"
    #             torch.save(batch_hidden_state.clone(), tensor_fname)
    # del batch_hidden_states
    flush()
    return batch_hidden_states, model_hidden_states_oq


def get_init_msgs(sample_config, tokenizer):
    init_msg = []

    # load system prompt - instruction about the task
    sys_prompt_file = f"{sample_config.prefix_path}prefix_{sample_config.prefix_name}.txt"
    with open(sys_prompt_file, encoding='utf-8') as sys_prompt_file_object:
        sys_prompt = sys_prompt_file_object.read()

    # initialise messages with system prompt
    init_msg.append({'role': 'system', 'content': sys_prompt})

    # load questionnaire prompt
    user_prompt_file = '/'.join(
        ['prompts', 'qs', sample_config.qs_name, sample_config.prompt_type, sample_config.prompt_type + '.txt'])
    with open(user_prompt_file, encoding='utf-8') as user_prompt_file_object:
        user_prompt = user_prompt_file_object.read()

    begin_study = 'The instruction phase has now finished and the study starts now.\n'
    begin_qs = 'Now, please read and then answer the following statements one at a time.\n\n'

    if sample_config.prompt_type == 'prompt_conv':
        # conversational prompting
        store_parts = user_prompt.split('\n^^\n')
        questions = store_parts[2].split('\n')

        # prepare first question user prompt
        if 'conv_scale' in sample_config.prefix_name:
            # case if scale is included
            questions = ['' + q + '\n' + "Answer on the scale:\n" + store_parts[1] + '\nAnswer: ' for q in
                         questions]  # scale vesion

            first_prompt = '\n' + begin_qs + store_parts[0] + '\n\n' + questions[0]  # version with scale
        else:
            # no scale included
            questions = [q + '\n' for q in questions]

            first_prompt = '\n' + begin_qs + store_parts[0] + '\n\nRemember to answer on the scale below:\n' + \
                           store_parts[1] + '\n\n' + questions[0]  # basic version

    else:
        first_prompt = '\n\n' + user_prompt
        questions = None

    if sample_config.context_name != "none":
        # load context prompts
        context_prompt_file = '/'.join(
            ['prompts', 'context', sample_config.context_prompt_name, sample_config.context_prompt_name + '.txt'])
        with open(context_prompt_file, encoding='utf-8') as context_prompt_file_object:
            context_prompt = context_prompt_file_object.read()

        # add begin study with context
        init_msg.append({'role': 'user', 'content': begin_study + context_prompt + '\n'})

        context_file = '/'.join(['prompts', 'context', f"{sample_config.context_name}_context",
                                 f"{sample_config.context_full_name}.txt"])
        with open(context_file, encoding='utf-8') as context_file_object:
            context = context_file_object.read()

        # add the context condition
        init_msg.append({'role': 'assistant', 'content': context})
        init_msg.append({'role': 'user', 'content': first_prompt})
    else:
        # add begin study without context
        init_msg.append({'role': 'user', 'content': begin_study + first_prompt})

    # add the first question
    # init_msg.append({'role': 'user', 'content': first_prompt})
    init_msg_formatted = format_messages(init_msg, tokenizer, sample_config)
    # init_msg_formatted = None

    return init_msg, init_msg_formatted, questions
