import os, re
from tqdm import tqdm

from scipy import stats
import pandas as pd
import numpy as np
import sklearn.metrics
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from natsort import natsorted
import torch
from sympy.plotting.textplot import rescale
from torch import nn
from torch.utils.data import Dataset, DataLoader
import random
import torch.nn.functional as F
from torch import nn
import seaborn as sns
import gc
from sklearn.metrics import accuracy_score, balanced_accuracy_score
from pathlib import Path
from itertools import product
import json
from meta_utils.utils import flush, download_blob, set_seed, get_ram, get_vram

# download_blob('llm-bucket-res', 'data_sae.zip', 'data_sae.zip')

import matplotlib

if os.uname()[0] == 'Darwin':  # if on mac
    matplotlib.use('Qt5Agg')
    device_name = 'mps'
    import matplotlib.pyplot as plt

    plt.ion()
else:
    matplotlib.use('Agg')
    # matplotlib.use('TkAgg')
    matplotlib.get_backend()
    device_name = 'cuda'
    import matplotlib.pyplot as plt

from sae_models import *

# cwd = os.getcwd()
# cwd_split = cwd.split('/')
# qs_structure_dir = f'../online_tasks/qs_structure/'
# qs_int_dir = f'../online_tasks/qs_structure/'
# files_path = qs_structure_dir + '_analysis/_llm_based/files/'
# qs_str_dat_dir = f'{qs_structure_dir}_analysis/_data/'
# states_path = f"{files_path}states/subjects"
# base_task = 'qs_structure'
# cwd_base = '/'.join(cwd_split[:np.argwhere([p == "online_tasks" for p in cwd_split])[0][0] + 1])
# new_path = f"{cwd_base}/{base_task}/{path_to_go}/"
# os.chdir(new_path)

from objects.model_configs import *
from meta_objects.configs import *
from meta_objects.plot_config import *

from prepare_model_datasets import MetaDataset, MyDataset

pc = PlotConfig()
paths = Paths(files_dir='files', plots_subdir='', plots_subsubdir='')
paths.plots_path = '_plots/'

# Call this function at the very start of your script

# %% Setup configs and load data
exp_name = 'SAE_Mistral_gemma_llama_exp_v3'
exp_name_data = 'SAE_Mistral_gemma_llama_exp_v1'

es_patience = 30  # 30
es_delta = 0
torch_v = torch.__version__
seed_value = 80
set_seed(seed_value)
save_best_model = False
savePlots = True

do_zscores = True
# model_name_set = ['MistralOo', 'gemma2-9b-it', 'llama31-8b-it', 'gemma2-2b-it', 'llama32-3b-it']
# model_name_set = ['gemma2-9b-it']
# model_name_set = ['MistralOo', 'gemma2-9b-it', 'llama31-8b-it']
model_name_set = ['gemma2-9b-it', 'gemma2-2b-it','MistralOo']
# model_name_set = ['gemma2-2b-it']
q_case_set = ['9q']

# batch_size_set = [32, 64]
batch_size_set = [32]

# optim_lr_set = [1e-4, 1e-3]
optim_lr_set = [1e-3, 1e-4]
# optim_lr_set = [1e-3]
# optim_lr_set = [1e-4]

sparsity_coeff_set = [0.05, 0.1, 0.2]
# sparsity_coeff_set = [0.05]

sae_factor_set = [1, 2, 4]
# sae_factor_set = [1]

tied_weights_set = {'tied': True}
qs_coeff_set = [1]

# sae_set = ['SAE2_z']
# sae_set = ['SAE3_z']
sae_set = ['SAE4_z']
max_epochs = 1000
exp_json_config = {'name': exp_name, 'torch_v': torch_v, 'seed_value': seed_value, 'es_patience': es_patience,
                   'max_epochs': max_epochs, 'es_delta': es_delta, 'model_name_set': model_name_set,
                   'q_case_set': q_case_set, 'batch_size_set': batch_size_set, 'optim_lr_set': optim_lr_set,
                   'sparsity_coeff_set': sparsity_coeff_set, 'qs_coeff_set': qs_coeff_set,
                   'sae_factor_set': sae_factor_set, 'tied_weights_set': tied_weights_set, 'sae_set': sae_set}

Path(f'exp_logs/').mkdir(parents=True, exist_ok=True)
with open(f'exp_logs/{exp_name}.json', 'w') as fp:
    json.dump(exp_json_config, fp)

