import os, re
import copy
import einops
# from examples.Official_demo import anorm
from tqdm import tqdm
import string
from datetime import datetime

from scipy import stats
import spgl1
import pandas as pd
import numpy as np
import sklearn.metrics
# from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from natsort import natsorted
import torch
from sympy.plotting.textplot import rescale
from torch import nn
from torch.utils.data import Dataset, DataLoader
import random
import torch.nn.functional as F
from torch import nn
import seaborn as sns
import gc
# from sklearn.metrics import accuracy_score, balanced_accuracy_score
from pathlib import Path
from itertools import product
import json
from meta_utils.utils import flush, download_blob, set_seed, get_ram, get_vram
from perturb_utils import *

my_dtype = torch.bfloat16

# download_blob('llm-bucket-res', 'saved_delta.zip', 'saved_delta.zip')
# download_blob('llm-bucket-res', '_logits_data.zip', '_logits_data.zip')
# download_blob('llm-bucket-res', 'saved_models_compress.zip', 'saved_models_compress.zip')

import matplotlib

torch.set_grad_enabled(False)

# from sae_new import train_config

if os.uname()[0] == 'Darwin':  # if on mac
    matplotlib.use('Qt5Agg')
    device_name = 'mps'
    import matplotlib.pyplot as plt

    plt.ion()
else:
    matplotlib.use('Agg')
    # matplotlib.use('TkAgg')
    matplotlib.get_backend()
    device_name = 'cuda'
    import matplotlib.pyplot as plt

from objects.data_configs import *
from objects.model_configs import *

from meta_objects.configs import *
from meta_objects.plot_config import *

# from prepare_dataset import MetaDataset, MyDataset, do_zscores

pc = PlotConfig()
paths = Paths(files_dir='files', plots_subdir='', plots_subsubdir='')
paths.plots_path = '_plots/'

from sae_models import *

do_zscores = True
from prepare_model_datasets import MetaDataset, MyDataset

from transformers import AutoTokenizer
from baukit import Trace, TraceDict

instr_name_str = 'instr3'
from logit_sampling.utils.logit_utils import setup_model, load_task_content, get_all_question_pairs, get_sub_locs, \
    sample_model_responses_batch, flush, maps, rev_lists, phq9_qs_inv_map, lvlx_closed_map
from logit_sampling.utils.logit_utils import create_question_pairs_instr, SteerConfig, SteerHiddenState

max_epochs = 1000
llm_path = 'logit_sampling/'
# %%
exp_name_data = 'SAE_Mistral_gemma_llama_exp_v1'
exp_name = 'SAE_Mistral_gemma_llama_exp_v3_best'
model_name_set = ['gemma2-9b-it']
# model_name_set = ['gemma2-2b-it']
# model_name_set = ['gemma2-9b-it', 'gemma2-2b-it','MistralOo']
# q_idx_to_perturb = 1  # Q2
q_score_change = 1
p_force = 1
to_sample_qs = 'phq9'
ds_type='val'
ds_type='test'
paths = Paths(files_dir=f"{llm_path}files_logits_perturbed_{ds_type}", sub_path=llm_path, plots_subdir='',
              plots_subsubdir='',
              files_data_dir=f'{llm_path}_data')
bools = Bools()

pc = PlotConfig()
qs_config = QsConfig()
model = None
end_idx = None
subs_end_idx = None
# steer_mlt_set = [-1.5, -1, -0.5, 0.5, 1, 1.5]
# steer_mlt_set = [-1.0, -0.5, 0.0, 0.5, 1.0]
# steer_mlt_set = [-0.25,0, 0.25]
# steer_mlt_set = [-3.5, -1, -0.5, 0.5, 1, 1.5]
# steer_mlt_set = [-1.5,-1,-
# dm=0.05
# steer_mlt_set = np.round(np.arange(-0.3,0.3+dm,dm),3)

