import copy
import re

import matplotlib
from tqdm import tqdm
import os
import pandas as pd
import numpy as np
from transformers import AutoTokenizer
# import torch._dynamo
# torch._dynamo.config.suppress_errors = True
# torch._dynamo.disable()
# os.environ['TORCH_LOGS'] = "+dynamo"
# os.environ['TORCHDYNAMO_VERBOSE'] = "1"

# from utils.sampling_utils import create_question_pairs_instr

os.environ["TOKENIZERS_PARALLELISM"] = "false"
# set paths
path_to_go = '_analysis'
llm_path = '_llm_based/'
if os.uname()[0] == 'Darwin':  # if on mac
    cwd = os.getcwd()
    cwd_split = cwd.split('/')
    base_task = 'qs_structure'
    cwd_base = '/'.join(cwd_split[:np.argwhere([p == "online_tasks" for p in cwd_split])[0][0] + 1])
    new_path = f"{cwd_base}/{base_task}/{path_to_go}/"
    os.chdir(new_path)

    # task_version = 'qs-structure-phq9-v2'
    # analysis_dir = '_analysis'
    # cwd_base = '/'.join(cwd_split[:np.argwhere([p == "online_tasks" for p in cwd_split])[0][0] + 1])
    # new_path = f"{cwd_base}/{base_task}/{task_version}/{analysis_dir}"
    # os.chdir(new_path)

model = None
if os.uname()[0] == 'Darwin':  # if on mac
    matplotlib.use('Qt5Agg')
    # sys.path.append(new_path)
    device_name = 'mps'
    # device_name = 'cpu'
    model_name = 'GPT2'
    model_name = 'gemma2-2b-it'
    model_name = 'gemma2-9b-it'
    # model_name = 'llama32-3b-it'
    # model_name = 'MistralOo'
    # model_name = 'MistralOo'
else:
    matplotlib.use('Agg')
    device_name = 'cuda'
    # # model_name = 'MistralOo'
    # model_name = 'gemma2-2b-it'
    model_name = 'gemma2-9b-it'

instr_name_str = 'instr2'

if 'llama3'in model_name:
    instr_name_str = 'instr3'
if 'gemma2-9b' in model_name:
    instr_name_str = 'instr3'

import matplotlib.pyplot as plt

from _llm_based.utils.sampling_utils import setup_model, load_task_content, get_all_question_pairs, get_sub_locs, \
    sample_model_responses_batch, flush, spec_levels

from utils.utils import upload_blob_from_memory, download_blob
# download_blob('llm-bucket-res', '_data.zip','_data.zip')
from objects.configs import *
from objects.plot_config import *
from _llm_based.objects.data_configs import *
from _llm_based.objects.model_configs import *

# %%
paths = Paths(files_dir=f"{llm_path}/files", sub_path=llm_path, plots_subdir='', plots_subsubdir='',
              files_data_dir='_data')
bools = Bools()

pc = PlotConfig()
qs_config = QsConfig()

# initalise LLM sampling config
to_sample_qs = 'phq9'
# sample_config = SampleConfig(paths, model_name=model_name, qs_name=to_sample_qs, instr_name='instr2', temp='', top_p='')
sample_config = SampleConfig(paths, model_name=model_name, qs_name=to_sample_qs, instr_name=instr_name_str, temp='', top_p='')
sample_config.nSamples = 4
sample_config.batchSize = 1
sample_config.save_states = True

sample_config.gen_fname = 'gb'
do_bucket = False
# %%
if model is None:
    model, tokenizer = setup_model(sample_config)

# prepare tokenizer
# tokenizer = AutoTokenizer.from_pretrained(sample_config.model_name_hf, use_fast=True)
# tokenizer.pad_token_id = tokenizer.eos_token_id
# %% Load and prepare data
# spec_types = {'v4_dd':'gb','v4_d':'gb','v4': 'gb', 'v3_s': 'spec_hh', 'v2': 'gen', 'v2a': 'gen', 'v2b': 'gen', 'v1': 'gen'}
spec_types = {'v4_ddd': 'gb', 'v4_dd': 'gb', 'v4_d': 'gb', 'v4': 'gb', 'v3_s': 'spec_hh', 'v2': 'gen', 'v2a': 'gen',
              'v2b': 'gen', 'v1': 'gen'}
