# %% Import and setup
import os, re
import pandas as pd
import numpy as np
from natsort import natsorted
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import random
import torch.nn.functional as F
from torch import nn
import seaborn as sns
import gc

device_name = 'mps'

exp_name = 'SAE_Mistral_gemma_llama_exp_v1'
qs_structure_dir = f'../online_tasks/qs_structure/'
qs_int_dir = f'../online_tasks/qs_structure/'
files_path = qs_structure_dir + '_analysis/_llm_based/files_logits/paired/ABCD/subjects/'
qs_str_dat_dir = f'{qs_structure_dir}_analysis/_data/'
states_path = f"{files_path}states/subjects"

from _objects.model_configs import *
from _objects.configs import *

paths = Paths(files_dir='files', plots_subdir='', plots_subsubdir='')

task_v = ['v4', 'v4_d', 'v4_dd', 'v4_ddd']
common_tasks = 'v4'

# model_name_set = ['MistralOo', 'gemma2-9b-it', 'llama31-8b-it']
model_name_set = ['MistralOo', 'gemma2-9b-it', 'llama31-8b-it', 'gemma2-2b-it', 'llama32-3b-it']
# model_name_set = ['gemma2-2b-it']
# model_name = model_name_set[0]
# model_name = 'MistralOo'

do_small_data = False
# do_small_data = True
createDataset = True
saveDataset = True
createDataset = False
saveDataset = False
# do_zscores = False
do_zscores = True

q_case = '9q'  # 2q, 9q
tok_loc = ['oq_ans', 2]
if q_case == '9q':
    qs_list = ['lvl3_q1', 'lvl3_q2', 'lvl3_q3', 'lvl3_q4', 'lvl3_q5', 'lvl3_q6', 'lvl3_q7', 'lvl3_q8']
    phq9_q_names = ['phq9_q' + str(q + 1) + 's' for q in range(9)]
elif q_case == '2q':
    qs_list = ['lvl3_q2', 'lvl3_q4']
    phq9_q_names = ['phq9_q2s', 'phq9_q4s']
elif q_case == 'q2Only':
    qs_list = ['lvl3_q2']
    phq9_q_names = ['phq9_q2s']


# %% Create classes and functions
def prep_data(qs_list, read_list=True):
    pt_files = []

    # for root, dirs, files in os.walk(f"{states_path}"):
    for root, dirs, files in os.walk(f"{files_path}"):
        files = [file for file in files if
                 (sample_config.model_name_rp in file) and ('.pt' in file) and (re.search(common_tasks, file))]
        pt_files.extend(files)

    if len(qs_list) > 0:
        pt_files = [f for f in pt_files if any([True for q in qs_list if "^^" + q in f])]

    if do_small_data:
        subs = natsorted(list(set([f.split('^^')[0] for f in pt_files])))[:30]
        pt_files = [f for f in pt_files if f.split('^^')[0] in subs]
    else:
        subs = natsorted(list(set([f.split('^^')[0] for f in pt_files])))

    qs = natsorted(list(set([f.split('^^')[1] for f in pt_files])))
    pt_file_paths = []
    for pt_file in pt_files:
        sub, q, label_pert, model = pt_file.split('^^')
        model = model.split('_hidden_states')[0]
        pt_file_path = f"{files_path}{sub}/{q}/{model}/{pt_file}"
        pt_file_paths.append(pt_file_path)

    # % Prep phq and openq
    phq9_data = pd.read_csv(f"{qs_str_dat_dir}/phq9_data.csv")[['sub', 'task_version'] + phq9_q_names]
    phq9_data = phq9_data[phq9_data['task_version'].isin(task_v)].reset_index(drop=True)
    nan_phq9_subs = phq9_data[phq9_data.isna().any(axis=1)]['sub'].to_list()
    if do_zscores:
        phq9_data[phq9_q_names] = (phq9_data[phq9_q_names] - phq9_data[phq9_q_names].mean(axis=0)) / phq9_data[
            phq9_q_names].std(axis=0)

    min_words = 30
    openq_data_long = pd.read_csv(f'{qs_str_dat_dir}/openq_data_long.csv')
    openq_data_long = openq_data_long[openq_data_long['question'].isin(qs)]
    openq_data_long = openq_data_long[openq_data_long['sub'].isin(subs)]
    openq_data_long = openq_data_long[openq_data_long['task_version'].isin(task_v)].reset_index(drop=True)
    openq_data_long = openq_data_long.replace(r'\s+\.', '.', regex=True)
    openq_data_long = openq_data_long.replace(r'\s+,', ', ', regex=True)
    openq_data_long = openq_data_long.replace(r'\n+,', ' ', regex=True)
    openq_data_long['response'] = openq_data_long['response'].astype(str)
    openq_data_long['n_words'] = openq_data_long['response'].apply(lambda x: len(x.split()))

    openq_data_long = openq_data_long[
        (openq_data_long['n_words'] >= min_words) & (~openq_data_long['sub'].isin(nan_phq9_subs))]
    sub_with_full_set = openq_data_long.groupby(['sub'])['question'].count() == len(qs_list)
    sub_with_full_set = [s for b, s in zip(sub_with_full_set, sub_with_full_set.index) if b]
    openq_data_long = openq_data_long[openq_data_long['sub'].isin(sub_with_full_set)]
    print('n sub full set', len(sub_with_full_set))
    print('n sub', len(openq_data_long['sub'].unique()))

    subs = openq_data_long['sub'].unique().tolist()
    qs_list = openq_data_long['question'].unique().tolist()

    if read_list:
        with open(f'data/{exp_name}/sub_list_{q_case}_random.txt', 'r') as f:
            # subs = random.shuffle(subs)
            subs = f.read().split('^')
    else:
        random.shuffle(subs)
        subs_list = '^'.join(subs)
        with open(f'data/{exp_name}/sub_list_{q_case}_random.txt', 'w') as f:
            f.write(subs_list)

    return phq9_data, pt_file_paths, subs, qs_list


