"""

[] Create a new prompt for logit sampling
    A
    B
    C
    D
    ...

[] System prompt
[] Permute the labels (EXTRA)
[] Logits for 5 labels - (Save logits for the whole vocab) - Sample offline


[] Save logits for each sub, qs
[] Save hidden state at every layer at selected locations
[] Sample 50 from categorical

[] Change hidden states before logits obtained

[] Double check special tokens for each model to be used

[] Multinomial sampling offline across allowed tokens


[] Apply to individual phq9-questions
[] Apply to generalisation questions
[]

"""

# %%
import copy
from torch.distributions import Categorical
import re
import random

import matplotlib
from tqdm import tqdm
import os
import pandas as pd
import numpy as np
from transformers import AutoTokenizer
import torch.nn.functional as F

from baukit import Trace, TraceDict

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# set paths
path_to_go = '_analysis'
llm_path = '_llm_based/'
if os.uname()[0] == 'Darwin':  # if on mac
    cwd = os.getcwd()
    cwd_split = cwd.split('/')
    base_task = 'qs_structure'
    cwd_base = '/'.join(cwd_split[:np.argwhere([p == "online_tasks" for p in cwd_split])[0][0] + 1])
    new_path = f"{cwd_base}/{base_task}/{path_to_go}/"
    os.chdir(new_path)

    # task_version = 'qs-structure-phq9-v2'
    # analysis_dir = '_analysis'
    # cwd_base = '/'.join(cwd_split[:np.argwhere([p == "online_tasks" for p in cwd_split])[0][0] + 1])
    # new_path = f"{cwd_base}/{base_task}/{task_version}/{analysis_dir}"
    # os.chdir(new_path)

model = None
if os.uname()[0] == 'Darwin':  # if on mac
    matplotlib.use('Qt5Agg')
    # sys.path.append(new_path)
    device_name = 'mps'
    # device_name = 'cpu'
    model_name = 'GPT2'
    model_name = 'gemma2-2b-it'
    model_name = 'gemma2-9b-it'
    # model_name = 'llama32-3b-it'
    # model_name = 'llama31-8b-it'
    # model_name = 'MistralOo'
    # model_name = 'GPT2'
    model_names = ['gemma2-2b-it']
else:
    matplotlib.use('Agg')
    device_name = 'cuda'
    model_name = 'MistralOo'
    # model_name = 'gemma2-2b-it'
    # model_name = 'GPT2'
    model_names = ['MistralOo','gemma2-2b-it','llama32-3b-it','gemma2-9b-it','llama31-8b-it']

instr_name_str = 'instr3'

# if 'llama3'in model_name:
#     instr_name_str = 'instr3'

import matplotlib.pyplot as plt

from _llm_based.utils.logit_utils import setup_model, load_task_content, get_all_question_pairs, get_sub_locs, \
    sample_model_responses_batch, flush, maps, rev_lists, phq9_qs_inv_map
from _llm_based.utils.logit_utils import create_question_pairs_instr, SteerConfig, SteerHiddenState

# from utils.utils import upload_blob_from_memory, download_blob
# download_blob('llm-bucket-res', '_data.zip','_data.zip')
from objects.configs import *
from objects.plot_config import *
from _llm_based.objects.data_configs import *
from _llm_based.objects.model_configs import *

torch.set_grad_enabled(False)

# %%
paths = Paths(files_dir=f"{llm_path}files_logits", sub_path=llm_path, plots_subdir='', plots_subsubdir='',
              files_data_dir='_data')
bools = Bools()

pc = PlotConfig()
qs_config = QsConfig()

# initalise LLM sampling config
to_sample_qs = 'phq9'

