import os, re
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
import seaborn as sns
torch.set_grad_enabled(False)
import matplotlib
import matplotlib.pyplot as plt

if os.uname()[0] == 'Darwin':  # if on mac
    device_name = 'mps'
    matplotlib.use('Qt5Agg')
    plt.ion()

else:
    matplotlib.use('Agg')
    matplotlib.get_backend()
    device_name = 'cuda'

from _objects.model_configs import *
from _objects.configs import *
from _objects.plot_config import *

pc = PlotConfig()

testCase_dir = 'test2_ssae'
subPath = 'analysis_sampling'
test_path = f"{testCase_dir}/{subPath}/"
fig_no = 'Fig4'

files_path_new = f'{testCase_dir}/llm_sampling/files_logits/'  # new repo path
files_path = '../online_tasks/qs_structure/_analysis/_llm_based/files_logits/'  # old path
prev_path = '../construct_detector/'  # old path
prev_path_new = ''
paths = Paths(files_dir=files_path, plots_subdir='', plots_subsubdir='')
paths.plots_path = f'{testCase_dir}/_plots/perturbation/'

from _objects.sae_models import *

do_zscores = True
from test2_ssae.prepare_model_datasets import MetaDataset, MyDataset
from test2_ssae.perturb_utils import *

max_epochs = 1000
bools = Bools()
exp_name_data = 'SAE_Mistral_gemma_llama_exp_v1'
exp_name = 'SAE_Mistral_gemma_llama_exp_v3_best'
model_name = 'gemma2-9b-it'
q_idxs = np.arange(9)
q_score_change = 1
p_force = 1
computeDelta=False
bools.saveFig=True
# %% Load model configs
sample_config = SampleLogitsConfig(paths, model_name=model_name, qs_name='phq9', instr_name='instr3', temp='',
                                   top_p='')