hparam_sets = list(
    product(model_name_set, q_case_set, batch_size_set, optim_lr_set, sparsity_coeff_set, qs_coeff_set,
            sae_factor_set,
            tied_weights_set.keys(), sae_set))
# hparam_sets = [hparam_sets[0]]
hparams = hparam_sets[0]
# %%
for hparams in tqdm(hparam_sets):
    # for hparams in tqdm(hparam_sets[200:205]):
    # for hparams in hparam_sets[2232:2233]:
    # for hparams in hparam_sets[406:407]:
    model_name = hparams[0]
    # do_full = do_full_set[hparams[1]]
    q_case = hparams[1]
    batch_size = hparams[2]
    optim_lr = hparams[3]
    sparsity_coeff = hparams[4]
    qs_coeff = hparams[5]
    sae_factor = hparams[6]
    tied_weights = tied_weights_set[hparams[7]]
    sae_name = hparams[8]

    loss_plot_dir = f'_plots/loss/{exp_name}/{model_name}/'
    loss_dir = f'loss/{exp_name}/{model_name}/'
    metric_dir = f'metrics/{exp_name}/{model_name}/'
    model_dir = f'saved_models/{exp_name}/{model_name}/'
    prelim_plot_dir = f'_plots/prelim/{exp_name}/{model_name}/'
    Path(model_dir).mkdir(parents=True, exist_ok=True)
    Path(loss_plot_dir).mkdir(parents=True, exist_ok=True)
    Path(loss_dir).mkdir(parents=True, exist_ok=True)
    Path(metric_dir).mkdir(parents=True, exist_ok=True)
    Path(prelim_plot_dir).mkdir(parents=True, exist_ok=True)

    sample_config = SampleLogitsConfig(paths, model_name=model_name, qs_name='phq9', instr_name='instr3', temp='',
                                       top_p='')
    exp_config = ExpConfig(sample_config, q_case=q_case, do_zscores=do_zscores, max_epochs=max_epochs)

    # Get data
    meta_dataset = torch.load(f'data/{exp_name_data}/{exp_config.dataset_fname}.pt', weights_only=False)

    n_train = int(0.7 * len(meta_dataset.subs))
    n_val = int(0.15 * len(meta_dataset.subs))
    n_test = int(0.15 * len(meta_dataset.subs))

    # layer_idxs = np.arange(sample_config.L // 2, sample_config.L)[0:1]
    layer_idxs = np.arange(sample_config.L // 2, sample_config.L)

    exp_config.d_h = list(meta_dataset.avg_features.values())[0].shape[1]
    sae_cfg = SAE_Config(exp_config, device=device_name, x_m=sae_factor, sparse_coeff=sparsity_coeff,
                         tied_weights=tied_weights, qs_coeff=qs_coeff)
    for l, layer_idx in tqdm(enumerate(layer_idxs)):
        layer_name = f'layer-{layer_idx + 1}'

        print(
            f'model_name: {model_name}, layer:{layer_idx + 1}, q_case: {q_case}, batch_size: {batch_size}, optim_lr: {optim_lr}, sparsity_coeff: {sparsity_coeff}, qs_coeff:{qs_coeff}, sae_factor: {sae_factor}, tied_weights: {tied_weights}, sae_name: {sae_name}')

        train_config = f'model-{model_name}^^layer-{layer_idx + 1}^^q_case-{q_case}^^batch_size-{batch_size}^^optim_lr-{optim_lr}^^sparsity_coeff-{sparsity_coeff}^^qs_coeff-{qs_coeff}^^sae_factor-{sae_factor}^^tied_weights-{tied_weights}^^sae_name-{sae_name}'
        train_config_dict = {'model': model_name, 'layer': layer_idx + 1, 'q_case': q_case, 'batch_size': batch_size,
                             'optim_lr': optim_lr, 'sparsity_coeff': sparsity_coeff, 'qs_coeff': qs_coeff,
                             'sae_factor': sae_factor, 'tied_weights': tied_weights, 'sae_name': sae_name}

        train_dataset = MyDataset(meta_dataset, None, n_train, l, device=device_name)
        val_dataset = MyDataset(meta_dataset, n_train, n_train + n_val, l, device=device_name)
        test_dataset = MyDataset(meta_dataset, n_train + n_val, None, l, device=device_name)

        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

        loss_plot_fp = f'{loss_plot_dir}loss_{train_config}.pdf'
        loss_fp = f'{loss_dir}loss_{train_config}.csv'
        preds_fp = f'{metric_dir}preds_{train_config}.csv'
        metric_fp = f'{metric_dir}metrics_{train_config}.csv'
        model_fp = f'{model_dir}best_{train_config}.pt'
        heatmap_fp = f'{prelim_plot_dir}heatmap_{train_config}.pdf'
        scatter_fp = f'{prelim_plot_dir}scatter_{train_config}.pdf'

        # if (Path(loss_plot_fp).exists() and Path(loss_fp).exists() and Path(model_fp).exists() and Path(
        #         heatmap_fp).exists() and Path(scatter_fp).exists()):
        if (Path(loss_plot_fp).exists() and Path(loss_fp).exists() and Path(
                heatmap_fp).exists() and Path(scatter_fp).exists() and Path(metric_fp).exists() and Path(
            preds_fp).exists()):
            print('\t already done')
        # elif not do_full:
        #     print('Skipping 2q')
        # elif do_full:
        #     print('Skipping 9q')
        else:
            print('\tTraining')
            # %% Train
            early_stopper = EarlyStopper(patience=es_patience, min_delta=es_delta)
            if 'SAE2' in sae_name:
                model = SAE2(sae_cfg)
            if 'SAE3' in sae_name:
                model = SAE3(sae_cfg)
            if 'SAE4' in sae_name:
                model = SAE4(sae_cfg)

            gather_train_loss = []
            gather_val_loss = []
            gather_train_loss_dict = {'L_rec': [], 'L_sparse': [], 'L_qs': [], 'L_sev': []}
            gather_val_loss_dict = {'L_rec': [], 'L_sparse': [], 'L_qs': [], 'L_sev': []}
            # gather_train_loss_dict = {'L_rec': [], 'L_sparse': [], 'L_qs': []}
            # gather_val_loss_dict = {'L_rec': [], 'L_sparse': [], 'L_qs': []}
            optimizer = torch.optim.Adam(model.parameters(), lr=optim_lr)
            best_epoch_bool = []
            for epoch in tqdm(range(exp_config.max_epochs)):
                # Training
                model.train()
                train_total_loss = 0
                train_loss_dict_store = {'L_rec': 0, 'L_sparse': 0, 'L_qs': 0, 'L_sev': 0}
                # train_loss_dict_store = {'L_rec': 0, 'L_sparse': 0, 'L_qs': 0}
                for (train_features, train_labels) in train_dataloader:
                    train_features = train_features.type(torch.float32).to(device_name)
                    train_labels = train_labels.type(torch.float32).to(device_name)

                    # Reset gradients before each pass
                    optimizer.zero_grad()

                    train_loss, train_loss_dict, latent_qs, h_rec = model(train_features, train_labels)
                    train_loss.backward()

                    optimizer.step()
                    train_total_loss += train_loss.item()
                    train_loss_dict_store = {k: v + train_loss_dict[k].item() for k, v in train_loss_dict_store.items()}

                train_total_loss = train_total_loss / len(train_dataloader)
                gather_train_loss.append(train_total_loss)

                train_loss_dict_store = {k: v / len(train_dataloader) for k, v in train_loss_dict_store.items()}
                gather_train_loss_dict = {k: v + [train_loss_dict_store[k]] for k, v in gather_train_loss_dict.items()}

                # Validation
                model.eval()
                val_total_loss = 0
                val_loss_dict_store = {'L_rec': 0, 'L_sparse': 0, 'L_qs': 0, 'L_sev': 0}
                # val_loss_dict_store = {'L_rec': 0, 'L_sparse': 0, 'L_qs': 0}
                with torch.no_grad():
                    for (val_features, val_labels) in val_dataloader:
                        val_features = val_features.type(torch.float32).to(device_name)
                        val_labels = val_labels.type(torch.float32).to(device_name)

                        val_loss, val_loss_dict, latent_qs, h_rec = model(val_features, val_labels)
                        val_total_loss += val_loss.item()

                        val_loss_dict_store = {k: v + val_loss_dict[k].item() for k, v in val_loss_dict_store.items()}

                val_total_loss = val_total_loss / len(val_dataloader)
                gather_val_loss.append(val_total_loss)

                val_loss_dict_store = {k: v / len(val_dataloader) for k, v in val_loss_dict_store.items()}
                gather_val_loss_dict = {k: v + [val_loss_dict_store[k]] for k, v in gather_val_loss_dict.items()}

                print("epoch : {}/{}, train loss = {:.6f}".format(epoch + 1, max_epochs, train_total_loss))
                print("\t\t\t\tval loss = {:.6f}".format(val_total_loss))
                best_epoch_bool.append(False)
                early_stopper(val_total_loss, model, epoch)

                if early_stopper.early_stop:
                    gather_train_loss_dict = {k + '_train': v for k, v in gather_train_loss_dict.items()}
                    gather_val_loss_dict = {k + '_val': v for k, v in gather_val_loss_dict.items()}

                    print("Early stopping triggered")
                    best_epoch_bool[early_stopper.best_epoch] = True
                    loss_dict = {'train_loss': gather_train_loss, 'val_loss': gather_val_loss,
                                 'is_best': best_epoch_bool} | gather_train_loss_dict | gather_val_loss_dict | train_config_dict
                    loss_df = pd.DataFrame(loss_dict).reset_index(names='epoch')
                    loss_df.to_csv(loss_fp, index=False)

                    break

            if early_stopper.best_model_state is not None:
                model.load_state_dict(early_stopper.best_model_state)
                if save_best_model:
                    print("Saving the best model from memory to disk.")
                    torch.save(early_stopper.best_model_state, model_fp)

            # %% Plot loss
            item_loss_names = ['L_rec', 'L_sparse', 'L_qs', 'L_sev']
            # item_loss_names = ['L_rec', 'L_sparse', 'L_qs']
            loss_names_t = ['train_loss'] + [i + '_train' for i in item_loss_names]
            loss_names_v = ['val_loss'] + [i + '_val' for i in item_loss_names]
            item_loss_names = ['L_total'] + item_loss_names
            plt.close('all')
            pc.r, pc.c, pc.mlt = 1, len(item_loss_names), 2

            pc.figsize = ((pc.c + 4.15) * pc.mlt, (pc.r + 0.75) * pc.mlt)
            fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, sharey=False)
            axes = np.array([axes])
            pc.axes = axes
            # pc.onerow = True
            pc.onerow = False
            pc.i = 0
            pc.j = 0
            pc.ax_ts(10, 1.1)
            pc.l_fs(8, 0.85)
            pc.xyt_ls(16, 16)
            pc.ax_ls(16)
            pc.kde_lw = 3
            pc.p_lab_spec[2] = 14
            pc.p_lab_spec[0] = -0.05
            pc.p_lab_spec[1] = 1.05
            pc.dpi_val = 300

            for pc.j, (t_lab, val_lab) in enumerate(zip(loss_names_t, loss_names_v)):
                t_loss_df = loss_df[t_lab]
                v_loss_df = loss_df[val_lab]
                pc.ax.plot(t_loss_df, label=t_lab)
                pc.ax.plot(v_loss_df, label=val_lab)
                losses = np.array([t_loss_df + v_loss_df])
                if pc.j == 0:
                    pc.ax.vlines(x=early_stopper.best_epoch, ymin=0, ymax=losses.max() * 1.1, color='r',
                                 linestyle='--',
                                 label=f'Best epoch {early_stopper.best_epoch}')
                pc.ax.set_xlabel('epoch')
                pc.ax.set_ylabel('loss')
                pc.ax.set_title(item_loss_names[pc.j])
                pc.ax.legend()
            plt.tight_layout()
            if savePlots:
                plt.savefig(loss_plot_fp, dpi=300)

            # %% Plot heatmap
            plt.close('all')
            if q_case == '9q':
                phq9_q_names = ['phq9_q' + str(q + 1) + 's' for q in range(9)]
            elif q_case == '2q':
                phq9_q_names = ['phq9_q2s', 'phq9_q4s']
            elif q_case == 'q2Only':
                phq9_q_names = ['phq9_q2s']
            pc.r, pc.c, pc.mlt = 1, 2, 2.75
            pc.figsize = ((pc.c + 4.25) * pc.mlt, (pc.r + 3.25) * pc.mlt)

            pc.ax_ts(16, 1.1)
            pc.l_fs(12, 0.85)
            ts = 14
            pc.xyt_ls(ts, 10)
            pc.ax_ls(16)
            pc.kde_lw = 3
            pc.p_lab_spec[2] = 12
            pc.p_lab_spec[0] = -0.1
            pc.p_lab_spec[1] = 1.1
            pc.dpi_val = 300
            # if do_logits:
            #     vminval = 0
            #     vmaxval = 3
            # else:
            fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
            pc.i = 0
            pc.j = 0
            pc.onerow = True
            # axes = np.array([[axes]])
            axes = np.array([axes])
            # axes = np.array([axes]).T
            pc.axes = axes
            vminval = -3
            vmaxval = 3
            vminval = None
            vmaxval = None


            def rescale_np(ref, x):
                to_max = ref.max(axis=0)
                to_min = ref.min(axis=0)
                x_min = x.min(axis=0)
                x_max = x.max(axis=0)

                return (to_max - to_min) * (x - x_min) / (x_max - x_min + 1e-8) + to_min


            if do_zscores:
                x = torch.stack(val_dataloader.dataset.features).to(device_name)
                y = torch.stack(val_dataloader.dataset.labels).to(device_name)
                # preds = model(x, y)[2].to('cpu').detach().round().clamp(0,3).numpy().astype(int)
                # preds = (preds-preds.mean(axis=0))/preds.std(axis=0)

                ytrue = y.to('cpu').numpy()  # .round().clamp(0,3).numpy().astype(int)
                # ytrue = (ytrue - ytrue.mean(axis=0)) / ytrue.std(axis=0)

                preds = model(x, y)[2].to('cpu').detach().numpy()
                # preds = rescale_np(ytrue, preds)
                # preds = model(x, y)[2].to('cpu').detach()
                # preds = (preds - preds.mean(axis=0)) / preds.std(axis=0)
            else:
                x = torch.stack(val_dataloader.dataset.features).to(device_name)
                y = torch.stack(val_dataloader.dataset.labels).to(device_name)
                preds = model(x, y)[2].to('cpu').detach()
                preds = torch.argmax(F.softmax(preds, dim=-1), dim=-1).to('cpu')
                ytrue = y.to('cpu')

            ytruedf = pd.DataFrame(ytrue, columns=phq9_q_names).reset_index().rename(columns={'index': 'sub'})
            ytpreddf = pd.DataFrame(preds, columns=phq9_q_names).reset_index().rename(columns={'index': 'sub'})
            ytruedf['source'] = 'ppt'
            ytpreddf['source'] = 'ae'

            sns.heatmap(preds, ax=pc.ax, vmin=vminval, vmax=vmaxval, cmap='Oranges')
            pc.ax.set_title('Predicted scores')
            pc.ax.set_xlabel('Question')
            pc.ax.set_ylabel('Subject')

            pc.j = 1
            sns.heatmap(ytrue, ax=pc.ax, vmin=vminval, vmax=vmaxval, cmap='Oranges')
            pc.ax.set_title('True scores')
            pc.ax.set_xlabel('Question')
            pc.ax.set_ylabel('Subject')
            plt.tight_layout()
            if savePlots:
                plt.savefig(f'{heatmap_fp}', dpi=300)

            # %% Plot scatter
            ytpreddf_m = pd.melt(ytpreddf, id_vars=['sub', 'source'])
            ytruedf_m = pd.melt(ytruedf, id_vars=['sub', 'source'])
            y_all = pd.concat([ytpreddf_m, ytruedf_m], axis=0)
            y_all.to_csv(f'{preds_fp}', index=False)

            if q_case == '9q':
                if do_zscores:
                    pc.r, pc.c, pc.mlt = 1, 9, 2
                else:
                    pc.r, pc.c, pc.mlt = 2, 9, 2
            elif q_case == '2q':
                if do_zscores:
                    pc.r, pc.c, pc.mlt = 1, 2, 2
                else:
                    pc.r, pc.c, pc.mlt = 2, 2, 2
            elif q_case == 'q2Only':
                if do_zscores:
                    pc.r, pc.c, pc.mlt = 1, 1, 2
                else:
                    pc.r, pc.c, pc.mlt = 2, 1, 2

            pc.figsize = ((pc.c + 1.15) * pc.mlt, (pc.r + 0.75) * pc.mlt)
            fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, sharey=True, sharex=True)
            pc.axes = axes
            # pc.onerow = True
            pc.onerow = False
            pc.i = 0
            pc.j = 0
            pc.ax_ts(10, 1.1)
            pc.l_fs(12, 0.85)
            pc.xyt_ls(16, 16)
            pc.ax_ls(16)
            pc.kde_lw = 3
            pc.p_lab_spec[2] = 14
            pc.p_lab_spec[0] = -0.05
            pc.p_lab_spec[1] = 1.05
            pc.dpi_val = 300
            ms = 110
            alpha = 0.7
            s_ec = '#ffedcb'
            b_c = '#e19f20'
            s_lw = 0.8
            s_size = 2
            s_out_size = 6
            b_width = 0.4
            v_width = 1.15
            b_lw = 1.5
            v_cols = {'lvl1': '#1f77b4', 'lvl2': '#009ac8', 'lvl3': '#00b9c2'}
            lvl_name = 'lvl3'

            pred_stats = []

            for pc.j, (q_name) in enumerate(phq9_q_names):
                ys = y_all[y_all['variable'] == q_name]
                ys = ys.pivot(index='sub', columns='source', values='value')
                if not do_zscores:
                    y_all_acc = balanced_accuracy_score(ys['ppt'], ys['ae'])
                    cm = sklearn.metrics.confusion_matrix(ys['ppt'], ys['ae'])
                    disp = ConfusionMatrixDisplay(cm)
                    pc.i = 1
                    disp.plot(ax=pc.ax, colorbar=False)
                    print(f'accurracy for {q_name}: {y_all_acc}')
                    ttitle = f'{q_name} - b.acc: {y_all_acc:.2f}'
                    stat_dict = {}
                else:
                    r, p = stats.spearmanr(ys['ppt'], ys['ae'])
                    p = min(p * len(phq9_q_names), 1)
                    ttitle = f"{q_name}\nr (spear): {r:.3f}\n p-val: {p:.2e}"
                    stat_dict = {'q_name': q_name, 'r': r, 'p': p}

                pred_stats.append(stat_dict | train_config_dict)

                # sns.scatterplot(data=ys, x='ae', y='ppt', ax=pc.ax)
                pc.i = 0
                sns.scatterplot(ys, x='ppt', y='ae', ax=pc.ax, edgecolor='black', color=s_ec,
                                linewidth=s_lw,
                                size=s_size, legend=False)
                # sns.violinplot(ys, x='ae', y='ppt', ax=pc.ax, orient='h', inner=None, width=v_width,
                #                color=v_cols[lvl_name])
                # bp = sns.boxplot(data=ys, x='ae', y='ppt', ax=pc.ax, width=b_width, color=b_c,
                #                  linewidth=b_lw, orient='h', fliersize=s_out_size)
                # for patch in bp.patches:
                #     face_color = patch.get_facecolor()
                #     patch.set_facecolor((*face_color[:3], alpha))
                # sns.stripplot(data=ys, x='ae', y='ppt', ax=pc.ax, edgecolor='black', color=s_ec,
                #               linewidth=s_lw,
                #               size=s_size, orient='h')

                # pc.ax.set_title(f'{q_name} - b.acc: {y_all_acc:.2f}')
                pc.ax.set_title(f'{ttitle}')
                pc.ax.set_aspect('equal', adjustable='box')
                # ll, ul = -4, 5
                # if do_zscores:
                #     # pc.ax.set_title(f'{q_name}')
                #     # pc.ax.set_xlim([-3, 3.75])
                #     pc.ax.set_xlim([ll, ul])
                #     pc.ax.set_xticks(np.arange(ll, ul, 2))
                #     pc.ax.set_xticklabels(np.arange(ll, ul, 2))
                #     # pc.ax.set_yticks(np.arange(ll, ul))
                #     # pc.ax.set_yticklabels(np.arange(ll, ul))
                #     pc.ax.set_ylim([-1, 4])
                #     pc.ax.set_yticks(np.arange(0, 4))
                #     pc.ax.set_yticklabels(np.arange(0, 4))
                # else:
                #     # pc.ax.set_title(f'{q_name} - b.acc: {y_all_acc:.2f}')
                #     pc.ax.set_xlim([-0.75, 3.75])
                #     pc.ax.set_ylim([-0.75, 3.75])
                #     pc.ax.set_xticks(np.arange(0, 4))
                #     pc.ax.set_xticklabels(np.arange(0, 4))
                #     pc.ax.set_yticks(np.arange(0, 4))
                #     pc.ax.set_yticklabels(np.arange(0, 4))
                # if do_logits:
                # else:
                # pc.ax.set_xlabel('AE score')
                # pc.ax.set_ylabel('Subject score')

            plt.tight_layout()
            if savePlots:
                plt.savefig(f'{scatter_fp}', dpi=300)
            # plt.close('all')
            preds_stats_df = pd.DataFrame(pred_stats)
            preds_stats_df.to_csv(f'{metric_fp}', index=False)

            flush(model, optimizer, early_stopper, device_name)
            if device_name != 'mps':
                del train_loss_dict, train_loss
# os.system("git add .; git commit -am 't4'; git push")
# os.system('sudo shutdown -h now')
