import os, re
import copy
import einops
# from examples.Official_demo import anorm
from tqdm import tqdm
import string
from datetime import datetime

from scipy import stats
import spgl1
import pandas as pd
import numpy as np
import sklearn.metrics
# from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from natsort import natsorted
import torch
from sympy.plotting.textplot import rescale
from torch import nn
from torch.utils.data import Dataset, DataLoader
import random
import torch.nn.functional as F
from torch import nn
import seaborn as sns
import gc
# from sklearn.metrics import accuracy_score, balanced_accuracy_score
from pathlib import Path
from itertools import product
import json
# from m.utils import flush, download_blob, set_seed, get_ram, get_vram

# download_blob('llm-bucket-res', 'saved_delta.zip', 'saved_delta.zip')
# download_blob('llm-bucket-res', '_logits_data.zip', '_logits_data.zip')
# download_blob('llm-bucket-res', 'saved_models_compress.zip', 'saved_models_compress.zip')


import matplotlib

torch.set_grad_enabled(False)

# from sae_new import train_config

if os.uname()[0] == 'Darwin':  # if on mac
    matplotlib.use('Qt5Agg')
    device_name = 'mps'
    import matplotlib.pyplot as plt

    plt.ion()
else:
    matplotlib.use('Agg')
    # matplotlib.use('TkAgg')
    matplotlib.get_backend()
    device_name = 'cuda'
    import matplotlib.pyplot as plt

from _objects.model_configs import *
from _objects.configs import *
from _objects.plot_config import *
from _objects.data_configs import maps
from _objects.data_configs import QsConfig

from _objects.sae_models import *
from test2_ssae.prepare_model_datasets import MetaDataset, MyDataset
from test2_ssae.perturb_utils import load_pert_logits

pc = PlotConfig()

instr_name_str = 'instr3'
testCase_dir = 'test2_ssae'
subPath = 'analysis_sampling'
test_path = f"{testCase_dir}/{subPath}/"
fig_no = 'Fig5'

do_zscores = True
ds_type = 'test'
max_epochs = 1000

files_path_new = f'{testCase_dir}/logit_sampling/files_logits/paired/'  # new repo path
files_path_p = f'../construct_detector/logit_sampling/files_logits_perturbed_{ds_type}/paired/'  # old path
files_path_o = '../online_tasks/qs_structure/_analysis/_llm_based/files_logits/paired/'  # old path


prev_path = '../construct_detector/'  # old path
prev_path_new = ''
paths = Paths(files_dir=files_path_p, plots_subdir='', plots_subsubdir='')
paths.original_logits_path = files_path_o
paths.plots_path = f'{testCase_dir}/_plots/logits_perturbation/'
paths.files_path_new = files_path_new
paths.files_path_o = files_path_o

bools = Bools()
qs_config = QsConfig()
# %%
exp_name_data = 'SAE_Mistral_gemma_llama_exp_v1'
exp_name = 'SAE_Mistral_gemma_llama_exp_v3_best'
model_name = 'gemma2-9b-it'
to_sample_qs = 'phq9'

# model = None
# end_idx = 1
# # steer_mlt_set = [-1.5, -1, -0.5, 0.5, 1, 1.5]
# steer_mlt_set = [-1, 1]
# bools.saveMe = True
# bools.loadMe = False
bools.saveMe = False
bools.loadMe = True
bools.savePlots= True
bools.savePlots= False

d_thr = 0.3

sample_config = SampleLogitsConfig(paths, model_name=model_name, qs_name='phq9', instr_name='instr3', temp='',
                                   top_p='')
sample_config.gen_fname = 'gb'
sample_config.permute_labels = False
sample_config.label_letters = None
label_letters = list(string.ascii_uppercase)[:qs_config.qs_n_lab[to_sample_qs]]
q_scores = list(maps[to_sample_qs].values())
# q_responses = list(maps[to_sample_qs].keys())

if sample_config.permute_labels:
    tmp_zip = list(zip(q_scores, label_letters))
    random.shuffle(tmp_zip)
    q_scores, label_letters = zip(*tmp_zip)
sample_config.label_letters = label_letters
sample_config.q_scores = q_scores
sample_config.label_scores = {k: v for k, v in zip(label_letters, q_scores)}

model_dir = f'{prev_path}saved_models/{exp_name}/{model_name}/'

exp_config = ExpConfig(sample_config, q_case='9q', do_zscores=do_zscores, max_epochs=max_epochs)

# Get data
meta_dataset = torch.load(f'{prev_path}data/{exp_name_data}/{exp_config.dataset_fname}.pt', weights_only=False)
n_train = int(0.7 * len(meta_dataset.subs))
n_val = int(0.15 * len(meta_dataset.subs))
n_test = int(0.15 * len(meta_dataset.subs))

exp_config.d_h = list(meta_dataset.avg_features.values())[0].shape[1]
# sae_cfg = SAE_Config(exp_config, device=device_name, x_m=sae_factor, sparse_coeff=sparsity_coeff,
#                      tied_weights=tied_weights, qs_coeff=qs_coeff)