instr_dict, open_qs_dict, _ = load_task_content(sample_config)
# instr_dict, open_qs_dict, open_spec_qs_dict, closed_qs_dict, _ = load_task_content(sample_config)
oq_names = list(open_qs_dict.keys())  # + list(open_qs_rep_dict.keys())
phq9_data = pd.read_csv(f"{paths.files_data_dir}phq9_data.csv")
phq9_names = [c for c in phq9_data.columns if 'phq9_q' in c]
openq_data = pd.read_csv(f"{paths.files_data_dir}openq_data.csv")

# subset tasks
# task_versions = ['v4']
# task_versions = ['v4_d', 'v4_dd']
# task_versions = ['v4_d', 'v4_dd']
task_versions = ['v4','v4_d', 'v4_dd','v4_ddd']
# task_versions = ['v4_ddd']
openq_data = openq_data[openq_data['task_version'].isin(task_versions)]
phq9_data = phq9_data[phq9_data['task_version'].isin(task_versions)]

# intro_prompt, oq_instr, open_qs_dict, closed_questions_dict, cq_preamble, qsn_questions_dict, cq_instr, qsn_preamble = create_question_pairs_instr(sample_config)
# preprocess text
openq_data = openq_data.replace(r'\s+\.', '.', regex=True)
openq_data = openq_data.replace(r'\s+,', ', ', regex=True)
openq_data = openq_data.replace(r'\n+,', ' ', regex=True)
df_obj = openq_data.select_dtypes('object')
openq_data[df_obj.columns] = df_obj.apply(lambda x: x.str.strip())
openq_data_long = pd.melt(openq_data, id_vars=['sub'], value_vars=oq_names, var_name='q_name').sort_values(
    by=['sub', 'q_name']).reset_index(drop=True)
phq9_data_long = pd.melt(phq9_data, id_vars=['sub'], value_vars=phq9_names, var_name='q_name')

# openq_data_long['len'] = (
empty_resp_idx = openq_data_long['value'].apply(lambda x: len(x.split(' ')) if type(x) == str else 0) == 0
openq_data_long = openq_data_long.loc[~empty_resp_idx, :]
# openq_data_long['value'] = openq_data_long['value'].str.replace(r'[^a-zA-Z0-9]+$', '', regex=True)

lvl1_closed_data = pd.read_csv(f"{paths.files_data_dir}lvl1_closed_data.csv")
lvl2_closed_data = pd.read_csv(f"{paths.files_data_dir}lvl2_closed_data.csv")
del df_obj, empty_resp_idx, instr_dict, oq_names, phq9_names

## include specificity level for version 3
openq_data_long['spec_level'] = 'gen'
openq_data_long.loc[openq_data_long['sub'].str.contains('v3_s'), 'spec_level'] = openq_data_long.loc[
    openq_data_long['sub'].str.contains('v3_s'), 'q_name']

openq_data_long.loc[openq_data_long['sub'].str.contains('v3_s'), 'spec_level'] = \
    openq_data_long[openq_data_long['sub'].str.contains('v3_s')]['spec_level'].replace(spec_levels['v3_s'])

question_pairs, question_pairs_formatted, _, _ = get_all_question_pairs(
    openq_data_long, sample_config, tokenizer)