for model_name in tqdm(model_names):
    print(f'Model: {model_name}')
    # sample_config = SampleConfig(paths, model_name=model_name, qs_name=to_sample_qs, instr_name='instr2', temp='', top_p='')
    sample_config = SampleLogitsConfig(paths, model_name=model_name, qs_name=to_sample_qs, instr_name=instr_name_str,
                                       temp='',
                                       top_p='')
    end_idx = -1
    sample_config.nSamples = 4
    sample_config.batchSize = 1
    sample_config.save_states = True
    sample_config.gen_fname = 'gb'
    sample_config.permute_labels = False
    sample_config.label_letters = None
    do_bucket = False

    label_letters = list(string.ascii_uppercase)[:qs_config.qs_n_lab[to_sample_qs]]
    q_scores = list(maps[to_sample_qs].values())
    sample_config.q_responses = {l: k for l, (k, v) in zip(label_letters, maps[to_sample_qs].items())}
    # q_responses = list(maps[to_sample_qs].keys())

    if sample_config.permute_labels:
        tmp_zip = list(zip(q_scores, label_letters))
        random.shuffle(tmp_zip)
        q_scores, label_letters = zip(*tmp_zip)
    sample_config.label_letters = label_letters
    sample_config.q_scores = q_scores
    sample_config.label_scores = {k: v for k, v in zip(label_letters, q_scores)}
    sample_config.qs_labels = [sample_config.q_responses[l] for l in sample_config.label_letters]
    # %%
    if model is None:
        model, tokenizer = setup_model(sample_config)
        # model = model.model

    # prepare tokenizer
    # tokenizer = AutoTokenizer.from_pretrained(sample_config.model_name_hf, use_fast=True)
    # tokenizer.pad_token_id = tokenizer.eos_token_id
    # toker = tokenizer
    # %%


    # # %% Steering stuff
    # def get_steering_vecs(layer_ids_to_steer):
    #     love_id = toker("Love", return_tensors='pt').input_ids.to(device_name)
    #     hate_id = toker("Hate", return_tensors='pt').input_ids.to(device_name)
    #     steering_vectors = {}  # Dictionary to hold a steering vector for each layer
    #     print("--- Creating Steering Vectors ---")
    #     for layer_id in layer_ids_to_steer:
    #         module_to_steer = model.model.layers[layer_id]
    #         with Trace(module_to_steer) as ret:
    #             # Pass "Love" through the model
    #             _ = model(love_id)
    #             act_love = ret.output[0][0, -1, :]
    #
    #             # Pass "Hate" through the model
    #             _ = model(hate_id)
    #             act_hate = ret.output[0][0, -1, :]
    #
    #         # The steering vector is the difference.
    #         steering_vec = (act_love - act_hate).to(device_name)
    #
    #         # Store the vector with a key corresponding to the layer name
    #         layer_name = f'model.layers.{layer_id}'
    #         steering_vectors[layer_name] = steering_vec
    #
    #         print(f"  - Vector created for layer {layer_id} with shape: {steering_vec.shape}")
    #     return steering_vectors
    #
    #
    # layer_ids = [i - 1 for i in sample_config.layer_list[-4:]]
    # # layer_ids = list(np.arange(16, 26, 2)) + [25]
    # # steering_vectors = {f'model.layers.{layer_id}': [] for layer_id in layer_ids}
    # steering_vectors = get_steering_vecs(layer_ids)
    # inputs = ["I think dogs are", "Cats are"]
    # # steer_config = SteerConfig(layer_ids=layer_ids, steering_vectors=steering_vectors, device='mps', multiplier=-1.5, run_gen=True,
    # #                            sample_text=False, n_tokens=50, perturb_input_only=True,norm_vecs=True)
    # # steer_config = SteerConfig(layer_ids=layer_ids, steering_vectors=steering_vectors, device='mps', multiplier=-100.5,
    # #                            run_gen=True, sample_text=False, perturb_input_only=True, norm_vecs=True,n_tokens=10)
    # steer_config = SteerConfig(layer_ids=layer_ids, steering_vectors=steering_vectors, device='mps', multiplier=-100.5,
    #                            run_gen=False, sample_text=False, perturb_input_only=True, norm_vecs=True, n_tokens=10)
    # steer_hs = SteerHiddenState(steer_config, model, tokenizer)
    # steer_hs.steer(model, inputs, return_org=True)
    #
    # # print(f'Original: {[steer_hs.hidden_states_original[l][:, -1, :].max().item() for l in [10, 12, 14] + layer_ids]}')
    # # print(f'Steered: {[steer_hs.hidden_states[l][:, -1, :].max().item() for l in [10, 12, 14] + layer_ids]}')
    # # print(f'Original gen: {[steer_hs.hidden_states_gen_original[0][l][:, -1, :].max().item() for l in [10, 12, 14] + layer_ids]}')
    # # print(f'Steered gen: {[steer_hs.gen_hid[0][l][:, -1, :].max().item() for l in [10, 12, 14] + layer_ids]}')
    # # print('\n')
    # # for i in range(len(steer_hs.gen_hid)):
    # #     print(f'Gen token {i}')
    # #     print(
    # #         f'\tOriginal gen: {[steer_hs.hidden_states_gen_original[i][l][:, -1, :].max().item() for l in [10, 12, 14] + layer_ids]}')
    # #     print(f'\tSteered gen: {[steer_hs.gen_hid[i][l][:, -1, :].max().item() for l in [10, 12, 14] + layer_ids]}')
    # #
    # # steer_hs.hidden_states_original[1][:,-1,:]-steering_vectors['model.layers.16']
    # # steer_hs.hidden_states[0][:,-1,:]-steering_vectors['model.layers.16'].unsqueeze(0)
    # %%
    # %% Load data
    instr_dict, open_qs_dict, _ = load_task_content(sample_config)
    # instr_dict, open_qs_dict, open_spec_qs_dict, closed_qs_dict, _ = load_task_content(sample_config)
    oq_names = list(open_qs_dict.keys())  # + list(open_qs_rep_dict.keys())
    phq9_data = pd.read_csv(f"{paths.files_data_dir}phq9_data.csv")
    phq9_names = [c for c in phq9_data.columns if 'phq9_q' in c]
    openq_data = pd.read_csv(f"{paths.files_data_dir}openq_data.csv")

    # subset tasks
    task_versions = ['v4', 'v4_d', 'v4_dd', 'v4_ddd']
    openq_data = openq_data[openq_data['task_version'].isin(task_versions)]
    phq9_data = phq9_data[phq9_data['task_version'].isin(task_versions)]

    # intro_prompt, oq_instr, open_qs_dict, closed_questions_dict, cq_preamble, qsn_questions_dict, cq_instr, qsn_preamble = create_question_pairs_instr(sample_config)
    # preprocess text
    openq_data = openq_data.replace(r'\s+\.', '.', regex=True)
    openq_data = openq_data.replace(r'\s+,', ', ', regex=True)
    openq_data = openq_data.replace(r'\n+,', ' ', regex=True)
    df_obj = openq_data.select_dtypes('object')
    openq_data[df_obj.columns] = df_obj.apply(lambda x: x.str.strip())
    openq_data_long = pd.melt(openq_data, id_vars=['sub'], value_vars=oq_names, var_name='q_name').sort_values(
        by=['sub', 'q_name']).reset_index(drop=True)
    phq9_data_long = pd.melt(phq9_data, id_vars=['sub'], value_vars=phq9_names, var_name='q_name')

    # openq_data_long['len'] = (
    empty_resp_idx = openq_data_long['value'].apply(lambda x: len(x.split(' ')) if type(x) == str else 0) == 0
    openq_data_long = openq_data_long.loc[~empty_resp_idx, :]
    # openq_data_long['value'] = openq_data_long['value'].str.replace(r'[^a-zA-Z0-9]+$', '', regex=True)

    lvl1_closed_data = pd.read_csv(f"{paths.files_data_dir}lvl1_closed_data.csv")
    lvl2_closed_data = pd.read_csv(f"{paths.files_data_dir}lvl2_closed_data.csv")
    del df_obj, empty_resp_idx, instr_dict, oq_names, phq9_names

    ## include specificity level for version 3
    # openq_data_long['spec_level'] = 'gen'
    # openq_data_long.loc[openq_data_long['sub'].str.contains('v3_s'), 'spec_level'] = openq_data_long.loc[
    #     openq_data_long['sub'].str.contains('v3_s'), 'q_name']
    #
    # openq_data_long.loc[openq_data_long['sub'].str.contains('v3_s'), 'spec_level'] = \
    #     openq_data_long[openq_data_long['sub'].str.contains('v3_s')]['spec_level'].replace(spec_levels['v3_s'])
    intro_prompt, oq_instr, open_qs_dict, closed_questions_dict, cq_preamble, qsn_questions_dict, cq_instr, qsn_preamble = create_question_pairs_instr(
        sample_config)

    instr_dict, open_qs_dict, closed_qs_dict = load_task_content(sample_config)
    question_pairs, question_pairs_formatted, _, _ = get_all_question_pairs(openq_data_long, sample_config, tokenizer)
    # %%

    sample_config.nSamples = 50
    sample_config.batchSize = 1  # 20
    extra_save = False
    for sub in tqdm(list(question_pairs.keys())[0:end_idx]):
        sample_config.subj = sub
        sub_texts_to_find, sub_last_token_locs = get_sub_locs(openq_data_long, question_pairs_formatted, tokenizer,
                                                              sample_config)
        for k, v in sub_last_token_locs.items():
            for k2, v2 in v.items():
                if v2 == -1:
                    print(sub, k, k2, 'missing location')

        for which_q in tqdm(list(question_pairs[sub].keys())[0:end_idx]):
            # sample_ts = str(round(datetime.timestamp(datetime.now()) * 10000)) + '_'
            sample_ts = str(round(datetime.timestamp(datetime.now()) * 10000))
            print(sub, which_q)
            sample_config.save_states = True
            sample_config.which_q = which_q
            sample_config.get_remaining_samples()
            while not sample_config.currentFiles:
                # sample_config.get_remaining_samples()
                # prepare subject data and locations
                # openq_data_long[(openq_data_long['sub'] == sample_config.subj)]['spec_level'].values[0]

                sub_q_token_locs = sub_last_token_locs[sample_config.which_q]
                sub_qs_pairs = [question_pairs_formatted[sample_config.subj][
                                    sample_config.which_q]]

                # inputs_ids = tokenizer(sub_qs_pairs, return_tensors="pt", padding=True).input_ids.to(device_name)
                inputs_ids = tokenizer(sub_qs_pairs, padding=True, return_tensors='pt', return_attention_mask=True).to(
                    device_name)

                model_output = model(inputs_ids.input_ids, attention_mask=inputs_ids.attention_mask, output_hidden_states=True)
                hs_ts = torch.stack(model_output.hidden_states)
                hs_ts_sub = torch.stack([hs_ts[:, :, loc, :] for loc in list(sub_q_token_locs.values()) + [hs_ts.shape[2] - 1]],
                                        dim=2).squeeze(1)
                del hs_ts


                logits = model_output.logits[:, -1, :]

                label_ids = tokenizer(sample_config.label_letters, padding=True, return_tensors='pt',
                                      return_attention_mask=True).to(device_name)
                label_ids = label_ids.input_ids[:, 1:]
                # label_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=False)

                # top10 = logits.argsort(axis=1, descending=True)[:, :10]
                # tokenizer.batch_decode(top10)

                label_logits = logits[:, label_ids].squeeze(-1).detach().cpu().type(torch.float32)
                # categorical_dist_logits = Categorical(logits=label_logits)
                label_probs = F.softmax(label_logits, dim=-1).numpy()[0, :]
                # sampled_label = sample_config.label_letters[categorical_dist_logits.sample()]
                sampled_labels = np.random.choice(sample_config.label_letters, size=sample_config.nSamples, p=label_probs)
                sampled_scores = [sample_config.label_scores[sl] for sl in sampled_labels]
                sampled_responses = [sample_config.q_responses[sl] for sl in sampled_labels]

                tmp_dict = {'sub': sample_config.subj, 'sample_ts': [f'{sample_ts}_{s}' for s in range(sample_config.nSamples)],
                            'question': sample_config.which_q,
                            'score': sampled_scores, 'response': sampled_responses, 'model': sample_config.model_name,
                            'instr_name': sample_config.instr_name, 'qs': sample_config.qs_name,
                            'label_perm': ''.join(sample_config.label_letters),
                            'nSamples': sample_config.nSamples}
                tmp_pd = pd.DataFrame(tmp_dict)

                tmp_logits_dict = {'sub': sample_config.subj, 'sample_ts': sample_ts, 'question': sample_config.which_q,
                                   'logits': label_logits[0], 'probs': label_probs,
                                   'label_letters': sample_config.label_letters, 'qs_labels': sample_config.qs_labels,
                                   'label_scores': sample_config.q_scores, 'model': sample_config.model_name,
                                   'instr_name': sample_config.instr_name, 'qs': sample_config.qs_name,
                                   'label_perm': ''.join(sample_config.label_letters),
                                   'nSamples': sample_config.nSamples}
                tmp_logits_pd = pd.DataFrame(tmp_logits_dict)

                tmp_fname = f"{sample_config.sample_path}{sample_config.subj}^^{sample_config.which_q}^^{''.join(sample_config.label_letters)}^^{sample_config.model_name_rp}"
                hs_fname = f"{tmp_fname}_hidden_states_s-ts-{sample_ts}0.pt"
                logits_fname = f"{tmp_fname}_logits_s-ts-{sample_ts}.csv"
                responses_fname = f"{tmp_fname}_responses_s-ts-{sample_ts}.csv"

                tmp_pd.to_csv(f"{responses_fname}", index=False)
                tmp_logits_pd.to_csv(f"{logits_fname}", index=False)
                torch.save(hs_ts_sub.clone(), hs_fname)
                # tmp_logits_pd.to_csv(f"{sample_config.sample_path}logits_model_s-ts-{sample_ts}.csv", index=False)
                flush()
                sample_config.get_remaining_samples()

    del model, tokenizer
    flush()
    model = None

try:
    os.system("sudo shutdown -h now")
except:
    print('shutdown issue')