model_dir = f'{prev_path}saved_models/{exp_name}/{model_name}/'
delta_dir = f'{prev_path}saved_delta/{exp_name}/{model_name}/'
layer_idxs = np.arange(sample_config.L // 2, sample_config.L)
layer_idx = 41  # best layer index
l = int(np.where(layer_idxs == layer_idx)[0][0])
layer_name = f'layer-{layer_idx + 1}'
print(layer_name)
best_config = [m for m in os.listdir(model_dir) if layer_name in m]
best_config = [t.replace('.pt', '').replace('best_', '') for t in best_config]
best_config_dict = [{s.split('-')[0]: '-'.join(s.split('-')[1:]) for s in t.split('^^')} for t in best_config][
    0]

best_hyper_dict = {k: v for k, v in best_config_dict.items() if k not in ['model', 'layer']}
hparams = list(best_config_dict.values())
q_case = hparams[2]
batch_size = int(hparams[3])
optim_lr = float(hparams[4])
sparsity_coeff = float(hparams[5])
qs_coeff = int(hparams[6])
sae_factor = int(hparams[7])
tied_weights = hparams[8] == 'True'
sae_name = hparams[9]

exp_config = ExpConfig(sample_config, q_case=q_case, do_zscores=do_zscores, max_epochs=max_epochs)
# Get data and set config
meta_dataset = torch.load(f'{prev_path}/data/{exp_name_data}/{exp_config.dataset_fname}.pt', weights_only=False)
n_train = int(0.7 * len(meta_dataset.subs))
n_val = int(0.15 * len(meta_dataset.subs))
n_test = int(0.15 * len(meta_dataset.subs))

exp_config.d_h = list(meta_dataset.avg_features.values())[0].shape[1]
sae_cfg = SAE_Config(exp_config, device=device_name, x_m=sae_factor, sparse_coeff=sparsity_coeff,
                     tied_weights=tied_weights, qs_coeff=qs_coeff)

# train_dataset = MyDataset(meta_dataset, None, n_train, l, device=device_name)
val_dataset = MyDataset(meta_dataset, n_train, n_train + n_val, l, device=device_name)
# test_dataset = MyDataset(meta_dataset, n_train + n_val, None, l, device=device_name)

# test_dataloader = DataLoader(test_dataset, batch_size=len(test_dataset.subs), shuffle=True)
test_dataloader = DataLoader(val_dataset, batch_size=len(val_dataset.subs), shuffle=True)
test_features, test_labels = next(iter(test_dataloader))
test_features = test_features.to(device_name)
test_labels = test_labels.to(device_name)

model_layer_config = best_config[0]
model_fp = f'{model_dir}best_{model_layer_config}.pt'

# Load SAE model
model_sae = SAE4(sae_cfg)
model_sae.load_state_dict(torch.load(model_fp))

#%%
store_diffs = []
for q_idx_to_perturb in tqdm(q_idxs):
    print(f'Q: {q_idx_to_perturb + 1}')
    deltaS_fp = f'{delta_dir}deltaS_{model_layer_config}^^q_target_idx-{q_idx_to_perturb}^^s_change-{q_score_change}.pt'
    # print(deltaS_fp)

    deltaS = get_perturb_delta(model_sae, deltaS_fp, q_idx_to_perturb=q_idx_to_perturb,
                               score_change=q_score_change, loadDelta=(not computeDelta), saveDelta=True,
                               device_name=device_name)
    # deltaS = get_perturb_delta(model_sae, deltaS_fp, q_idx_to_perturb=q_idx_to_perturb, score_change=q_score_change,
    #                            device_name=device_name)

    # Forward pass get original states
    latent_qs, _, _ = model_sae.infer(test_features)
    # Get perturbed states
    latent_qs_perturbed, _, h_rec_perturbed = model_sae.infer(test_features, delta_s=deltaS,
                                                              p_force=p_force * 1)

    tmp_diff = (latent_qs_perturbed-latent_qs).to('cpu').mean(axis=0)
    tmp_df = pd.DataFrame(tmp_diff,columns=['diff']).reset_index(names='q_idx')
    tmp_df= tmp_df.melt(id_vars='q_idx')
    tmp_df['q_idx_to_perturb'] = q_idx_to_perturb
    # tmp_df['p_force']=p_force
    store_diffs.append(tmp_df)

store_diffs = pd.concat(store_diffs)
#%% Plot
plt.close('all')
pc.r, pc.c, pc.mlt = 1, 1, 2
pc.figsize = ((pc.c + 1.5) * pc.mlt, (pc.r + 0.75) * pc.mlt)
fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
# pc.axes = np.array([[axes]])
# pc.axes= axes
pc.onerow = False
pc.axes = np.array([[axes]])
pc.i, pc.j = 0, 0

pc.ax_ts(15, 1.1)
pc.l_fs(12, 0.85)
ts = 15
pc.xyt_ls(ts, ts)
pc.ax_ls(18)
pc.kde_lw = 3
pc.p_lab_spec[2] = 14
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300
pc.ms = 100
ax_space = 5
hrat = 3.75
pc.annot_fs = 8
sc_lw = 3
store_diffs_wide = store_diffs.pivot(index='q_idx_to_perturb', columns='q_idx', values='value')
sns.heatmap(store_diffs_wide, annot=True, fmt='.1f',cbar=False)
pc.ax.set_xlabel('Q to perturb')
pc.ax.set_ylabel('Q change')
pc.ax.text(pc.p_lab_spec[0] + 0.05, pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes,
           fontweight='bold',
           va='top', ha='right',
           fontsize=pc.p_lab_spec[2])
pc.ax.set_xticks(np.arange(len(store_diffs_wide.columns)) + 0.5)
pc.ax.set_xticklabels([f'Q{q + 1}' for q in range(len((store_diffs_wide.index)))], rotation='horizontal')
pc.ax.set_yticks(np.arange(len(store_diffs_wide.columns)) + 0.5)
pc.ax.set_yticklabels([f'Q{q + 1}' for q in range(len((store_diffs_wide.index)))], rotation='horizontal')
pc.ax.set_title(f'Perturbation confusion matrix')
plt.tight_layout()


if bools.saveFig:
    plt.savefig(
        f"{paths.plots_path}best_layer_latent_perturbation.pdf", dpi=300)
    plt.savefig(
        f"{paths.plots_path}{fig_no}_p3_test2_latent_perturbation.pdf", dpi=300)

# #%%
#
#     plt.close('all')
#     pc.r, pc.c, pc.mlt = 2, 9, 2.5
#     pc.i = 0
#
#     pc.ax_ts(12, 1.25)
#     pc.l_fs(8, 0.85)
#     ts = 8
#     pc.xyt_ls(ts, 10)
#     pc.ax_ls(12)
#     pc.kde_lw = 3
#     pc.p_lab_spec[2] = 12
#     pc.p_lab_spec[0] = -0.1
#     pc.p_lab_spec[1] = 1.1
#     pc.dpi_val = 300
#     pc.figsize = ((pc.c + 0.5) * pc.mlt, (pc.r + .1) * pc.mlt)
#     fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
#     # axes = np.array([axes])
#     pc.axes = axes
#
#     ms = 110
#     alpha = 0.7
#     s_ec = '#ffedcb'
#     b_c = '#e19f20'
#     s_lw = 0.8
#     s_size = 2
#     s_out_size = 6
#     b_width = 0.4
#     v_width = 1.15
#     b_lw = 1.5
#     v_cols = {'lvl1': '#1f77b4', 'lvl2': '#009ac8', 'lvl3': '#00b9c2'}
#     lvl_name = 'lvl3'
#
#     for pc.j, (latent_qs_q, latent_qs_perturbed_q) in enumerate(zip(latent_qs.T, latent_qs_perturbed.T)):
#         pc.i = 0
#         ys = pd.DataFrame()
#         ys['ae_org'] = latent_qs_q.tolist()
#         ys['ae_per'] = latent_qs_perturbed_q.tolist()
#         ys['ae_diff'] = ys['ae_per'] - ys['ae_org']
#         sns.scatterplot(data=ys, x='ae_org', y='ae_per', ax=pc.ax, edgecolor='black', color=s_ec,
#                         linewidth=s_lw,
#                         size=s_size, legend=False)
#         pc.ax.set_xlim([-3, 3])
#         pc.ax.set_ylim([-3, 3])
#         pc.ax.plot([-3, 3], [-3, 3], transform=pc.ax.transAxes)
#         pc.ax.set_aspect('equal', adjustable='box')
#         pc.ax.set_title(f'SAE PHQ-9 Q{pc.j + 1}')
#
#         pc.ax.set_xlabel('Original score')
#         pc.ax.set_ylabel('Perturbed score')
#
#         pc.i = 1
#         sns.histplot(data=ys, x='ae_diff', ax=pc.ax, kde=True, stat='probability')
#         pc.ax.set_xlim([-1.25, 1.75])
#         pc.ax.set_title(f'\nPerturbed - Original')
#         # pc.ax.set_title(f'\nSAE PHQ-9 Q{pc.j+1}\nPerturbed - Original')
#         pc.ax.set_xlabel('Score difference')
#         # pc.ax.set_ylabel('Perturbed score')
#
#     plt.suptitle(
#         f'Model: {model_name} - layer: {layer_idx + 1}\nq-score-delta: {q_score_change}; p-force: {p_force}\nQ-target{q_idx_to_perturb + 1}')
#     plt.tight_layout()
#     plt.savefig(
#         f'{paths.plots_path}sae_latent_perturbation_layer-{layer_idx + 1}_q_target_idx-{q_idx_to_perturb}.pdf')
#
# # %%
# # # model_name_set = ['gemma2-9b-it']
# # model_name_set = ['gemma2-9b-it', 'gemma2-2b-it', 'MistralOo']
# # # model_name_set = ['gemma2-2b-it']
# # q_idxs = [1]
# # q_idx_to_perturb = 1  # Q2
# q_score_change = 1
# p_force = 1
# to_sample_qs = 'phq9'
# paths = Paths(files_dir=f"{llm_path}files_logits_perturbed", sub_path=llm_path, plots_subdir='', plots_subsubdir='',
#               files_data_dir=f'{llm_path}_data')
#
# # computeDelta = True
# computeDelta = False
#
# for model_name in tqdm(model_name_set):
#     print(model_name)
#     sample_config = SampleLogitsConfig(paths, model_name=model_name, qs_name='phq9', instr_name='instr3', temp='',
#                                        top_p='')
#
#     ### SAE set up stuff experiement and hyperparams
#     loss_plot_dir = f'_plots/loss/{exp_name}/{model_name}/'
#     loss_dir = f'loss/{exp_name}/{model_name}/'
#     metric_dir = f'metrics/{exp_name}/{model_name}/'
#     model_dir = f'saved_models/{exp_name}/{model_name}/'
#     delta_dir = f'saved_delta/{exp_name}/{model_name}/'
#     prelim_plot_dir = f'_plots/prelim/{exp_name}/{model_name}/'
#     latents_plot_dir = f'_plots/sae_latents/{exp_name}/{model_name}/'
#     Path(delta_dir).mkdir(parents=True, exist_ok=True)
#     Path(latents_plot_dir).mkdir(parents=True, exist_ok=True)
#
#     layer_idxs = np.arange(sample_config.L // 2, sample_config.L)
#     # Layerwise model and perturbation
#     for l, layer_idx in tqdm(enumerate(layer_idxs)):
#         layer_name = f'layer-{layer_idx + 1}'
#         print(layer_name)
#         best_config = [m for m in os.listdir(model_dir) if layer_name in m]
#         best_config = [t.replace('.pt', '').replace('best_', '') for t in best_config]
#         best_config_dict = [{s.split('-')[0]: '-'.join(s.split('-')[1:]) for s in t.split('^^')} for t in best_config][
#             0]
#         best_hyper_dict = {k: v for k, v in best_config_dict.items() if k not in ['model', 'layer']}
#         hparams = list(best_config_dict.values())
#         q_case = hparams[2]
#         batch_size = int(hparams[3])
#         optim_lr = float(hparams[4])
#         sparsity_coeff = float(hparams[5])
#         qs_coeff = int(hparams[6])
#         sae_factor = int(hparams[7])
#         tied_weights = hparams[8] == 'True'
#         sae_name = hparams[9]
#
#         exp_config = ExpConfig(sample_config, q_case=q_case, do_zscores=do_zscores, max_epochs=max_epochs)
#
#         # Get data
#         meta_dataset = torch.load(f'data/{exp_name_data}/{exp_config.dataset_fname}.pt', weights_only=False)
#         n_train = int(0.7 * len(meta_dataset.subs))
#         n_val = int(0.15 * len(meta_dataset.subs))
#         n_test = int(0.15 * len(meta_dataset.subs))
#
#         exp_config.d_h = list(meta_dataset.avg_features.values())[0].shape[1]
#         sae_cfg = SAE_Config(exp_config, device=device_name, x_m=sae_factor, sparse_coeff=sparsity_coeff,
#                              tied_weights=tied_weights, qs_coeff=qs_coeff)
#
#         # train_dataset = MyDataset(meta_dataset, None, n_train, l, device=device_name)
#         val_dataset = MyDataset(meta_dataset, n_train, n_train + n_val, l, device=device_name)
#         # test_dataset = MyDataset(meta_dataset, n_train + n_val, None, l, device=device_name)
#
#         # test_dataloader = DataLoader(test_dataset, batch_size=len(test_dataset.subs), shuffle=True)
#         test_dataloader = DataLoader(val_dataset, batch_size=len(val_dataset.subs), shuffle=True)
#         test_features, test_labels = next(iter(test_dataloader))
#         test_features = test_features.to(device_name)
#         test_labels = test_labels.to(device_name)
#
#         layer_name = f'layer-{layer_idx + 1}'
#         model_layer_config = best_config[0]
#         # model_layer_config = [c for c in best_configs if layer_name in c][0]
#         # best_model_config = [{s.split('-')[0]: '-'.join(s.split('-')[1:]) for s in t.split('^^')} for t in best_config_dict]
#         model_fp = f'{model_dir}best_{model_layer_config}.pt'
#
#         # Load SAE model
#         if 'SAE2' in sae_name:
#             model_sae = SAE2_a(sae_cfg)
#         if 'SAE3' in sae_name:
#             model_sae = SAE3(sae_cfg)
#         if 'SAE4' in sae_name:
#             model_sae = SAE4(sae_cfg)
#
#         model_sae.load_state_dict(torch.load(model_fp))
#
#         for q_idx_to_perturb in tqdm(q_idxs):
#             print(f'Q: {q_idx_to_perturb + 1}')
#             deltaS_fp = f'{delta_dir}deltaS_{model_layer_config}^^q_target_idx-{q_idx_to_perturb}^^s_change-{q_score_change}.pt'
#             # print(deltaS_fp)
#
#             deltaS = get_perturb_delta(model_sae, deltaS_fp, q_idx_to_perturb=q_idx_to_perturb,
#                                        score_change=q_score_change, loadDelta=(not computeDelta), saveDelta=True,
#                                        device_name=device_name)
#             # deltaS = get_perturb_delta(model_sae, deltaS_fp, q_idx_to_perturb=q_idx_to_perturb, score_change=q_score_change,
#             #                            device_name=device_name)
#
#             # Forward pass get original states
#             latent_qs, _, _ = model_sae.infer(test_features)
#             # Get perturbed states
#             latent_qs_perturbed, _, h_rec_perturbed = model_sae.infer(test_features, delta_s=deltaS,
#                                                                       p_force=p_force * 1)
#
#             plt.close('all')
#             pc.r, pc.c, pc.mlt = 2, 9, 2.5
#             pc.i = 0
#
#             pc.ax_ts(12, 1.25)
#             pc.l_fs(8, 0.85)
#             ts = 8
#             pc.xyt_ls(ts, 10)
#             pc.ax_ls(12)
#             pc.kde_lw = 3
#             pc.p_lab_spec[2] = 12
#             pc.p_lab_spec[0] = -0.1
#             pc.p_lab_spec[1] = 1.1
#             pc.dpi_val = 300
#             pc.figsize = ((pc.c + 0.5) * pc.mlt, (pc.r + .1) * pc.mlt)
#             fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
#             # axes = np.array([axes])
#             pc.axes = axes
#
#             ms = 110
#             alpha = 0.7
#             s_ec = '#ffedcb'
#             b_c = '#e19f20'
#             s_lw = 0.8
#             s_size = 2
#             s_out_size = 6
#             b_width = 0.4
#             v_width = 1.15
#             b_lw = 1.5
#             v_cols = {'lvl1': '#1f77b4', 'lvl2': '#009ac8', 'lvl3': '#00b9c2'}
#             lvl_name = 'lvl3'
#
#             for pc.j, (latent_qs_q, latent_qs_perturbed_q) in enumerate(zip(latent_qs.T, latent_qs_perturbed.T)):
#                 pc.i = 0
#                 ys = pd.DataFrame()
#                 ys['ae_org'] = latent_qs_q.tolist()
#                 ys['ae_per'] = latent_qs_perturbed_q.tolist()
#                 ys['ae_diff'] = ys['ae_per'] - ys['ae_org']
#                 sns.scatterplot(data=ys, x='ae_org', y='ae_per', ax=pc.ax, edgecolor='black', color=s_ec,
#                                 linewidth=s_lw,
#                                 size=s_size, legend=False)
#                 pc.ax.set_xlim([-3, 3])
#                 pc.ax.set_ylim([-3, 3])
#                 pc.ax.plot([-3, 3], [-3, 3], transform=pc.ax.transAxes)
#                 pc.ax.set_aspect('equal', adjustable='box')
#                 pc.ax.set_title(f'SAE PHQ-9 Q{pc.j + 1}')
#
#                 pc.ax.set_xlabel('Original score')
#                 pc.ax.set_ylabel('Perturbed score')
#
#                 pc.i = 1
#                 sns.histplot(data=ys, x='ae_diff', ax=pc.ax, kde=True, stat='probability')
#                 pc.ax.set_xlim([-1.25, 1.75])
#                 pc.ax.set_title(f'\nPerturbed - Original')
#                 # pc.ax.set_title(f'\nSAE PHQ-9 Q{pc.j+1}\nPerturbed - Original')
#                 pc.ax.set_xlabel('Score difference')
#                 # pc.ax.set_ylabel('Perturbed score')
#
#             plt.suptitle(
#                 f'Model: {model_name} - layer: {layer_idx + 1}\nq-score-delta: {q_score_change}; p-force: {p_force}\nQ-target{q_idx_to_perturb + 1}')
#             plt.tight_layout()
#             plt.savefig(
#                 f'{latents_plot_dir}sae_latent_perturbation_layer-{layer_idx + 1}_q_target_idx-{q_idx_to_perturb}.pdf')