# # subset the questions
# q_subset = ['lvl1_q1', 'lvl2_q1', 'lvl2_q2', 'lvl2_q3']
#
# question_pairs = {sub: {q_name: q_content for q_name, q_content in q_dict.items() if q_name in q_subset} for sub, q_dict
#                   in question_pairs.items() if 'v3_s' in sub}
# question_pairs = {sub: q_dict for sub, q_dict in question_pairs.items() if len(q_dict) > 0}
# question_pairs_formatted = {sub: {q_name: q_content for q_name, q_content in q_dict.items() if q_name in q_subset} for
#                             sub, q_dict
#                             in question_pairs_formatted.items() if 'v3_s' in sub}
# question_pairs_formatted = {sub: q_dict for sub, q_dict in question_pairs_formatted.items() if len(q_dict) > 0}
# %%
sample_config.nSamples = 25  # 50
sample_config.batchSize = 15  # 20
extra_save = False
for sub in tqdm(list(question_pairs.keys())[0:]):
    sample_config.subj = sub
    sub_texts_to_find, sub_last_token_locs = get_sub_locs(openq_data_long, question_pairs_formatted, tokenizer,
                                                          sample_config)

    for k, v in sub_last_token_locs.items():
        for k2, v2 in v.items():
            if v2 == -1:
                print(sub, k, k2, 'missing location')

    for which_q in tqdm(list(question_pairs[sub].keys())[0:]):
        print(sub, which_q)
        sample_config.save_states = True
        sample_config.which_q = which_q

        # prepare subject data and locations
        # openq_data_long[(openq_data_long['sub'] == sample_config.subj)]['spec_level'].values[0]

        sub_q_token_locs = sub_last_token_locs[sample_config.which_q]

        sample_config.get_remaining_samples()
        sampling_started = False
        while sample_config.remBatches > 0:
            sample_ts = str(round(datetime.timestamp(datetime.now()) * 10000)) + '_'
            sub_qs_pairs = [question_pairs_formatted[sample_config.subj][
                                sample_config.which_q]] * sample_config.remBatchSize

            if sample_config.remSamples < sample_config.nSamples and not extra_save:
                sample_config.save_states = False

            # start sampling
            sample_model_responses_batch(sub_qs_pairs, sub_q_token_locs, model, tokenizer, sample_ts, sample_config)
            sample_config.get_remaining_samples()
            flush()
            sampling_started = True

        # if do_bucket:#and sampling_started and sample_config.remSamples==0:
        #     try:
        #         tmp_sub_path = sample_config.responses_path.split('files/')[1][:-1]
        #         os.system(f"gcloud storage rsync -r --no-clobber {sample_config.responses_path} gs://llm-bucket-res/files_l4/{tmp_sub_path}")
        #     except:
        #         print('issue uploading responses')
        #     try:
        #         tmp_sub_path = sample_config.outputs_path.split('files/')[1][:-1]
        #         os.system(f"gcloud storage rsync -r --no-clobber {sample_config.outputs_path} gs://llm-bucket-res/files_l4/{tmp_sub_path}")
        #     except:
        #         print('issue uploading outputs')
        #     try:
        #         tmp_sub_path = sample_config.states_path.split('files/')[1][:-1]
        #         os.system(f"gcloud storage rsync -r --no-clobber {sample_config.states_path} gs://llm-bucket-res/files_l4/{tmp_sub_path}")
        #     except:
        #         print("issue uploading states")

# %% Save and upload hidden states
# states_path = 'files/states/subjects'
# subs = [d for d in os.listdir(paths.subj_states_dir) if '.DS_Store' not in d]
# store_pt_files = []
# # store_pt_file_names = []
# for sub in subs:
#     sub_path = f"{paths.subj_states_dir}/{sub}"
#     qs = [d for d in os.listdir(sub_path) if '.DS_Store' not in d]
#     for q in qs:
#         sub_q_path = f"{sub_path}/{q}/"
#         spec = [d for d in os.listdir(sub_q_path) if '.DS_Store' not in d][0]
#         spec_list = spec.split('^')
#         pt_file = [f for f in os.listdir(f"{sub_q_path}/{spec}") if '.pt' in f][0]
#         fname = f"{sub}^^{q}^^{spec}_" + pt_file
#         # store_pt_file_names.append(fname)
#         pt_file = f"{sub_q_path}{spec}/" + pt_file
#         store_pt_files.append(pt_file)
#
# sup_dir = f"{paths.sub_path}states/"
#
# Path(sup_dir).mkdir(parents=True, exist_ok=True)
# zip_ts = str(round(datetime.timestamp(datetime.now()) * 10000))
# files_to_zip = " ".join(store_pt_files)
# zip_fname = sup_dir + 'sub_hidden_states_' + zip_ts + '.zip'
# os.system("zip -j " + zip_fname + ' ' + files_to_zip)
# upload_blob_from_memory('llm-bucket-res', os.getcwd() + '/' + zip_fname, zip_fname)
# upload_blob_from_memory('llm-bucket-res', os.getcwd() + '/' + 'state_pts/states.zip', 'state_pts/states.zip')

# %% Git and shutdown
# try:
#     os.system("git status; git add .; git commit -am 'l4'; git push");
# except:
#     print('git issue')

try:
    os.system("sudo shutdown -h now")
except:
    print('shutdown issue')

# %% Clean directories
rp = '/'.join(sample_config.responses_path.split('/')[:-3])
sp = '/'.join(sample_config.states_path.split('/')[:-3])
op = '/'.join(sample_config.outputs_path.split('/')[:-3])

# os.system(f"rm -r {gp} {r1p} {r2p} {tp} {op} {wp} {lp} {pp}")
