import copy
import re

import matplotlib
from tqdm import tqdm
import os
import pandas as pd
import numpy as np
from transformers import AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.set_grad_enabled(False)
# set paths
path_to_go = '_analysis'
llm_path = '_llm_based/'
if os.uname()[0] == 'Darwin':  # if on mac
    cwd = os.getcwd()
    cwd_split = cwd.split('/')
    base_task = 'qs_intervention'
    cwd_base = '/'.join(cwd_split[:np.argwhere([p == "online_tasks" for p in cwd_split])[0][0] + 1])
    new_path = f"{cwd_base}/{base_task}/{path_to_go}/"
    os.chdir(new_path)

    # task_version = 'qs-structure-phq9-v2'
    # analysis_dir = '_analysis'
    # cwd_base = '/'.join(cwd_split[:np.argwhere([p == "online_tasks" for p in cwd_split])[0][0] + 1])
    # new_path = f"{cwd_base}/{base_task}/{task_version}/{analysis_dir}"
    # os.chdir(new_path)

model = None
if os.uname()[0] == 'Darwin':  # if on mac
    matplotlib.use('Qt5Agg')
    # sys.path.append(new_path)
    device_name = 'mps'
    # device_name = 'cpu'
    model_name = 'GPT2'
    model_name = 'gemma2-2b-it'
    # model_name = 'MistralOo'
    # model_name = 'MistralOo'
    # model_name = 'MistralOo'
else:
    matplotlib.use('Agg')
    device_name = 'cuda'
    model_name = 'MistralOo'
    model_name = 'gemma2-9b-it'

import matplotlib.pyplot as plt

# from _llm_based.utils.sampling_utils import setup_model, load_task_content, get_all_question_pairs, get_sub_locs, \
#     sample_model_responses_batch, flush, spec_levels

from _llm_based.utils.hs_utils import *

# from _llm_based.utils.sampling_utils import format_messages

from utils.utils import upload_blob_from_memory, download_blob
# download_blob('llm-bucket-res', '_data_int.zip','_data_int.zip')
from objects.configs import *
from objects.plot_config import *
from _llm_based.objects.data_configs import *
from _llm_based.objects.model_configs import *

# %%
paths = Paths(files_dir=f"{llm_path}/files", sub_path=llm_path, plots_subdir='', plots_subsubdir='',
              files_data_dir='_data')
bools = Bools()

pc = PlotConfig()
qs_config = QsConfig()

# initalise LLM sampling config
to_sample_qs = 'phq9'
sample_config = SampleConfig(paths, model_name=model_name, qs_name=to_sample_qs, instr_name='none', temp='', top_p='')
sample_config.nSamples = 4
sample_config.batchSize = 1
sample_config.save_states = True
sample_config.hsT = 20

sample_config.gen_fname = 'gb'
do_bucket = False

to_idx=-1
# %%
if model is None:
    model, tokenizer = setup_model(sample_config)
    model.to(device_name)

# # # prepare tokenizer
# tokenizer = AutoTokenizer.from_pretrained(sample_config.model_name_hf, use_fast=True)
# tokenizer.pad_token_id = tokenizer.eos_token_id

# %% Load data
task_versions = ['v2', 'v2b', 'v3']
store_int_dfs = []
for task_version in task_versions:
    file_path = f"_data/qs-intervention-{task_version}/processed/int_data.csv"
    int_data_pd = pd.read_csv(file_path).drop(columns=['task_version'])  # [int_cols]
    store_int_dfs.append(int_data_pd)
int_data = pd.concat(store_int_dfs).reset_index(drop=True)
int_data = int_data[int_data['group'] == 'D']
int_data = int_data[~int_data.isna().any(axis=1)]

act_data = int_data[['sub', 'condition', 'group', 'responses_act_0']]
act_data = act_data.replace(r'\s+\.', '.', regex=True)
act_data = act_data.replace(r'\s+,', ', ', regex=True)
act_data = act_data.replace(r'\n+,', ' ', regex=True)
act_data.rename(columns={'responses_act_0': 'act_text'}, inplace=True)
act_data['n_words'] = act_data['act_text'].str.split(' ').apply(lambda x: len(x))
min_word = 50
act_data = act_data[act_data['n_words'] >= min_word].reset_index()
act_data['text'] = act_data['act_text'].apply(lambda x: x.strip())

# %% Mood, energy, pospert
store_openq_dfs = []
for task_version in task_versions:
    file_path = f"_data/qs-intervention-{task_version}/processed/openq_data.csv"
    openq_data_pd = pd.read_csv(file_path).drop(columns=['task_version'])  # [int_cols]
    store_openq_dfs.append(openq_data_pd)