# dm=1
# steer_mlt_set = np.round(np.arange(-1,1+dm,dm),0)
# steer_mlt_set = [0, -2.0, -1.5, -1.0, -0.5, 0.5, 1.0, 1.5, 2.0]
steer_mlt_set = [0.0, -1.0, 1.0]  # ,-2.0, -1.5, -1.0, -0.5, 0.5, 1.0, 1.5, 2.0]
if ds_type=='val':
    steer_mlt_set = [-1.5, -1, -0.5, -0.25] + [0.25, 0.5, 1, 1.5]
if ds_type=='test':
    steer_mlt_set = [-1.5, -1, -0.5, -0.25] + [0.25, 0.5, 1, 1.5]
    steer_mlt_set = [-0.25,1.5]

# steer_mlt_set = [-1.5, -1, -0.5, -0.25] + [0.0] + [0.25, 0.5, 1, 1.5]

for model_name in model_name_set:
    sample_config = SampleLogitsConfig(paths, model_name=model_name, qs_name='phq9', instr_name='instr3', temp='',
                                       top_p='')

    ### Prepare for sampling logits
    sample_config.nSamples = 50
    sample_config.batchSize = 1  # 20
    sample_config.save_states = True
    sample_config.gen_fname = 'gb'
    sample_config.permute_labels = False
    sample_config.label_letters = None
    label_letters = list(string.ascii_uppercase)[:qs_config.qs_n_lab[to_sample_qs]]
    q_scores = list(maps[to_sample_qs].values())
    # q_responses = list(maps[to_sample_qs].keys())

    if sample_config.permute_labels:
        tmp_zip = list(zip(q_scores, label_letters))
        random.shuffle(tmp_zip)
        q_scores, label_letters = zip(*tmp_zip)
    sample_config.label_letters = label_letters
    sample_config.q_scores = q_scores
    sample_config.label_scores = {k: v for k, v in zip(label_letters, q_scores)}

    ### Load LLM
    if model is None:
        model, tokenizer = setup_model(sample_config)
        # model = model.model

    # # prepare tokenizer
    # tokenizer = AutoTokenizer.from_pretrained(sample_config.model_name_hf, use_fast=True)
    # tokenizer.pad_token_id = tokenizer.eos_token_id
    # # toker = tokenizer

    label_ids = tokenizer(sample_config.label_letters, padding=True, return_tensors='pt',
                          return_attention_mask=True).to(device_name)
    label_ids = label_ids.input_ids[:, 1:]

    ### SAE set up stuff experiement and hyperparams
    loss_plot_dir = f'_plots/loss/{exp_name}/{model_name}/'
    loss_dir = f'loss/{exp_name}/{model_name}/'
    metric_dir = f'metrics/{exp_name}/{model_name}/'
    model_dir = f'saved_models/{exp_name}/{model_name}/'
    delta_dir = f'saved_delta/{exp_name}/{model_name}/'
    prelim_plot_dir = f'_plots/prelim/{exp_name}/{model_name}/'
    Path(delta_dir).mkdir(parents=True, exist_ok=True)

    # Get the necssary info to load dataset
    best_configs = os.listdir(model_dir)
    best_configs = [t.replace('.pt', '').replace('best_', '') for t in best_configs]
    best_config_dict = [{s.split('-')[0]: '-'.join(s.split('-')[1:]) for s in t.split('^^')} for t in best_configs][0]
    best_hyper_dict = {k: v for k, v in best_config_dict.items() if k not in ['model', 'layer']}
    hparams = list(best_config_dict.values())
    q_case = hparams[2]
    exp_config = ExpConfig(sample_config, q_case=q_case, do_zscores=do_zscores, max_epochs=max_epochs)
    # Get data
    meta_dataset = torch.load(f'data/{exp_name_data}/{exp_config.dataset_fname}.pt', weights_only=False)
    n_train = int(0.7 * len(meta_dataset.subs))
    n_val = int(0.15 * len(meta_dataset.subs))
    n_test = int(0.15 * len(meta_dataset.subs))

    # Set layer list for the model
    # layer_idxs = np.arange(sample_config.L // 2, sample_config.L)[3:4]
    layer_idxs = np.arange(sample_config.L // 2, sample_config.L)
    # layer_idxs = np.arange(sample_config.L // 2, sample_config.L)[-2:]

    ## Load data open-ended
    if ds_type=='test':
        test_dataset_subs = list(MyDataset(meta_dataset, n_train + n_val, None, 0, device=device_name).subs) # test
    if ds_type=='val':
        # val
        test_dataset_subs = list(MyDataset(meta_dataset, n_train, n_train + n_val, 0, device=device_name).subs)
        # test_dataset_subs = list(MyDataset(meta_dataset, n_train, n_train + n_val, 0, device=device_name).subs)[
        #     :subs_end_idx]

    instr_dict, open_qs_dict, _ = load_task_content(sample_config)
    # instr_dict, open_qs_dict, open_spec_qs_dict, closed_qs_dict, _ = load_task_content(sample_config)
    oq_names = list(open_qs_dict.keys())  # + list(open_qs_rep_dict.keys())
    oq_names = [n for n in oq_names if 'lvl3' in n]
    phq9_data = pd.read_csv(f"{paths.files_data_dir}phq9_data.csv")
    phq9_names = [c for c in phq9_data.columns if 'phq9_q' in c]
    openq_data = pd.read_csv(f"{paths.files_data_dir}openq_data.csv")

    # subset tasks
    task_versions = ['v4', 'v4_d', 'v4_dd', 'v4_ddd']
    openq_data = openq_data[openq_data['task_version'].isin(task_versions)]
    phq9_data = phq9_data[phq9_data['task_version'].isin(task_versions)]

    # intro_prompt, oq_instr, open_qs_dict, closed_questions_dict, cq_preamble, qsn_questions_dict, cq_instr, qsn_preamble = create_question_pairs_instr(sample_config)
    # preprocess text
    openq_data = openq_data.replace(r'\s+\.', '.', regex=True)
    openq_data = openq_data.replace(r'\s+,', ', ', regex=True)
    openq_data = openq_data.replace(r'\n+,', ' ', regex=True)
    df_obj = openq_data.select_dtypes('object')
    openq_data[df_obj.columns] = df_obj.apply(lambda x: x.str.strip())
    openq_data_long = pd.melt(openq_data, id_vars=['sub'], value_vars=oq_names, var_name='q_name').sort_values(
        by=['sub', 'q_name']).reset_index(drop=True)
    phq9_data_long = pd.melt(phq9_data, id_vars=['sub'], value_vars=phq9_names, var_name='q_name')

    # openq_data_long['len'] = (
    empty_resp_idx = openq_data_long['value'].apply(lambda x: len(x.split(' ')) if type(x) == str else 0) == 0
    openq_data_long = openq_data_long.loc[~empty_resp_idx, :]
    # openq_data_long['value'] = openq_data_long['value'].str.replace(r'[^a-zA-Z0-9]+$', '', regex=True)

    # Get only test subjects
    openq_data_long = openq_data_long[openq_data_long['sub'].isin(test_dataset_subs)]

    lvl1_closed_data = pd.read_csv(f"{paths.files_data_dir}lvl1_closed_data.csv")
    lvl2_closed_data = pd.read_csv(f"{paths.files_data_dir}lvl2_closed_data.csv")
    del df_obj, empty_resp_idx, instr_dict, oq_names, phq9_names

    intro_prompt, oq_instr, open_qs_dict, closed_questions_dict, cq_preamble, qsn_questions_dict, cq_instr, qsn_preamble = create_question_pairs_instr(
        sample_config)

    instr_dict, open_qs_dict, closed_qs_dict = load_task_content(sample_config)
    question_pairs, question_pairs_formatted, _, _ = get_all_question_pairs(openq_data_long, sample_config, tokenizer)

    for sub in tqdm(list(question_pairs.keys())[0:end_idx]):
        sample_config.subj = sub
        sub_texts_to_find, sub_last_token_locs = get_sub_locs(openq_data_long, question_pairs_formatted, tokenizer,
                                                              sample_config)
        for k, v in sub_last_token_locs.items():
            for k2, v2 in v.items():
                if v2 == -1:
                    print(sub, k, k2, 'missing location')

        for q_idx_to_perturb, which_q in tqdm(enumerate(list(question_pairs[sub].keys())[0:end_idx])):
            # Set labels for resposnes
            sample_config.q_responses = {l: k for l, (k, v) in zip(label_letters, maps[to_sample_qs].items())}
            if 'lvl3' not in which_q:
                sample_config.q_responses = {l: k for l, (k, v) in zip(label_letters, lvlx_closed_map.items())}

            sample_config.qs_labels = [sample_config.q_responses[l] for l in sample_config.label_letters]
            # sample_ts = str(round(datetime.timestamp(datetime.now()) * 10000)) + '_'
            sample_ts = str(round(datetime.timestamp(datetime.now()) * 10000))
            # print(sub, which_q)
            sample_config.save_states = False
            sample_config.which_q = which_q

            # Check if files are already there (based on number)
            perturbation_config_list = []
            for str_mlt in steer_mlt_set:
                perturbation_config_dict = {'latent_q_score_change': q_score_change,
                                            'latent_p_force': p_force,
                                            'q_to_perturb': f'{q_idx_to_perturb + 1}',
                                            'hs_steer_mlt': str_mlt}
                perturbation_config = '^^'.join([f'{k}-{v}' for k, v in perturbation_config_dict.items()])
                perturbation_config_list.append(perturbation_config)

            sample_config.get_remaining_samples(printMe=True, pert_config_len=len(perturbation_config_list))

            # Start the sampling and perturbing for sub and questions
            model_sae = None
            while not sample_config.currentFiles:
                sub_q_token_locs = sub_last_token_locs[sample_config.which_q]
                sub_qs_pairs = [question_pairs_formatted[sample_config.subj][
                                    sample_config.which_q]]

                # inputs_ids = tokenizer(sub_qs_pairs, return_tensors="pt", padding=True).input_ids.to(device_name)
                inputs_ids = tokenizer(sub_qs_pairs, padding=True, return_tensors='pt', return_attention_mask=True).to(
                    device_name)

                model_output = model(inputs_ids.input_ids, attention_mask=inputs_ids.attention_mask,
                                     output_hidden_states=True)

                # Get LLM model original hidden states - 1st pass
                hs_ts = torch.stack(model_output.hidden_states)
                # hs_ts_sub = hs_ts[1:, 0, sub_q_token_locs['oq_ans'], :][layer_idxs]
                # hs_ts_sub = hs_ts[1:, 0, -1, :][layer_idxs]
                hs_ts_sub = hs_ts[1:, 0, [sub_q_token_locs['oq_ans'], -1], :][layer_idxs].mean(dim=1)
                del hs_ts
                # torch.sign(h@a.T)

                # Get layerwise model perturbation vectors
                steering_vectors = {}
                for l, (layer_idx, hs_ts_sub_l) in enumerate(zip(layer_idxs, hs_ts_sub)):
                    layer_name = f'layer-{layer_idx + 1}'
                    # print(layer_name)
                    best_config = [m for m in os.listdir(model_dir) if layer_name in m]
                    best_config = [t.replace('.pt', '').replace('best_', '') for t in best_config]
                    # print(best_config[0])
                    best_config_dict = \
                        [{s.split('-')[0]: '-'.join(s.split('-')[1:]) for s in t.split('^^')} for t in best_config][
                            0]
                    best_hyper_dict = {k: v for k, v in best_config_dict.items() if k not in ['model', 'layer']}
                    hparams = list(best_config_dict.values())
                    # q_case = hparams[2]
                    batch_size = int(hparams[3])
                    optim_lr = float(hparams[4])
                    sparsity_coeff = float(hparams[5])
                    qs_coeff = int(hparams[6])
                    sae_factor = int(hparams[7])
                    tied_weights = hparams[8] == 'True'
                    sae_name = hparams[9]

                    # exp_config = ExpConfig(sample_config, q_case=q_case, do_zscores=do_zscores, max_epochs=max_epochs)
                    # batch_size = int(hparams[3])
                    # optim_lr = float(hparams[4])
                    # sparsity_coeff = float(hparams[5])
                    # qs_coeff = int(hparams[6])
                    # sae_factor = int(hparams[7])
                    # tied_weights = hparams[8] == 'True'
                    # sae_name = hparams[9]

                    exp_config.d_h = list(meta_dataset.avg_features.values())[0].shape[1]
                    sae_cfg = SAE_Config(exp_config, device=device_name, x_m=sae_factor, sparse_coeff=sparsity_coeff,
                                         tied_weights=tied_weights, qs_coeff=qs_coeff)
                    # layer_name = f'layer-{layer_idx + 1}'
                    layer_module_name = f'model.layers.{layer_idx}'
                    model_layer_config = [c for c in best_configs if layer_name in c][0]
                    model_fp = f'{model_dir}best_{model_layer_config}.pt'
                    # deltaS_fp = f'{delta_dir}deltaS_{model_layer_config}^^s_change-{q_score_change}.pt'
                    deltaS_fp = f'{delta_dir}deltaS_{model_layer_config}^^q_target_idx-{q_idx_to_perturb}^^s_change-{q_score_change}.pt'

                    # Load SAE model
                    if 'SAE2' in sae_name:
                        model_sae = SAE2_a(sae_cfg)
                    if 'SAE3' in sae_name:
                        model_sae = SAE3(sae_cfg)
                    if 'SAE4' in sae_name:
                        model_sae = SAE4(sae_cfg)
                    model_sae.load_state_dict(torch.load(model_fp, map_location=torch.device(device_name)))

                    # deltaS = get_perturb_delta(model_sae,deltaS_fp,q_idx_to_perturb=q_idx_to_perturb,score_change=q_score_change,loadDelta=False,saveDelta=True,device_name=device_name)
                    deltaS = get_perturb_delta(model_sae, deltaS_fp, q_idx_to_perturb=q_idx_to_perturb,
                                               score_change=q_score_change,
                                               device_name=device_name)

                    # # Forward pass get original states
                    # Get perturbed states
                    _, _, h_rec_perturbed = model_sae.infer(hs_ts_sub_l.unsqueeze(0), delta_s=deltaS,
                                                            p_force=p_force)

                    steering_vectors[layer_module_name] = h_rec_perturbed.type(my_dtype)

                # h=hs_ts_sub[0:1]
                # p=list(steering_vectors.values())[0]
                # a2=p-(p@h.T)/(h.norm(dim=-1,keepdim=True)**2)*h
                # Steer hidden states with steering vectors - 2nd pass
                for str_mlt in steer_mlt_set:
                    steer_config = SteerConfig(layer_ids=layer_idxs, steering_vectors=steering_vectors, device='mps',
                                               multiplier=str_mlt, run_gen=False, sample_text=False,
                                               perturb_input_only=True, norm_vecs=True, n_tokens=1)
                    steer_hs = SteerHiddenState(steer_config, model, tokenizer)
                    steer_hs.steer(model, sub_qs_pairs, return_org=False)

                    logits_perturbed = steer_hs.logits
                    label_logits_perturbed = logits_perturbed[:, label_ids].squeeze(-1).detach().cpu().type(
                        torch.float32)
                    label_probs_perturbed = F.softmax(label_logits_perturbed, dim=-1).numpy()[0, :]

                    # logits_original = steer_hs.logits_original
                    # label_logits_original = logits_original[:, label_ids].squeeze(-1).detach().cpu().type(torch.float32)
                    # label_probs_original = F.softmax(label_logits_original, dim=-1).numpy()[0, :]
                    # print(f'org prob: {label_probs_original}\nperturb prob: {label_probs_perturbed}')

                    sampled_perturbed_labels = np.random.choice(sample_config.label_letters,
                                                                size=sample_config.nSamples,
                                                                p=label_probs_perturbed)
                    sampled_perturbed_scores = [sample_config.label_scores[sl] for sl in sampled_perturbed_labels]
                    sampled_perturbed_responses = [sample_config.q_responses[sl] for sl in sampled_perturbed_labels]

                    # Save and sample from logits after steering
                    tmp_dict = {'sub': sample_config.subj,
                                'sample_ts': [f'{sample_ts}_{s}' for s in range(sample_config.nSamples)],
                                'question': sample_config.which_q,
                                'score': sampled_perturbed_scores, 'response': sampled_perturbed_responses,
                                'model': sample_config.model_name,
                                'latent_q_score_change': q_score_change, 'latent_p_force': p_force,
                                'q_to_perturb': f'{q_idx_to_perturb + 1}', 'hs_steer_mlt': str_mlt,
                                'instr_name': sample_config.instr_name, 'qs': sample_config.qs_name,
                                'label_perm': ''.join(sample_config.label_letters),
                                'nSamples': sample_config.nSamples} | best_hyper_dict
                    tmp_pd = pd.DataFrame(tmp_dict)

                    tmp_logits_dict = {'sub': sample_config.subj, 'sample_ts': sample_ts,
                                       'question': sample_config.which_q,
                                       'logits': label_logits_perturbed[0], 'probs': label_probs_perturbed,
                                       'latent_q_score_change': q_score_change, 'latent_p_force': p_force,
                                       'q_to_perturb': f'{q_idx_to_perturb + 1}', 'hs_steer_mlt': str_mlt,
                                       'label_letters': sample_config.label_letters,
                                       'qs_labels': sample_config.qs_labels,
                                       'label_scores': sample_config.q_scores, 'model': sample_config.model_name,
                                       'instr_name': sample_config.instr_name, 'qs': sample_config.qs_name,
                                       'label_perm': ''.join(sample_config.label_letters),
                                       'nSamples': sample_config.nSamples} | best_hyper_dict
                    tmp_logits_pd = pd.DataFrame(tmp_logits_dict)

                    perturbation_config_dict = {'latent_q_score_change': q_score_change,
                                                'latent_p_force': p_force,
                                                'q_to_perturb': f'{q_idx_to_perturb + 1}',
                                                'hs_steer_mlt': str_mlt}
                    perturbation_config = '^^'.join([f'{k}-{v}' for k, v in perturbation_config_dict.items()])

                    tmp_fname = f"{sample_config.sample_path}{sample_config.subj}^^{sample_config.which_q}^^{''.join(sample_config.label_letters)}^^{sample_config.model_name_rp}^^{perturbation_config}"
                    # hs_fname = f"{tmp_fname}_hidden_states_s-ts-{sample_ts}0.pt"
                    logits_fname = f"{tmp_fname}_logits_s-ts-{sample_ts}.csv"
                    responses_fname = f"{tmp_fname}_responses_s-ts-{sample_ts}.csv"

                    tmp_pd.to_csv(f"{responses_fname}", index=False)
                    tmp_logits_pd.to_csv(f"{logits_fname}", index=False)

                sample_config.get_remaining_samples(printMe=False, pert_config_len=len(perturbation_config_list))

            del model_sae, hs_ts_sub, hs_ts_sub_l
            flush()
        flush()
    del model
    flush()