paths.responses_path = f"{files_path_p}{''.join(sample_config.label_letters)}/subjects"
paths.save_path = f"{paths.files_path_new}{''.join(sample_config.label_letters)}/"
Path(paths.save_path).mkdir(parents=True, exist_ok=True)
scores_logits_wide, scores_logits = load_pert_logits(paths, sample_config, bools, ds_type)
# %% Plot logit exp scores
q_names = natsorted(scores_logits['question'].unique())
q_names = [n for n in q_names if 'lvl3' in n]
steer_mlts = natsorted(scores_logits['hs_steer_mlt'].unique())
steer_mlts = [f'steer_mlt_{s}' for s in
              sorted([float(s.split('_')[-1]) for s in steer_mlts if s != 'steer_mlt_0.0'])]

ms = 110
alpha = 0.7
s_ec = '#ffedcb'
b_c = '#e19f20'
s_lw = 0.8
s_size = 2
s_out_size = 6
b_width = 0.4
v_width = 1.15
b_lw = 1.5
v_cols = {'lvl1': '#1f77b4', 'lvl2': '#009ac8', 'lvl3': '#00b9c2'}
lvl_name = 'lvl3'

plt.close('all')
pc.r, pc.c, pc.mlt = len(steer_mlts), len(q_names), 1.8
pc.figsize = ((pc.c + 1.4) * pc.mlt, (pc.r +1.4) * pc.mlt)
fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, sharex=False, sharey=True)
pc.i, pc.j = 0, 0
pc.onerow = False
pc.axes = axes
pc.ax_ts(13.5, 13*1.1/12)
pc.l_fs(8, 0.85)
ts = 12
pc.xyt_ls(ts, ts)
pc.ax_ls(14)
pc.kde_lw = 2
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = 0.05
pc.p_lab_spec[1] = 1.2
pc.dpi_val = 300
# axes = np.array([axes])

for pc.j, q_name in enumerate(q_names):
    q_s_resp_wide = scores_logits_wide[scores_logits_wide['question'] == q_name]
    q_name_short = re.sub('lvl.*_', '', q_name).upper()
    for pc.i, steer_mlt in enumerate(steer_mlts):

        if steer_mlt == 'steer_mlt_0.0':
            continue
        steer_mlt_val = float(steer_mlt.split('_')[-1])

        tmp_diffs = q_s_resp_wide[f'diff_{steer_mlt}'].values
        tmp_std = np.std(tmp_diffs)
        val_comp = float(steer_mlt.split('_')[-1]) < 0
        d_eff = (tmp_diffs.mean() - 0) / tmp_std
        test_type = 'less' if val_comp else 'greater'
        tv, p_res = stats.ttest_1samp(tmp_diffs, popmean=0, alternative=test_type)
        hist_col = 'tab:green' if val_comp else "tab:red"
        hist_col = hist_col if (p_res < 0.05 and np.abs(d_eff) > d_thr) else 'tab:blue'
        hist_col = hist_col if (p_res < 0.05) else 'tab:gray'

        sns.histplot(data=q_s_resp_wide, x=f'diff_{steer_mlt}', ax=pc.ax, kde=True, stat='density',
                     color=hist_col,line_kws={'linewidth': pc.kde_lw})
        pc.ax.axvline(x=0, lw=s_lw * 3, color='k')
        pc.ax.set_xlim([-4, 4])
        # pc.ax.set_title(f'\nPerturbed - Original')
        # pc.ax.set_title(f'\nSAE PHQ-9 Q{pc.j+1}\nPerturbed - Original')
        pc.ax.set_xlabel(f'PHQ9 {q_name_short}')
        # pc.ax.set_title(f"{steer_mlt}; d:{d_eff:.2f}\ntv:{tv:.3f},p-v:{p_res:.3e}")
        direction_txt = 'Positive' if steer_mlt_val>=0 else 'Negative'
        pert_text = f"\n{direction_txt} perturbation strength: {steer_mlt_val}\n\n" if pc.j==3 else ''

        pc.ax.set_title(f"{pert_text}Effect size: {d_eff:.2f}\np: {p_res:.2e}")
        pc.ax.text(pc.p_lab_spec[0] + 0.05, pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes,
                   fontweight='bold',
                   va='top', ha='right',
                   fontsize=pc.p_lab_spec[2])

        # q_s_resps = logits_responses[(logits_responses['question']==q_name) & (logits_responses['hs_steer_mlt']==steer_mlt)]

plt.suptitle(f'Expected score difference after perturbation: {sample_config.model_name_plot}')
plt.tight_layout()
if bools.savePlots:
    plt.savefig(f'{paths.plots_path}logits_exp_score_pert_{ds_type}_{sample_config.model_name}.pdf', dpi=300)
    plt.savefig(
        f"{paths.plots_path}{fig_no}_test2_logits_perturbation.pdf", dpi=300)