openq_data = pd.concat(store_openq_dfs).reset_index(drop=True)
openq_data = openq_data[openq_data['group'] == 'D']
openq_data = openq_data[~openq_data.isna().any(axis=1)]

min_word = 15
mood_data = openq_data[['sub', 'condition', 'group', 'oq_mood_text']]
mood_data = mood_data.replace(r'\s+\.', '.', regex=True)
mood_data = mood_data.replace(r'\s+,', ', ', regex=True)
mood_data = mood_data.replace(r'\n+,', ' ', regex=True)
mood_data.rename(columns={'oq_mood_text': 'mood_text'}, inplace=True)
mood_data['n_words'] = mood_data['mood_text'].str.split(' ').apply(lambda x: len(x))
mood_data = mood_data[mood_data['n_words'] >= min_word].reset_index()
mood_data['text'] = mood_data['mood_text'].apply(lambda x: x.strip())

energy_data = openq_data[['sub', 'condition', 'group', 'oq_energy_text']]
energy_data = energy_data.replace(r'\s+\.', '.', regex=True)
energy_data = energy_data.replace(r'\s+,', ', ', regex=True)
energy_data = energy_data.replace(r'\n+,', ' ', regex=True)
energy_data.rename(columns={'oq_energy_text': 'energy_text'}, inplace=True)
energy_data['n_words'] = energy_data['energy_text'].str.split(' ').apply(lambda x: len(x))
energy_data = energy_data[energy_data['n_words'] >= min_word].reset_index()
energy_data['text'] = energy_data['energy_text'].apply(lambda x: x.strip())

pospert_data = openq_data[['sub', 'condition', 'group', 'oq_pospert_text']]
pospert_data = pospert_data.replace(r'\s+\.', '.', regex=True)
pospert_data = pospert_data.replace(r'\s+,', ', ', regex=True)
pospert_data = pospert_data.replace(r'\n+,', ' ', regex=True)
pospert_data.rename(columns={'oq_pospert_text': 'pospert_text'}, inplace=True)
pospert_data['n_words'] = pospert_data['pospert_text'].str.split(' ').apply(lambda x: len(x))
pospert_data = pospert_data[pospert_data['n_words'] >= min_word].reset_index()
pospert_data['text'] = pospert_data['pospert_text'].apply(lambda x: x.strip())

texts_dict = {'act': act_data, 'pospert': pospert_data, 'mood': mood_data, 'energy': energy_data}

for text_name, text_data in tqdm(texts_dict.items()):
    formatted_texts = texts_to_msg(texts_dict[text_name], sample_config, tokenizer, rem_tag=True)
    # sub_formatted_message = list(formatted_texts.values())[0]
    # re.sub(r"<end_of_turn>\n<start_of_turn>model\n", '', sub_formatted_message)

    # Format and sample HS
    # formatted_texts = texts_to_msg(act_data, sample_config, tokenizer, rem_tag=True)

    sample_config.batchSize = 1
    sample_config.tmpRemBatches = int(np.ceil((len(formatted_texts) / sample_config.batchSize)))

    # split subject texts into batches
    texts_batch_idx = [[b * sample_config.batchSize, (b + 1) * sample_config.batchSize] for b in
                       range(sample_config.tmpRemBatches)]
    texts_batches = [list(formatted_texts.items())[batch_idx[0]:batch_idx[1]] for batch_idx in texts_batch_idx]
    texts_batches = texts_batches[:to_idx]

    # forward pass and save hidden states
    for text_batch in tqdm(texts_batches):  # go through batche
        sample_ts = str(round(datetime.timestamp(datetime.now()) * 10000)) + '_'
        sub_list = [sub_text[0] for sub_text in text_batch]
        sub_texts = [sub_text[1] for sub_text in text_batch]
        hs_ts = forward_pass_whs(sub_texts, model, tokenizer, sample_config)

        ## Save hidden states
        paths.hs_path = f'{paths.files_dir}text_hs/{text_name}/'
        Path(paths.hs_path).mkdir(parents=True, exist_ok=True)

        for sub, hs in zip(sub_list, hs_ts):
            hs_fname = f'{paths.hs_path}{sub}^^{sample_config.model_name_rp}^^hsT-{sample_config.hsT}^^{text_name}_text-hs.pt'
            torch.save(hs.clone(), hs_fname)
        del hs_ts
        flush()

# %% Git and shutdown
# try:
#     os.system("git status; git add .; git commit -am 'l4'; git push");
# except:
#     print('git issue')

# try:
#     os.system("sudo shutdown -h now")
# except:
#     print('shutdown issue')