# % Create dataset classes
class MetaDataset:
    def __init__(self, device_name, qs_list):
        # x_avg, y_avg = [], []
        x_avg, y_avg = {}, {}
        phq9_data, pt_file_paths, subs, qs_list = prep_data(qs_list)
        for sub in subs:
            sub_phq9 = list(phq9_data[phq9_data['sub'] == sub][phq9_q_names].values[0])
            sub_phq9 = torch.tensor(sub_phq9)
            sub_pts = []
            for q in qs_list:
                pt_file_path = [p for p in pt_file_paths if f'{sub}/{q}/{sample_config.model_name_rp}' in p]
                if len(pt_file_path) == 1:
                    pt_file_path = pt_file_path[0]
                    # load the embedding of the last token of the open-ended response
                    sub_q_pt = \
                        torch.load(f"{pt_file_path}", weights_only=True, map_location=torch.device(device_name)).type(
                            torch.FloatTensor)[
                            1 + layer_idx_start:, tok_loc[1], :]  # remve the embedding layer  as well

                    # normalise the embedding (length 1)
                    sub_q_pt /= torch.linalg.norm(sub_q_pt, axis=1, keepdims=True)

                    sub_pts.append(sub_q_pt)

            # get the average and normalise
            sub_avg_pt = torch.stack(sub_pts).mean(dim=0)
            sub_avg_pt /= torch.linalg.norm(sub_avg_pt, axis=1, keepdims=True)

            if sub_avg_pt is not None:
                x_avg[sub] = sub_avg_pt
                y_avg[sub] = sub_phq9
        self.avg_features = x_avg
        self.avg_labels = y_avg
        self.subs = subs


class MyDataset(Dataset):
    def __init__(self, dataset, i_start, i_end, layer_index, device):
        self.subs = dataset.subs[i_start:i_end]
        self.features = [dataset.avg_features[sub] for sub in self.subs]
        self.labels = [dataset.avg_labels[sub] for sub in self.subs]

        self.features = [f[layer_index].type(torch.float32) for f in self.features]
        if device == 'mps':
            self.features = [f.type(torch.float32) for f in self.features]
            self.labels = [l.type(torch.float32) for l in self.labels]

        # shuffle
        tmp_pairs = list(zip(self.features, self.labels, self.subs))
        random.shuffle(tmp_pairs)
        self.features, self.labels, self.subs = zip(*tmp_pairs)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]


# %% Create datasets
for model_name in model_name_set:
    sample_config = SampleLogitsConfig(paths, model_name=model_name, qs_name='phq9', instr_name='instr3', temp='',
                                       top_p='')

    dataset_fname = q_case

    if do_zscores:
        dataset_fname += '_zs'

    dataset_fname += f'_{sample_config.model_name_sshort}'

    layer_idx_start = sample_config.L // 2

    # %
    if createDataset:
        meta_dataset = MetaDataset(device_name, qs_list)
        if saveDataset:
            torch.save(meta_dataset, f'data/{exp_name}/{dataset_fname}.pt')
