import torch
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoModelForMaskedLM
from transformers import BertTokenizer, BertModel
import os
import sys
import pickle

from llm import load_llm, get_paths_from_string, get_left_pad, get_add_token
from data.dataset import BooIQDataset, CommonsenseQADataset, WinoGrandeDataset, NQOpenDataset, HaluEvalDataset, HateSpeechDataset, SquadDataset, CNN_DM_Dataset, XSUM_Dataset
from utils import compute_rouge, gpt_explanation_prompts, explanation_prompts, gpt_state_prompts, random_prompts
from torchmetrics.text.rouge import ROUGEScore
import together
import gc

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# set seeds
np.random.seed(0)
torch.manual_seed(0)

class BoolExplanationDataset(torch.utils.data.Dataset):
    '''
    Datasets of elicited explanations for boolean questions with answers of "yes" or "no"
    Data: BooIQ, HaluEval, ToxicEval
    '''

    def __init__(self, base_dataset, model_type, random=False, gpt_exp=False, gpt_state=False):

        self.base_dataset = base_dataset        
        self.model = None

        # check if path exists
        folder_path = "./data/" + base_dataset + "_outputs/" + model_type
        
        if gpt_exp:
            folder_path += "_gpt"
        elif gpt_state:
            folder_path += "_gpt_state"
        elif random:
            folder_path += "_random"
        
        train_subset = 5000
        test_subset = 1000
        
        # check if folder path exists
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        path = folder_path + "/train_explanations.npy"
        if not os.path.exists(path) or not os.path.exists(folder_path + "/test_explanations.npy"):

            print("No data found at " + path)
            self.model, self.tokenizer = load_llm(model_type)
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
            
            if base_dataset == "BooIQ":
                self.train_dataset = BooIQDataset(split="train", tokenizer=self.tokenizer)
                self.test_dataset = BooIQDataset(split="test", tokenizer=self.tokenizer)
                
            elif base_dataset == "HaluEval":
                self.train_dataset = HaluEvalDataset(split="train", tokenizer=self.tokenizer)
                self.test_dataset = HaluEvalDataset(split="test", tokenizer=self.tokenizer)
            
            elif base_dataset == "ToxicEval":
                self.train_dataset = HateSpeechDataset(split="train", tokenizer=self.tokenizer)
                self.test_dataset = HateSpeechDataset(split="test", tokenizer=self.tokenizer)
                    
            else:
                raise ValueError("Dataset not found")

            self.train_dataset.questions = self.train_dataset.questions[:train_subset]
            self.train_dataset.answers = self.train_dataset.answers[:train_subset]
            self.test_dataset.questions = self.test_dataset.questions[:test_subset]
            self.test_dataset.answers = self.test_dataset.answers[:test_subset]
            
            self.model_type = model_type
            self.left_pad = get_left_pad(model_type)
            self.add_token = get_add_token(model_type)

            # resulting arrays
            self.train_data, self.test_data = [], []
            self.train_log_probs, self.test_log_probs = [], []
            self.train_pre_confs, self.test_pre_confs = [], []
            self.train_post_confs, self.test_post_confs = [], []
            self.train_logits, self.test_logits = [], []

            # current prompts to generate simple responses...
            self.pre_conf_prompt = "Will you answer this question correctly? [/INST]"
            self.post_conf_prompt = "[INST] Did you answer this question correctly? [/INST]"

            if gpt_exp:
                self.explanation_prompts = gpt_explanation_prompts()
            elif gpt_state:
                self.explanation_prompts = gpt_state_prompts()
            elif random:
                self.explanation_prompts = random_prompts()
            else:
                self.explanation_prompts = explanation_prompts()

        if os.path.exists(path):
            self.train_data = np.load(folder_path + "/train_explanations.npy")
            self.train_labels = np.load(folder_path + "/train_labels.npy")
            self.train_log_probs = np.load(folder_path + "/train_log_probs.npy")
            self.train_pre_confs = np.load(folder_path + "/train_pre_confs.npy")
            self.train_post_confs = np.load(folder_path + "/train_post_confs.npy")
            self.train_logits = np.load(folder_path + "/train_logits.npy")

        else:
            self.train_data, self.train_log_probs, self.train_labels, self.train_pre_confs, \
                self.train_post_confs, self.train_logits = self.process_data("train")
            np.save(folder_path + "/train_explanations.npy", self.train_data)
            np.save(folder_path + "/train_labels.npy", self.train_labels)
            np.save(folder_path + "/train_log_probs.npy", self.train_log_probs)
            np.save(folder_path + "/train_pre_confs.npy", self.train_pre_confs)
            np.save(folder_path + "/train_post_confs.npy", self.train_post_confs)
            np.save(folder_path + "/train_logits.npy", self.train_logits)

        if os.path.exists(folder_path + "/test_explanations.npy"):
            self.test_data = np.load(folder_path + "/test_explanations.npy")
            self.test_labels = np.load(folder_path + "/test_labels.npy")
            self.test_log_probs = np.load(folder_path + "/test_log_probs.npy")
            self.test_pre_confs = np.load(folder_path + "/test_pre_confs.npy")
            self.test_post_confs = np.load(folder_path + "/test_post_confs.npy")
            self.test_logits = np.load(folder_path + "/test_logits.npy")

        else:
            self.test_data, self.test_log_probs, self.test_labels, self.test_pre_confs, \
                    self.test_post_confs, self.test_logits = self.process_data("test")
            np.save(folder_path + "/test_explanations.npy", self.test_data)
            np.save(folder_path + "/test_labels.npy", self.test_labels)
            np.save(folder_path + "/test_log_probs.npy", self.test_log_probs)
            np.save(folder_path + "/test_pre_confs.npy", self.test_pre_confs)
            np.save(folder_path + "/test_post_confs.npy", self.test_post_confs)
            np.save(folder_path + "/test_logits.npy", self.test_logits)

        # delete model
        if self.model is not None:
            del self.model
            gc.collect()

        # reshape log probs
        self.train_log_probs = self.train_log_probs.reshape(-1, 2)
        self.test_log_probs = self.test_log_probs.reshape(-1, 2)

        # convert labels from downstream task label to if model was correct
        model_preds = np.argmax(self.train_log_probs, axis=1)
        self.train_labels = (model_preds == self.train_labels).astype(int)

        model_preds = np.argmax(self.test_log_probs, axis=1)
        self.test_labels = (model_preds == self.test_labels).astype(int)

    def process_data(self, split):
            
        if split == "train":
            base_dataset = self.train_dataset
        else:
            base_dataset = self.test_dataset
        
        base_questions = base_dataset.questions
        
        all_data = []
        all_log_probs = []
        all_pre_confs = []
        all_post_confs = []
        all_logits = []
        all_labels = base_dataset.answers

        # get indices of yes and no tokens -> possible answers to question
        yes_token = "yes"
        no_token = "no"

        if self.add_token:
            yes_token_id = self.tokenizer.encode(yes_token, return_tensors="pt")[:, 1]
            no_token_id = self.tokenizer.encode(no_token, return_tensors="pt")[:, 1]
        else:
            yes_token_id = self.tokenizer.encode(yes_token, return_tensors="pt")
            no_token_id = self.tokenizer.encode(no_token, return_tensors="pt")
            yes_token_id = yes_token_id[:, 0]
            no_token_id = no_token_id[:, 0]

        # loop through questions 
        for q_ind, q in tqdm(enumerate(base_questions), total=len(base_questions)):

            # get last token logits after question
            input_ids = self.tokenizer.encode(q, return_tensors="pt").to(device)
            with torch.no_grad():
                logits = self.model(input_ids, return_dict=True).logits[0, -1, :].cpu()
            all_logits.append(logits)

            # get pre confidence score - append pre conf prompt to question
            input_ids = self.tokenizer.encode(q[:-7] + self.pre_conf_prompt, return_tensors="pt").to(device)
            with torch.no_grad():
                logits = self.model(input_ids, return_dict=True).logits[0, -1, :].cpu()
            pre_conf_dist = torch.stack([logits[yes_token_id], logits[no_token_id]], dim=0).squeeze()
            pre_conf_dist = torch.nn.functional.softmax(pre_conf_dist, dim=0)
            pre_conf = pre_conf_dist[0].item()
            all_pre_confs.append(pre_conf)

            # get post confidence score - append post conf prompt after adding answer to question
            input_ids_y = self.tokenizer.encode(q + " " + yes_token + " " + self.post_conf_prompt, return_tensors="pt").to(device)
            input_ids_n = self.tokenizer.encode(q + " " + no_token + " " + self.post_conf_prompt, return_tensors="pt").to(device)
            with torch.no_grad():
                logits_y = self.model(input_ids_y, return_dict=True).logits[0, -1, :].cpu()
                logits_n = self.model(input_ids_n, return_dict=True).logits[0, -1, :].cpu()

            post_conf_dist_y = torch.stack([logits_y[yes_token_id], logits_y[no_token_id]], dim=0).squeeze()
            post_conf_dist_y = torch.nn.functional.softmax(post_conf_dist_y, dim=0)
            post_conf_y = post_conf_dist_y[0].item()
            post_conf_dist_n = torch.stack([logits_n[yes_token_id], logits_n[no_token_id]], dim=0).squeeze()
            post_conf_dist_n = torch.nn.functional.softmax(post_conf_dist_n, dim=0)
            post_conf_n = post_conf_dist_n[0].item()
            all_post_confs.append([post_conf_y, post_conf_n])

            # get distribution over answers
            input_prompt = q + " "
            input_ids = self.tokenizer.encode(q, return_tensors="pt").to(device)
            input_ids = self.tokenizer.encode(input_prompt, return_tensors="pt").to(device)
            
            with torch.no_grad():
                logits = self.model(input_ids, return_dict=True).logits[0, -1, :].cpu()
            
            prob_dist = torch.stack([logits[no_token_id], logits[yes_token_id]], dim=1)
            prob_dist = torch.nn.functional.softmax(prob_dist, dim=1)
            all_log_probs.append(prob_dist)

            # # get probs of "yes" to explanation questions
            # exp_inputs = [q + " " + ans + " " + exp for ans in ["yes", "no"] for exp in self.explanation_prompts]
            # token_dict = self.tokenizer(exp_inputs, padding=True, return_tensors="pt")
            # input_ids = token_dict.input_ids.to(device)

            # with torch.no_grad():
            #     logits = self.model(input_ids, return_dict=True).logits
            #     last_token_id = token_dict.attention_mask.sum(1) - 1

            # if not self.left_pad:
            #     logits = logits[range(logits.shape[0]), last_token_id, :].squeeze()
            # else:
            #     logits = logits[:, -1, :]

            # prob_dist = torch.stack([logits[:, no_token_id], logits[:, yes_token_id]], dim=1).squeeze()
            # prob_dist = torch.nn.functional.softmax(prob_dist, dim=1)
            # prob_dist = prob_dist[:, 1].cpu().numpy()

            # all_data.append(prob_dist)
            # del input_ids
            # del logits
            # gc.collect()

            # get explanation responses by looping through by size num_batch
            prob_dist = np.zeros((2*len(self.explanation_prompts),))
            num_batch = 28
            for i in range(0, len(self.explanation_prompts), num_batch):
                exp_inputs = [q + " " + ans + " " + exp for ans in ["yes", "no"] for exp in self.explanation_prompts[i:i+num_batch]]
                token_dict = self.tokenizer(exp_inputs, padding=True, return_tensors="pt")
                input_ids = token_dict.input_ids.to(device)

                with torch.no_grad():
                    logits = self.model(input_ids, return_dict=True).logits
                    last_token_id = token_dict.attention_mask.sum(1) - 1

                # get probability of yes (w.r.t. distribution [yes, no])
                if self.left_pad:
                    logits = logits[:, -1, :]
                else:
                    logits = logits[range(logits.shape[0]), last_token_id, :]
                logits = torch.stack([logits[:, no_token_id], logits[:, yes_token_id]], dim=1).squeeze()
                prob_dist[2 * i: 2 * (i + num_batch)] = torch.nn.functional.softmax(logits, dim=1)[:, 1].cpu().numpy()
    
                # del from memory
                del input_ids
                del logits
                gc.collect()

            all_data.append(prob_dist)

        all_data = np.array(all_data)
        all_log_probs = np.array(all_log_probs)
        all_pre_confs = np.array(all_pre_confs)
        all_post_confs = np.array(all_post_confs)
        all_logits = np.array(all_logits)
        all_labels = np.array(all_labels)
        return all_data, all_log_probs, all_labels, all_pre_confs, all_post_confs, all_logits

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

class MCQExplanationDataset(torch.utils.data.Dataset):
    '''
    Datasets for multiple choice questions with explanations
    Data: CommonsenseQA
    '''

    def __init__(self, base_dataset, model_type, random=False, gpt_exp=False, gpt_state=False):

        self.base_dataset = base_dataset
        self.model = None

        # check if path exists
        folder_path = "./data/" + base_dataset + "_outputs/" + model_type
        if gpt_exp:
            folder_path += "_gpt"
        elif gpt_state:
            folder_path += "_gpt_state"
        elif random:
            folder_path += "_random"
        
        self.random = random
        # check if folder path exists
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        if model_type == "llama-70b":
            train_subset = 1000
            test_subset = 1000
        else:
            train_subset = 5000
            test_subset = 1000

        self.model_type = model_type
        self.left_pad = get_left_pad(model_type)
        self.add_token = get_add_token(model_type)
        self.gpt_exp = gpt_exp
        
        path = folder_path + "/train_explanations.npy"
        if not os.path.exists(path) or not os.path.exists(folder_path + "/test_explanations.npy"):

            print("No data found at " + path)
            self.model, self.tokenizer = load_llm(model_type)
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        
            if base_dataset == "CommonsenseQA":
                self.train_dataset = CommonsenseQADataset(split="train", tokenizer=self.tokenizer)
                self.test_dataset = CommonsenseQADataset(split="test", tokenizer=self.tokenizer)
                self.options = self.train_dataset.options # abcde

            else:
                raise ValueError("Dataset not found")
        

            self.train_dataset.questions = self.train_dataset.questions[:train_subset]
            self.train_dataset.answers = self.train_dataset.answers[:train_subset]
            self.test_dataset.questions = self.test_dataset.questions[:test_subset]
            self.test_dataset.answers = self.test_dataset.answers[:test_subset]
        
            # resulting arrays
            self.train_data, self.test_data = [], []
            self.train_log_probs, self.test_log_probs = [], []
            self.train_pre_confs, self.test_pre_confs = [], []
            self.train_post_confs, self.test_post_confs = [], []
            self.train_logits, self.test_logits = [], []

            # current prompts to generate simple responses...
            self.pre_conf_prompt = "Will you answer this question correctly? [/INST]"
            self.post_conf_prompt = "[INST] Did you answer this question correctly? [/INST]"

            if gpt_exp:
                self.explanation_prompts = gpt_explanation_prompts()
            elif gpt_state:
                self.explanation_prompts = gpt_state_prompts()
            elif self.random:
                self.explanation_prompts = random_prompts()
            else:
                self.explanation_prompts = explanation_prompts()

        if os.path.exists(path):
            self.train_data = np.load(folder_path + "/train_explanations.npy")
            self.train_labels = np.load(folder_path + "/train_labels.npy")
            self.train_log_probs = np.load(folder_path + "/train_log_probs.npy")
            self.train_pre_confs = np.load(folder_path + "/train_pre_confs.npy")
            self.train_post_confs = np.load(folder_path + "/train_post_confs.npy")
            self.train_logits = np.load(folder_path + "/train_logits.npy")

        else:
            self.train_data, self.train_log_probs, self.train_labels, self.train_pre_confs, \
                self.train_post_confs, self.train_logits = self.process_data("train")
            np.save(folder_path + "/train_explanations.npy", self.train_data)
            np.save(folder_path + "/train_labels.npy", self.train_labels)
            np.save(folder_path + "/train_log_probs.npy", self.train_log_probs)
            np.save(folder_path + "/train_pre_confs.npy", self.train_pre_confs)
            np.save(folder_path + "/train_post_confs.npy", self.train_post_confs)
            np.save(folder_path + "/train_logits.npy", self.train_logits)

        if os.path.exists(folder_path + "/test_explanations.npy"):
            self.test_data = np.load(folder_path + "/test_explanations.npy")
            self.test_labels = np.load(folder_path + "/test_labels.npy")
            self.test_log_probs = np.load(folder_path + "/test_log_probs.npy")
            self.test_pre_confs = np.load(folder_path + "/test_pre_confs.npy")
            self.test_post_confs = np.load(folder_path + "/test_post_confs.npy")
            self.test_logits = np.load(folder_path + "/test_logits.npy")

        else:
            self.test_data, self.test_log_probs, self.test_labels, self.test_pre_confs, \
                    self.test_post_confs, self.test_logits = self.process_data("test")
            np.save(folder_path + "/test_explanations.npy", self.test_data)
            np.save(folder_path + "/test_labels.npy", self.test_labels)
            np.save(folder_path + "/test_log_probs.npy", self.test_log_probs)
            np.save(folder_path + "/test_pre_confs.npy", self.test_pre_confs)
            np.save(folder_path + "/test_post_confs.npy", self.test_post_confs)
            np.save(folder_path + "/test_logits.npy", self.test_logits)

        # delete model
        if self.model is not None:
            del self.model
            gc.collect()

        # convert labels from downstream task label to if model was correct
        model_preds = np.argmax(self.train_log_probs, axis=1)
        self.train_labels = (model_preds == self.train_labels).astype(int)

        model_preds = np.argmax(self.test_log_probs, axis=1)
        self.test_labels = (model_preds == self.test_labels).astype(int)


    def process_data(self, split="train"):

        if split == "train":
            base_dataset = self.train_dataset
        else:
            base_dataset = self.test_dataset
        
        base_questions = base_dataset.questions
        
        all_data = []
        all_log_probs = []
        all_pre_confs = []
        all_post_confs = []
        all_logits = []
        all_labels = base_dataset.answers

        # get indices of yes and no tokens -> possible answers to question
        yes_token = "yes"
        no_token = "no"

        # get answer ids in vocab
        answer_tokens = self.options
        if self.add_token:
            answer_token_ids = [self.tokenizer.encode(token, return_tensors="pt")[:, 1] for token in answer_tokens]
        else:
            answer_token_ids = [self.tokenizer.encode(token, return_tensors="pt") for token in answer_tokens]
            answer_token_ids = [token_id[:, 0] for token_id in answer_token_ids]
        
        # get answer ids in vocab
        yes_token = "yes"
        no_token = "no"

        if self.add_token:
            yes_token_id = self.tokenizer.encode(yes_token, return_tensors="pt")[:, 1]
            no_token_id = self.tokenizer.encode(no_token, return_tensors="pt")[:, 1]
        
        else:
            yes_token_id = self.tokenizer.encode(yes_token, return_tensors="pt")
            no_token_id = self.tokenizer.encode(no_token, return_tensors="pt")
            
            yes_token_id = yes_token_id[:, 0]
            no_token_id = no_token_id[:, 0]
        
        # loop through questions 
        for q_ind, q in tqdm(enumerate(base_questions), total=len(base_questions)):

            # look at probabilities of model outputing true or false
            input_ids = self.tokenizer.encode(q, return_tensors="pt").to(device)
            with torch.no_grad():
                log_probs = self.model(input_ids, return_dict=True).logits[0, -1, :].cpu()
            
            # convert to float if half
            if log_probs.dtype == torch.float16 or log_probs.dtype == torch.bfloat16:
                log_probs = log_probs.float()

            # store full set of last layer logits
            all_logits.append(log_probs)

            # get log probs for each answer
            log_probs_answers = [log_probs[token_id] for token_id in answer_token_ids]
            prob_dist = torch.stack(log_probs_answers, dim=0).squeeze()
            prob_dist = torch.nn.functional.softmax(prob_dist, dim=0)
            
            # store model probability prediction over [A, B, C, D, E]
            all_log_probs.append(prob_dist)

            # compute model pre confidence
            pre_conf_input = q[:-7] + self.pre_conf_prompt
            tokenized_pre_conf = self.tokenizer(pre_conf_input, padding=True, return_tensors="pt")
            with torch.no_grad():
                log_probs_pre_conf = self.model(tokenized_pre_conf.input_ids.to(device), return_dict=True).logits[0, -1, :].cpu()
            pre_conf_dist = torch.stack([log_probs_pre_conf[yes_token_id], log_probs_pre_conf[no_token_id]], dim=0).squeeze()
            pre_conf_dist = torch.nn.functional.softmax(pre_conf_dist, dim=0)
            all_pre_confs.append(pre_conf_dist[0].numpy())

            # compute model post confidence
            post_conf_input = [q + token + " " + self.post_conf_prompt for token in answer_tokens]
            tokenized_post_conf = self.tokenizer(post_conf_input, padding=True, return_tensors="pt")
            with torch.no_grad():
                log_probs_post_conf = self.model(tokenized_post_conf.input_ids.to(device), return_dict=True).logits

                # convert to float if half
                if log_probs_post_conf.dtype == torch.float16 or log_probs_post_conf.dtype == torch.bfloat16:
                    log_probs_post_conf = log_probs_post_conf.float()

                if not self.left_pad:
                    mask = tokenized_post_conf.attention_mask
                    last_token = mask.sum(1) - 1
                    log_probs_post_conf = log_probs_post_conf[range(log_probs_post_conf.shape[0]), last_token, :].cpu()
                else:
                    log_probs_post_conf = log_probs_post_conf[range(log_probs_post_conf.shape[0]), -1, :].cpu()

                log_probs_post_conf_yes = log_probs_post_conf[:, yes_token_id]
                log_probs_post_conf_no = log_probs_post_conf[:, no_token_id]

                dist = torch.stack([log_probs_post_conf_yes, log_probs_post_conf_no], axis=1).squeeze()
                prob_dist_post_conf = torch.nn.functional.softmax(dist, dim=1)
                all_post_confs.append(prob_dist_post_conf[:,0].numpy())
                
            # stack and process in parallel
            exp_input = [q + token + " " + exp for token in answer_tokens for exp in self.explanation_prompts]
            tokenized = self.tokenizer(exp_input, padding=True, return_tensors="pt")
            input_ids = tokenized.input_ids.to(device)

            if self.gpt_exp:
                # process in batches
                batch_size = 120
                exp_dist = np.zeros((len(exp_input),))
                for i in range(0, len(input_ids), batch_size):
                    with torch.no_grad():
                        log_probs = self.model(input_ids[i:i+batch_size], return_dict=True).logits
                        last_token_id = tokenized.attention_mask[i:i+batch_size].sum(1) - 1

                    if not self.left_pad:
                        log_probs = log_probs[range(log_probs.shape[0]), last_token_id, :].squeeze()
                    else:
                        log_probs = log_probs[:, -1, :].squeeze()
                    # all_probs.append(log_probs)
                    log_probs_yes = log_probs[:, yes_token_id]
                    log_probs_no = log_probs[:, no_token_id]
                    dist = torch.stack([log_probs_yes, log_probs_no], axis=1).squeeze()
                    exp_dist[i:i+batch_size] = torch.nn.functional.softmax(dist, dim=1)[:, 0].cpu().numpy()
                
                all_data.append(exp_dist)

            else:
                # compute log probabilities of next token in model
                with torch.no_grad():
                    log_probs = self.model(input_ids, return_dict=True).logits

                    # convert to float if half
                    if log_probs.dtype == torch.float16 or log_probs.dtype == torch.bfloat16:
                        log_probs = log_probs.float()

                    if not self.left_pad:
                        mask = tokenized.attention_mask
                        last_token = mask.sum(1) - 1
                        log_probs = log_probs[range(log_probs.shape[0]), last_token, :].cpu()
                    else:
                        log_probs = log_probs[range(log_probs.shape[0]), -1, :].cpu()
                
                log_probs_yes = log_probs[:, yes_token_id]
                log_probs_no = log_probs[:, no_token_id] 
                dist = torch.stack([log_probs_yes, log_probs_no], axis=1).squeeze()
                exp_dist = torch.nn.functional.softmax(dist, dim=1)[:, 0].numpy()
                all_data.append(exp_dist)

        all_data = np.array(all_data)
        all_log_probs = np.array(all_log_probs)
        all_pre_confs = np.array(all_pre_confs)
        all_post_confs = np.array(all_post_confs)
        all_logits = np.array(all_logits)
        all_labels = np.array(all_labels)
        print("Finished processing explanations")
        return all_data, all_log_probs, all_labels, all_pre_confs, all_post_confs, all_logits

class WinoGrandeExplanationDataset(torch.utils.data.Dataset):

    def __init__(self, model_type, random=False, gpt_exp=False, gpt_state=False):


        self.model = None
        
        self.model_type = model_type
        self.left_pad = get_left_pad(model_type)
        self.add_token = get_add_token(model_type)
        self.gpt_exp = gpt_exp
        self.gpt_state = gpt_state
        self.random = random

        # subset data
        train_data = 5000
        test_data = 1000

        # check if path exists
        folder_path = "./data/WinoGrande_outputs/" + model_type
        if random:
            folder_path += "/random"
        elif gpt_exp:
            folder_path += "_gpt"
        elif gpt_state:
            folder_path += "_gpt_state"
        
        # check if folder path exists
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        path = folder_path + "/train_explanations.npy"

        if not os.path.exists(path) or not os.path.exists(folder_path + "/test_explanations.npy"):

            print("No data found at " + path)
            self.model, self.tokenizer = load_llm(model_type)
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        
            self.train_dataset = WinoGrandeDataset(split="train", tokenizer=self.tokenizer)
            self.test_dataset = WinoGrandeDataset(split="validation", tokenizer=self.tokenizer) # as test set has no label?
            
            self.train_dataset.questions = self.train_dataset.questions[:train_data]
            self.train_dataset.options1 = self.train_dataset.options1[:train_data]
            self.train_dataset.options2 = self.train_dataset.options2[:train_data]
            self.test_dataset.questions = self.test_dataset.questions[:test_data]
            self.test_dataset.options1 = self.test_dataset.options1[:test_data]
            self.test_dataset.options2 = self.test_dataset.options2[:test_data]

            self.train_dataset.answers = self.train_dataset.answers[:train_data]
            self.test_dataset.answers = self.test_dataset.answers[:test_data]

            # resulting arrays
            self.train_data, self.test_data = [], []
            self.train_log_probs, self.test_log_probs = [], []
            self.train_pre_confs, self.test_pre_confs = [], []
            self.train_post_confs, self.test_post_confs = [], []
            self.train_logits, self.test_logits = [], []

            # current prompts to generate simple responses...
            self.pre_conf_prompt = "Will you answer this question correctly? [/INST]"
            self.post_conf_prompt = "[INST] Did you answer this question correctly? [/INST]"

            if gpt_exp:
                self.explanation_prompts = gpt_explanation_prompts()
            elif gpt_state:
                self.explanation_prompts = gpt_state_prompts()
            elif random:
                self.explanation_prompts = random_prompts()
            else:
                self.explanation_prompts = explanation_prompts()


        if os.path.exists(path):
            self.train_data = np.load(folder_path + "/train_explanations.npy")
            self.train_labels = np.load(folder_path + "/train_labels.npy")
            self.train_log_probs = np.load(folder_path + "/train_log_probs.npy")
            self.train_pre_confs = np.load(folder_path + "/train_pre_confs.npy")
            self.train_post_confs = np.load(folder_path + "/train_post_confs.npy")
            self.train_logits = np.load(folder_path + "/train_logits.npy")

        else:
            self.train_data, self.train_log_probs, self.train_labels, self.train_pre_confs, \
                self.train_post_confs, self.train_logits = self.process_data("train")
            np.save(folder_path + "/train_explanations.npy", self.train_data)
            np.save(folder_path + "/train_labels.npy", self.train_labels)
            np.save(folder_path + "/train_log_probs.npy", self.train_log_probs)
            np.save(folder_path + "/train_pre_confs.npy", self.train_pre_confs)
            np.save(folder_path + "/train_post_confs.npy", self.train_post_confs)
            np.save(folder_path + "/train_logits.npy", self.train_logits)

        if os.path.exists(folder_path + "/test_explanations.npy"):
            self.test_data = np.load(folder_path + "/test_explanations.npy")
            self.test_labels = np.load(folder_path + "/test_labels.npy")
            self.test_log_probs = np.load(folder_path + "/test_log_probs.npy")
            self.test_pre_confs = np.load(folder_path + "/test_pre_confs.npy")
            self.test_post_confs = np.load(folder_path + "/test_post_confs.npy")
            self.test_logits = np.load(folder_path + "/test_logits.npy")

        else:
            self.test_data, self.test_log_probs, self.test_labels, self.test_pre_confs, \
                    self.test_post_confs, self.test_logits = self.process_data("test")
            np.save(folder_path + "/test_explanations.npy", self.test_data)
            np.save(folder_path + "/test_labels.npy", self.test_labels)
            np.save(folder_path + "/test_log_probs.npy", self.test_log_probs)
            np.save(folder_path + "/test_pre_confs.npy", self.test_pre_confs)
            np.save(folder_path + "/test_post_confs.npy", self.test_post_confs)
            np.save(folder_path + "/test_logits.npy", self.test_logits)

        # delete model
        if self.model is not None:
            del self.model
            gc.collect()
            
        print(self.train_log_probs.shape, self.test_log_probs.shape)

        # convert labels from downstream task label to if model was correct
        model_preds = np.argmax(self.train_log_probs, axis=1)
        self.train_labels = (model_preds == self.train_labels).astype(int)

        model_preds = np.argmax(self.test_log_probs, axis=1)
        self.test_labels = (model_preds == self.test_labels).astype(int)

    def process_data(self, split="train"):

        # get answer ids in vocab for yes / no questions to explanations and pre/post conf
        yes_token = "yes"
        no_token = "no"

        if split == "train":
            base_dataset = self.train_dataset
        else:
            base_dataset = self.test_dataset
        
        base_questions = base_dataset.questions
        base_answers = base_dataset.answers
        
        all_data = []
        all_log_probs = []
        all_pre_confs = []
        all_post_confs = []
        all_logits = []
        all_labels = base_dataset.answers

        if self.add_token:
            yes_token_id = self.tokenizer.encode(yes_token, return_tensors="pt")[:, 1]
            no_token_id = self.tokenizer.encode(no_token, return_tensors="pt")[:, 1]
        
        else:
            yes_token_id = self.tokenizer.encode(yes_token, return_tensors="pt")
            no_token_id = self.tokenizer.encode(no_token, return_tensors="pt")
            
            yes_token_id = yes_token_id[:, 0]
            no_token_id = no_token_id[:, 0]

        # loop through questions 
        for q_ind, q in tqdm(enumerate(base_questions), total=len(base_questions)):

            # look at probabilities of model outputing true or false
            input_ids = self.tokenizer.encode(q, return_tensors="pt").to(device)
            with torch.no_grad():
                log_probs = self.model(input_ids, return_dict=True).logits[0, -1, :].cpu()

            # option1_token = base_dataset.options1[q_ind]
            # option2_token = base_dataset.options2[q_ind]
            option1_token = "A"
            option2_token = "B"

            if self.add_token:
                option1_token_id = self.tokenizer.encode(option1_token, return_tensors="pt")[:, 1]
                option2_token_id = self.tokenizer.encode(option2_token, return_tensors="pt")[:, 1]
            else:
                option1_token_id = self.tokenizer.encode(option1_token, return_tensors="pt")
                option2_token_id = self.tokenizer.encode(option2_token, return_tensors="pt")

                option1_token_id = option1_token_id[:, 0]
                option2_token_id = option2_token_id[:, 0]
                
            # convert to float if half
            if log_probs.dtype == torch.float16 or log_probs.dtype == torch.bfloat16:
                log_probs = log_probs.float()
            
            # store full set of logits
            all_logits.append(log_probs)
            
            log_probs_1 = log_probs[option1_token_id]
            log_probs_2 = log_probs[option2_token_id]
            prob_dist = torch.stack([log_probs_1, log_probs_2], dim=0).squeeze()
            prob_dist = torch.nn.functional.softmax(prob_dist, dim=0)

            # store model probability distribution over answers
            all_log_probs.append(prob_dist)

            # computing model pre confidences
            pre_conf = np.ones(len(self.explanation_prompts))
            input_ids = self.tokenizer.encode(q[:-7] + self.pre_conf_prompt, return_tensors="pt").to(device)
            with torch.no_grad():
                log_probs = self.model(input_ids, return_dict=True).logits[0, -1, :].cpu()

            # convert to float if half
            if log_probs.dtype == torch.float16 or log_probs.dtype == torch.bfloat16:
                log_probs = log_probs.float()
            pre_dist = torch.stack([log_probs[yes_token_id], log_probs[no_token_id]], dim=0).squeeze()
            pre_dist = torch.nn.functional.softmax(pre_dist, dim=0)
            all_pre_confs.append(pre_dist[0].numpy())

            # computing model post confidences
            input_ids = self.tokenizer.encode(q + base_dataset.options1[q_ind] + " " + self.post_conf_prompt, return_tensors="pt").to(device)
            with torch.no_grad():
                log_probs = self.model(input_ids, return_dict=True).logits[0, -1, :].cpu()
            post_dict = torch.stack([log_probs[yes_token_id], log_probs[no_token_id]], dim=0).squeeze()
            post_dict = torch.nn.functional.softmax(post_dict, dim=0)
            post_conf_1 = post_dict[0].numpy()

            input_ids = self.tokenizer.encode(q + base_dataset.options2[q_ind] + " " + self.post_conf_prompt, return_tensors="pt").to(device)
            with torch.no_grad():
                log_probs = self.model(input_ids, return_dict=True).logits[0, -1, :].cpu()
            post_dict = torch.stack([log_probs[yes_token_id], log_probs[no_token_id]], dim=0).squeeze()
            post_dict = torch.nn.functional.softmax(post_dict, dim=0)
            post_conf_2 = post_dict[0].numpy()

            post_conf = np.array([post_conf_1, post_conf_2])
            all_post_confs.append(post_conf)

            # now take true and false and prepend to each simple_exp
            true_to_add = np.zeros(len(self.explanation_prompts))
            false_to_add = np.zeros(len(self.explanation_prompts))

            # stack and process in parallel
            exp_input_true = [q + base_dataset.options1[q_ind] + " " + exp for exp in self.explanation_prompts]
            exp_input_false = [q + base_dataset.options2[q_ind] + " " + exp for exp in self.explanation_prompts]

            tokenized_true = self.tokenizer(exp_input_true, padding=True, return_tensors="pt")
            tokenized_false = self.tokenizer(exp_input_false, padding=True, return_tensors="pt")

            input_ids_t = tokenized_true.input_ids.to(device)
            input_ids_f = tokenized_false.input_ids.to(device)

            # compute log probabilities of next token in model
            with torch.no_grad():
                log_probs_t = self.model(input_ids_t, return_dict=True).logits
                # convert to float if half
                if log_probs_t.dtype == torch.float16 or log_probs_t.dtype == torch.bfloat16:
                    log_probs_t = log_probs_t.float()

                if not self.left_pad:
                    mask_t = tokenized_true.attention_mask
                    last_token_t = mask_t.sum(1) - 1
                    log_probs_t = log_probs_t[range(log_probs_t.shape[0]), last_token_t, :].cpu()
                
                else:
                    log_probs_t = log_probs_t[range(log_probs_t.shape[0]), -1, :].cpu()

                log_probs_t_yes = log_probs_t[:, yes_token_id]
                log_probs_t_no = log_probs_t[:, no_token_id]

                dist = torch.stack([log_probs_t_yes, log_probs_t_no], dim=1).squeeze()
                prob_dist_t = torch.nn.functional.softmax(dist, dim=1)
                true_to_add = prob_dist_t[:,0].numpy()

            with torch.no_grad():
                log_probs_f = self.model(input_ids_f, return_dict=True).logits

                # convert to float if half
                if log_probs_f.dtype == torch.float16 or log_probs_f.dtype == torch.bfloat16:
                    log_probs_f = log_probs_f.float()

                if not self.left_pad:
                    mask_f = tokenized_false.attention_mask
                    last_token_f = mask_f.sum(1) - 1
                    log_probs_f = log_probs_f[range(log_probs_f.shape[0]), last_token_f, :].cpu()
                else:
                    log_probs_f = log_probs_f[range(log_probs_f.shape[0]), -1, :].cpu()
            
                log_probs_f_yes = log_probs_f[:, yes_token_id]
                log_probs_f_no = log_probs_f[:, no_token_id]

                dist = torch.stack([log_probs_f_yes, log_probs_f_no], axis=1).squeeze()
                prob_dist_f = torch.nn.functional.softmax(dist, dim=1)
                false_to_add = prob_dist_f[:,0].numpy()

            # add to data
            all_data.append(np.concatenate([true_to_add, false_to_add]))

        # convert to numpy arrays
        all_data = np.array(all_data)
        all_log_probs = np.array(all_log_probs)
        all_pre_confs = np.array(all_pre_confs)
        all_post_confs = np.array(all_post_confs)
        all_logits = np.array(all_logits)
        all_labels = np.array(all_labels)
        return all_data, all_log_probs, all_labels, all_pre_confs, all_post_confs, all_logits

class OpenEndedExplanationDataset(torch.utils.data.Dataset):

    def __init__(self, model_type, gpt_exp=False, gpt_state=False, random=False):
        
        self.model_type = model_type
        self.left_pad = get_left_pad(model_type)
        self.add_token = get_add_token(model_type)

        # check if path exists
        if gpt_exp:
            folder_path = "./data/NQOpen_outputs/" + model_type + "_gpt"
        elif gpt_state:
            folder_path = "./data/NQOpen_outputs/" + model_type + "_gpt_state"
        elif random:
            folder_path = "./data/NQOpen_outputs/" + model_type + "_random"
        else:
            folder_path = "./data/NQOpen_outputs/" + model_type
        
        path = folder_path + "/train_explanations.npy"

        if not os.path.exists(path) or not os.path.exists(folder_path + "/test_explanations.npy"):
            print("No data found at " + path)
            
            self.model, self.tokenizer = load_llm(model_type)
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        
            self.train_dataset = NQOpenDataset(split="train", tokenizer=self.tokenizer) # load for in-context examples
            self.test_dataset = NQOpenDataset(split="validation", tokenizer=self.tokenizer)

            num_context = 2
            self.context_examples = ""

            for i in range(num_context):
                self.context_examples += self.train_dataset.questions[i] + " " + self.train_dataset.answers[i][0] + "\n"
            
            # don't double use context examples
            self.train_dataset.questions = self.train_dataset.questions[num_context:]
            self.train_dataset.answers = self.train_dataset.answers[num_context:]

            # subset train_data
            num_subset = 5000
            self.train_dataset.questions = self.train_dataset.questions[:num_subset]
            self.train_dataset.answers = self.train_dataset.answers[:num_subset]

            # subset test_data
            num_subset = 1000
            self.test_dataset.questions = self.test_dataset.questions[:num_subset]
            self.test_dataset.answers = self.test_dataset.answers[:num_subset]

            # current prompts to generate simple responses...
            self.pre_conf_prompt = "Will you answer this question correctly? [/INST]"
            self.post_conf_prompt = "[INST] Did you answer this question correctly? [/INST]"

            if gpt_exp:
                self.explanation_prompts = gpt_explanation_prompts()
            elif gpt_state:
                self.explanation_prompts = gpt_state_prompts()
            elif random:
                self.explanation_prompts = random_prompts()
            else:
                self.explanation_prompts = explanation_prompts()

            # check if path exists, otherwise make
            if not os.path.exists(folder_path):
                print("Making dir", folder_path)
                os.makedirs(folder_path)

            # stores result in self.data and self.labels
            self.train_data, self.test_data = [], [] # explanation answers
            self.train_labels, self.test_labels = [], [] # is output correct on certain question
            self.train_log_probs, self.test_log_probs = [], [] # model log probs
            self.train_pre_confs, self.train_post_confs = [], [] # pre and post confidences
            self.test_pre_confs, self.test_post_confs = [], [] # pre and post confidences
            self.train_logits, self.test_logits = [], [] # logits

            if os.path.exists(path):
                self.train_data = np.load(folder_path + "/train_explanations.npy")
                self.train_labels = np.load(folder_path + "/train_labels.npy")
                self.train_log_probs = np.load(folder_path + "/train_log_probs.npy")
                self.train_pre_confs = np.load(folder_path + "/train_pre_confs.npy")
                self.train_post_confs = np.load(folder_path + "/train_post_confs.npy")
                self.train_logits = np.load(folder_path + "/train_logits.npy")

            else:
                self.train_data, self.train_labels, self.train_log_probs, \
                    self.train_pre_confs, self.train_post_confs, self.train_logits = self.process_data("train")
            
                # save result
                np.save(folder_path + "/train_explanations.npy", self.train_data)
                np.save(folder_path + "/train_labels.npy", self.train_labels)
                np.save(folder_path + "/train_log_probs.npy", self.train_log_probs)
                np.save(folder_path + "/train_pre_confs.npy", self.train_pre_confs)
                np.save(folder_path + "/train_post_confs.npy", self.train_post_confs)
                np.save(folder_path + "/train_logits.npy", self.train_logits)

            
            if os.path.exists(folder_path + "/test_explanations.npy"):
                self.test_data = np.load(folder_path + "/test_explanations.npy")
                self.test_labels = np.load(folder_path + "/test_labels.npy")
                self.test_log_probs = np.load(folder_path + "/test_log_probs.npy")
                self.test_pre_confs = np.load(folder_path + "/test_pre_confs.npy")    
                self.test_post_confs = np.load(folder_path + "/test_post_confs.npy")
                self.test_logits = np.load(folder_path + "/test_logits.npy")

            else:
                self.test_data, self.test_labels, self.test_log_probs, \
                    self.test_pre_confs, self.test_post_confs, self.test_logits = self.process_data("test")
                
                np.save(folder_path + "/test_explanations.npy", self.test_data)
                np.save(folder_path + "/test_labels.npy", self.test_labels)
                np.save(folder_path + "/test_log_probs.npy", self.test_log_probs)
                np.save(folder_path + "/test_pre_confs.npy", self.test_pre_confs)
                np.save(folder_path + "/test_post_confs.npy", self.test_post_confs)
                np.save(folder_path + "/test_logits.npy", self.test_logits)
        
            # delete model
            if self.model is not None:
                del self.model
                gc.collect()

        else:
            print("Loading data")
            self.train_data = np.load(folder_path + "/train_explanations.npy")
            self.train_labels = np.load(folder_path + "/train_labels.npy")
            self.train_log_probs = np.load(folder_path + "/train_log_probs.npy")
            self.train_pre_confs = np.load(folder_path + "/train_pre_confs.npy")
            self.train_post_confs = np.load(folder_path + "/train_post_confs.npy")
            self.train_logits = np.load(folder_path + "/train_logits.npy")

            self.test_data = np.load(folder_path + "/test_explanations.npy")
            self.test_labels = np.load(folder_path + "/test_labels.npy")
            self.test_log_probs = np.load(folder_path + "/test_log_probs.npy")
            self.test_pre_confs = np.load(folder_path + "/test_pre_confs.npy")
            self.test_post_confs = np.load(folder_path + "/test_post_confs.npy")
            self.test_logits = np.load(folder_path + "/test_logits.npy")

    def process_data(self, split):

        count = 0

        # get ids of yes and no token - used later
        yes_token = "yes"
        no_token = "no"
        if self.add_token:
            yes_token_id = self.tokenizer.encode(yes_token, return_tensors="pt")[:, 1]
            no_token_id = self.tokenizer.encode(no_token, return_tensors="pt")[:, 1]
        else:
            yes_token_id = self.tokenizer.encode(yes_token, return_tensors="pt")
            no_token_id = self.tokenizer.encode(no_token, return_tensors="pt")
            yes_token_id = yes_token_id[:, 0]
            no_token_id = no_token_id[:, 0]

        all_data = []
        all_labels = []
        model_log_probs = []
        pre_confs = []
        post_confs = []
        all_logits = []

        if split == "train":
            base_dataset = self.train_dataset
        else:
            base_dataset = self.test_dataset

        # loop through questions 
        for q_ind, q in tqdm(enumerate(base_dataset.questions), total=len(base_dataset.questions)):

            answers = base_dataset.answers[q_ind] # list of potential open ended answers
            answer_tokens = self.tokenizer(answers, padding=True, return_tensors="pt").input_ids.to(device)
            max_len = answer_tokens.shape[1]


            input_ids = self.tokenizer.encode(self.context_examples + q, return_tensors="pt").to(device)
            q_len = len(input_ids[0])
            
            # get highest probability generation from model
            with torch.no_grad():
                output = self.model.generate(input_ids, max_length=q_len + max_len, num_return_sequences=1, do_sample=False)
                output = self.tokenizer.decode(output[0, len(input_ids[0]):], skip_special_tokens=True)
                output = output.strip()

            # check if output is in the answers
            # print("Answers", answers)
            # print("Output", output)
            
            correct_flag = False # used to check if model did predict correctly
            if output not in answers:
                all_labels.append(0)
            else:
                all_labels.append(1)
            
            # get last layer logits
            with torch.no_grad():
                logits = self.model(input_ids, return_dict=True).logits
                logits = logits[0, -1, :]
                all_logits.append(logits.cpu().numpy())

            # get pre confidence
            inputs = self.context_examples + q[:-7] + self.pre_conf_prompt
            with torch.no_grad():
                logits = self.model(self.tokenizer(inputs, return_tensors="pt").input_ids.to(device), return_dict=True).logits
                logits = logits[0, -1, :]
                pre_dist = torch.stack([logits[yes_token_id], logits[no_token_id]], dim=0).squeeze()
                pre_dist = torch.nn.functional.softmax(pre_dist, dim=0)
                pre_conf = pre_dist[0].cpu().numpy().flatten()
            pre_confs.append(pre_conf)

            # get post confidence from its generated answer
            inputs = self.context_examples + q + output + self.post_conf_prompt
            with torch.no_grad():
                logits = self.model(self.tokenizer(inputs, return_tensors="pt").input_ids.to(device), return_dict=True).logits
                logits = logits[0, -1, :]
                post_dist = torch.stack([logits[yes_token_id], logits[no_token_id]], dim=0).squeeze()
                post_dist = torch.nn.functional.softmax(post_dist, dim=0)
                post_conf = post_dist[0].cpu().numpy().flatten()
            post_confs.append(post_conf)

            # get model probabilities of generated answer
            inputs = self.context_examples + q + " " + output
            token_dict = self.tokenizer(inputs, padding=True, return_tensors="pt")
            input_ids = token_dict.input_ids.to(device)
            
            with torch.no_grad():
                logits = self.model(input_ids, return_dict=True).logits
                logits = logits[0]
                output_logits = logits[q_len - 1: -1, :] # getting dist shifted by one
                probabilities = torch.nn.functional.softmax(output_logits, dim=1)
                output_tokens = token_dict.input_ids[0, q_len:].cpu().numpy()

                log_probs = torch.log(probabilities[range(probabilities.shape[0]), output_tokens])
                log_probs = log_probs.sum().item()

            model_log_probs.append(log_probs)

            # del from memory
            del input_ids
            del logits
            gc.collect()
            
            # get explanation responses
            exp_inputs = [inputs + " " + exp for exp in self.explanation_prompts]
            token_dict = self.tokenizer(exp_inputs, padding=True, return_tensors="pt")
            input_ids = token_dict.input_ids.to(device)
            
            with torch.no_grad():
                logits = self.model(input_ids, return_dict=True).logits
                last_token_id = token_dict.attention_mask.sum(1) - 1
            
            # get probability of yes (w.r.t. distribution [yes, no])
            if self.left_pad:
                logits = logits[:, -1, :]
            else:
                logits = logits[range(logits.shape[0]), last_token_id, :].squeeze()
            
            prob_dist = torch.stack([logits[:, yes_token_id], logits[:, no_token_id]], dim=1).squeeze()
            prob_dist = torch.nn.functional.softmax(prob_dist, dim=1)
            prob_dist = prob_dist[:, 0].cpu().numpy()
            prob_dist = prob_dist.reshape(-1, len(self.explanation_prompts))
            
            # store results
            all_data.append(prob_dist)
            del input_ids
            del logits
            gc.collect()

        all_data = np.concatenate(all_data, axis=0)
        all_labels = np.array(all_labels)
        model_log_probs = np.array(model_log_probs)
        pre_confs = np.array(pre_confs)
        post_confs = np.array(post_confs)
        all_logits = np.array(all_logits)

        return all_data, all_labels, model_log_probs, pre_confs, post_confs, all_logits
         
class SquadExplanationDataset(torch.utils.data.Dataset):
    
    def __init__(self, model_type, gpt_exp=False, gpt_state=False, random=False):
        self.model_type = model_type
        self.left_pad = get_left_pad(model_type)
        self.add_token = get_add_token(model_type)

        if gpt_exp:
            folder_path = "./data/squad_outputs/" + model_type + "_gpt_exp"
        elif gpt_state:
            folder_path = "./data/squad_outputs/" + model_type + "_gpt_state"
        elif random:
            folder_path = "./data/squad_outputs/" + model_type + "_random"
        else:
            folder_path = "./data/squad_outputs/" + model_type
            
        path = folder_path + "/train_explanations.npy"

        # check if folder path exists
        if not os.path.exists(folder_path):
            print("Making dir", folder_path)
            os.makedirs(folder_path)

        if not os.path.exists(path) or not os.path.exists(folder_path + "/test_explanations.npy"):
            print("No data found at " + path)
            
            self.model, self.tokenizer = load_llm(model_type)
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        
            self.train_dataset = SquadDataset(split="train", tokenizer=self.tokenizer)
            self.test_dataset = SquadDataset(split="validation", tokenizer=self.tokenizer)

            # subset train_data
            num_subset = 5000
            self.train_dataset.questions = self.train_dataset.questions[:num_subset]
            self.train_dataset.answers = self.train_dataset.answers[:num_subset]

            # subset test_data
            num_subset = 1000
            self.test_dataset.questions = self.test_dataset.questions[:num_subset]
            self.test_dataset.answers = self.test_dataset.answers[:num_subset]

            # current prompts to generate simple responses...
            self.pre_conf_prompt = "Will you answer this question correctly? [/INST]"
            self.post_conf_prompt = "[INST] Did you answer this question correctly? [/INST]"

            if gpt_exp:
                self.explanation_prompts = gpt_explanation_prompts()
            elif gpt_state:
                self.explanation_prompts = gpt_state_prompts()
            elif random:
                self.explanation_prompts = random_prompts()
            else:
                self.explanation_prompts = explanation_prompts()

            # stores result in self.data and self.labels
            self.train_data, self.test_data = [], [] # explanation answers
            self.train_labels, self.test_labels = [], [] # is output correct on certain question
            self.train_log_probs, self.test_log_probs = [], [] # model log probs
            self.train_pre_confs, self.train_post_confs = [], [] # pre and post confidences
            self.test_pre_confs, self.test_post_confs = [], [] # pre and post confidences
            self.train_logits, self.test_logits = [], [] # logits

            if os.path.exists(path):
                self.train_data = np.load(folder_path + "/train_explanations.npy")
                self.train_labels = np.load(folder_path + "/train_labels.npy")
                self.train_log_probs = np.load(folder_path + "/train_log_probs.npy")
                self.train_pre_confs = np.load(folder_path + "/train_pre_confs.npy")
                self.train_post_confs = np.load(folder_path + "/train_post_confs.npy")
                self.train_logits = np.load(folder_path + "/train_logits.npy")

            else:
                self.train_data, self.train_labels, self.train_log_probs, \
                    self.train_pre_confs, self.train_post_confs, self.train_logits = self.process_data("train")
            
                # save result
                np.save(folder_path + "/train_explanations.npy", self.train_data)
                np.save(folder_path + "/train_labels.npy", self.train_labels)
                np.save(folder_path + "/train_log_probs.npy", self.train_log_probs)
                np.save(folder_path + "/train_pre_confs.npy", self.train_pre_confs)
                np.save(folder_path + "/train_post_confs.npy", self.train_post_confs)
                np.save(folder_path + "/train_logits.npy", self.train_logits)

            
            if os.path.exists(folder_path + "/test_explanations.npy"):
                self.test_data = np.load(folder_path + "/test_explanations.npy")
                self.test_labels = np.load(folder_path + "/test_labels.npy")
                self.test_log_probs = np.load(folder_path + "/test_log_probs.npy")
                self.test_pre_confs = np.load(folder_path + "/test_pre_confs.npy")    
                self.test_post_confs = np.load(folder_path + "/test_post_confs.npy")
                self.test_logits = np.load(folder_path + "/test_logits.npy")

            else:
                self.test_data, self.test_labels, self.test_log_probs, \
                    self.test_pre_confs, self.test_post_confs, self.test_logits = self.process_data("test")
                
                np.save(folder_path + "/test_explanations.npy", self.test_data)
                np.save(folder_path + "/test_labels.npy", self.test_labels)
                np.save(folder_path + "/test_log_probs.npy", self.test_log_probs)
                np.save(folder_path + "/test_pre_confs.npy", self.test_pre_confs)
                np.save(folder_path + "/test_post_confs.npy", self.test_post_confs)
                np.save(folder_path + "/test_logits.npy", self.test_logits)
        
            # delete model
            if self.model is not None:
                del self.model  
                gc.collect()

        else:
            print("Loading data")
            self.train_data = np.load(folder_path + "/train_explanations.npy")
            self.train_labels = np.load(folder_path + "/train_labels.npy")
            self.train_log_probs = np.load(folder_path + "/train_log_probs.npy")
            self.train_pre_confs = np.load(folder_path + "/train_pre_confs.npy")
            self.train_post_confs = np.load(folder_path + "/train_post_confs.npy")
            self.train_logits = np.load(folder_path + "/train_logits.npy")

            self.test_data = np.load(folder_path + "/test_explanations.npy")
            self.test_labels = np.load(folder_path + "/test_labels.npy")
            self.test_log_probs = np.load(folder_path + "/test_log_probs.npy")
            self.test_pre_confs = np.load(folder_path + "/test_pre_confs.npy")
            self.test_post_confs = np.load(folder_path + "/test_post_confs.npy")
            self.test_logits = np.load(folder_path + "/test_logits.npy")

    def process_data(self, split):

        # get ids of yes and no token - used later
        yes_token = "yes"
        no_token = "no"

        if self.add_token:
            yes_token_id = self.tokenizer.encode(yes_token, return_tensors="pt")[:, 1]
            no_token_id = self.tokenizer.encode(no_token, return_tensors="pt")[:, 1]
        else:
            yes_token_id = self.tokenizer.encode(yes_token, return_tensors="pt")
            no_token_id = self.tokenizer.encode(no_token, return_tensors="pt")
            yes_token_id = yes_token_id[:, 0]
            no_token_id = no_token_id[:, 0]

        all_data = []
        all_labels = []
        model_log_probs = []
        pre_confs = []
        post_confs = []
        all_logits = []

        if split == "train":
            base_dataset = self.train_dataset
        else:
            base_dataset = self.test_dataset

        # loop through questions 
        for q_ind, q in tqdm(enumerate(base_dataset.questions), total=len(base_dataset.questions)):

            answers = base_dataset.answers[q_ind] # answer subsequence
            answer_tokens = self.tokenizer(answers, return_tensors="pt").input_ids.to(device)
            max_len = answer_tokens.shape[1]

            input_ids = self.tokenizer.encode(q, return_tensors="pt").to(device)
            q_len = len(input_ids[0])

            # get highest probability generation from model
            with torch.no_grad():
                output = self.model.generate(input_ids, max_length=q_len + max_len, num_return_sequences=1, do_sample=False)
                output = self.tokenizer.decode(output[0, len(input_ids[0]):], skip_special_tokens=True)
                output = output.strip()


            # check if output matches answer
            if answers.strip().lower() in output.strip().lower(): # handle case like "the" or added punctuation
                all_labels.append(1)
            else:
                all_labels.append(0)
            
            # get last layer logits
            with torch.no_grad():
                logits = self.model(input_ids, return_dict=True).logits
                logits = logits[0, -1, :]
                all_logits.append(logits.cpu().numpy())

            # get pre confidence
            inputs = q[:-7] + self.pre_conf_prompt
            with torch.no_grad():
                logits = self.model(self.tokenizer(inputs, return_tensors="pt").input_ids.to(device), return_dict=True).logits
                logits = logits[0, -1, :]
                pre_dist = torch.stack([logits[yes_token_id], logits[no_token_id]], dim=0).squeeze()
                pre_dist = torch.nn.functional.softmax(pre_dist, dim=0)
                pre_conf = pre_dist[0].cpu().numpy().flatten()
            pre_confs.append(pre_conf)

            # get post confidence from its generated answer
            inputs = q + output + self.post_conf_prompt
            with torch.no_grad():
                logits = self.model(self.tokenizer(inputs, return_tensors="pt").input_ids.to(device), return_dict=True).logits
                logits = logits[0, -1, :]
                post_dist = torch.stack([logits[yes_token_id], logits[no_token_id]], dim=0).squeeze()
                post_dist = torch.nn.functional.softmax(post_dist, dim=0)
                post_conf = post_dist[0].cpu().numpy().flatten()
            post_confs.append(post_conf)

            # get model probabilities of generated answer
            inputs = q + " " + output
            token_dict = self.tokenizer(inputs, padding=True, return_tensors="pt")
            input_ids = token_dict.input_ids.to(device)
            
            with torch.no_grad():
                logits = self.model(input_ids, return_dict=True).logits
                logits = logits[0]
                output_logits = logits[q_len - 1: -1, :] # getting dist shifted by one
                probabilities = torch.nn.functional.softmax(output_logits, dim=1)
                output_tokens = token_dict.input_ids[0, q_len:].cpu().numpy()

                log_probs = torch.log(probabilities[range(probabilities.shape[0]), output_tokens])
                log_probs = log_probs.sum().item()

            model_log_probs.append(log_probs)

            # del from memory
            del input_ids
            del logits
            gc.collect()
            
            # get explanation responses
            exp_inputs = [inputs + " " + exp for exp in self.explanation_prompts]
            token_dict = self.tokenizer(exp_inputs, padding=True, return_tensors="pt")
            input_ids = token_dict.input_ids.to(device)
            
            with torch.no_grad():
                logits = self.model(input_ids, return_dict=True).logits
                last_token_id = token_dict.attention_mask.sum(1) - 1
            
            # get probability of yes (w.r.t. distribution [yes, no])
            if self.left_pad:
                logits = logits[:, -1, :]
            else:
                logits = logits[range(logits.shape[0]), last_token_id, :].squeeze()
            
            prob_dist = torch.stack([logits[:, yes_token_id], logits[:, no_token_id]], dim=1).squeeze()
            prob_dist = torch.nn.functional.softmax(prob_dist, dim=1)
            prob_dist = prob_dist[:, 0].cpu().numpy()
            prob_dist = prob_dist.reshape(-1, len(self.explanation_prompts))
            
            # store results
            all_data.append(prob_dist)
            del input_ids
            del logits
            gc.collect()

        all_data = np.concatenate(all_data, axis=0)
        all_labels = np.array(all_labels)
        model_log_probs = np.array(model_log_probs)
        pre_confs = np.array(pre_confs)
        post_confs = np.array(post_confs)
        all_logits = np.array(all_logits)

        return all_data, all_labels, model_log_probs, pre_confs, post_confs, all_logits

class SummarizationDataset(torch.utils.data.Dataset):

    def __init__(self, model_type, dataset="cnn", gpt_explanations=False, gpt_state=False):

        # check if path exists
        if dataset == "cnn":
            folder_path = "./data/CNN_outputs/" + model_type
        elif dataset == "xsum":
            folder_path = "./data/XSUM_outputs/" + model_type
        if gpt_explanations:
            folder_path += "_gpt"
        elif gpt_state:
            folder_path += "_gpt_state"
        path = folder_path + "/train_explanations.npy"
        
        # check if folder path exists
        if not os.path.exists(folder_path):
            print("Making dir", folder_path)
            os.makedirs(folder_path)

        if not os.path.exists(path) or not os.path.exists(folder_path + "/test_explanations.npy"):
            print("No data found at " + path)
            
            self.model, self.tokenizer = load_llm(model_type)
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        
            if dataset == "cnn":
                self.train_dataset = CNN_DM_Dataset(split="train", tokenizer=self.tokenizer)
                self.test_dataset = CNN_DM_Dataset(split="test", tokenizer=self.tokenizer)
            elif dataset == "xsum":
                self.train_dataset = XSUM_Dataset(split="train", tokenizer=self.tokenizer)
                self.test_dataset = XSUM_Dataset(split="test", tokenizer=self.tokenizer)

            # subset train_data
            num_subset = min(5000, len(self.train_dataset.questions))
            self.train_dataset.questions = self.train_dataset.questions[:num_subset]
            self.train_dataset.answers = self.train_dataset.answers[:num_subset]

            # subset test_data
            num_subset = min(1000, len(self.test_dataset.questions))
            self.test_dataset.questions = self.test_dataset.questions[:num_subset]
            self.test_dataset.answers = self.test_dataset.answers[:num_subset]

            self.model_type = model_type
            self.left_pad = get_left_pad(model_type)
            self.add_token = get_add_token(model_type)

            # current prompts to generate simple responses...
            self.pre_conf_prompt = "Will you answer this question correctly? [/INST]"
            self.post_conf_prompt = "[INST] Did you answer this question correctly? [/INST]"

            if gpt_explanations:
                self.explanation_prompts = gpt_explanation_prompts()
            elif gpt_state:
                self.explanation_prompts = gpt_state_prompts()
            else:
                self.explanation_prompts = explanation_prompts()

            # stores result in self.data and self.labels
            self.train_data, self.test_data = [], []
            self.train_labels, self.test_labels = [], []
            self.train_log_probs, self.test_log_probs = [], []
            self.train_pre_confs, self.train_post_confs = [], []
            self.test_pre_confs, self.test_post_confs = [], []
            self.train_logits, self.test_logits = [], []

            if os.path.exists(path):
                self.train_data = np.load(folder_path + "/train_explanations.npy")
                self.train_labels = np.load(folder_path + "/train_labels.npy")
                self.train_log_probs = np.load(folder_path + "/train_log_probs.npy")
                self.train_pre_confs = np.load(folder_path + "/train_pre_confs.npy")
                self.train_post_confs = np.load(folder_path + "/train_post_confs.npy")
                self.train_logits = np.load(folder_path + "/train_logits.npy")

            else:
                self.train_data, self.train_labels, self.train_log_probs, \
                    self.train_pre_confs, self.train_post_confs, self.train_logits = self.process_data("train")
            
                # save result
                np.save(folder_path + "/train_explanations.npy", self.train_data)
                np.save(folder_path + "/train_labels.npy", self.train_labels)
                np.save(folder_path + "/train_log_probs.npy", self.train_log_probs)
                np.save(folder_path + "/train_pre_confs.npy", self.train_pre_confs)
                np.save(folder_path + "/train_post_confs.npy", self.train_post_confs)
                np.save(folder_path + "/train_logits.npy", self.train_logits)


            if os.path.exists(folder_path + "/test_explanations.npy"):
                self.test_data = np.load(folder_path + "/test_explanations.npy")
                self.test_labels = np.load(folder_path + "/test_labels.npy")
                self.test_log_probs = np.load(folder_path + "/test_log_probs.npy")
                self.test_pre_confs = np.load(folder_path + "/test_pre_confs.npy")    
                self.test_post_confs = np.load(folder_path + "/test_post_confs.npy")
                self.test_logits = np.load(folder_path + "/test_logits.npy")

            else:
                self.test_data, self.test_labels, self.test_log_probs, \
                    self.test_pre_confs, self.test_post_confs, self.test_logits = self.process_data("test")
                
                np.save(folder_path + "/test_explanations.npy", self.test_data)
                np.save(folder_path + "/test_labels.npy", self.test_labels)
                np.save(folder_path + "/test_log_probs.npy", self.test_log_probs)
                np.save(folder_path + "/test_pre_confs.npy", self.test_pre_confs)
                np.save(folder_path + "/test_post_confs.npy", self.test_post_confs)
                np.save(folder_path + "/test_logits.npy", self.test_logits)

            # delete model
            if self.model is not None:
                del self.model
                gc.collect()

        else:
            print("Loading data")
            self.train_data = np.load(folder_path + "/train_explanations.npy")
            self.train_labels = np.load(folder_path + "/train_labels.npy")
            self.train_log_probs = np.load(folder_path + "/train_log_probs.npy")
            self.train_pre_confs = np.load(folder_path + "/train_pre_confs.npy")
            self.train_post_confs = np.load(folder_path + "/train_post_confs.npy")
            self.train_logits = np.load(folder_path + "/train_logits.npy")

            self.test_data = np.load(folder_path + "/test_explanations.npy")
            self.test_labels = np.load(folder_path + "/test_labels.npy")
            self.test_log_probs = np.load(folder_path + "/test_log_probs.npy")
            self.test_pre_confs = np.load(folder_path + "/test_pre_confs.npy")
            self.test_post_confs = np.load(folder_path + "/test_post_confs.npy")
            self.test_logits = np.load(folder_path + "/test_logits.npy")

    def process_data(self, split):
            
        # get ids of yes and no token - used later
        yes_token = "yes"
        no_token = "no"

        if self.add_token:
            yes_token_id = self.tokenizer.encode(yes_token, return_tensors="pt")[:, 1]
            no_token_id = self.tokenizer.encode(no_token, return_tensors="pt")[:, 1]
        else:
            yes_token_id = self.tokenizer.encode(yes_token, return_tensors="pt")
            no_token_id = self.tokenizer.encode(no_token, return_tensors="pt")
            yes_token_id = yes_token_id[:, 0]
            no_token_id = no_token_id[:, 0]

        all_data = []
        all_labels = []
        model_log_probs = []
        pre_confs = []
        post_confs = []
        all_logits = []

        if split == "train":
            base_dataset = self.train_dataset
        else:
            base_dataset = self.test_dataset

        # loop through questions 
        for q_ind, q in tqdm(enumerate(base_dataset.questions), total=len(base_dataset.questions)):

            answers = base_dataset.answers[q_ind]
            input_ids = self.tokenizer.encode(q, return_tensors="pt").to(device)
            q_len = len(input_ids[0])

            # get highest probability generation from model
            with torch.no_grad():
                output_tokens = self.model.generate(input_ids, max_length=min(q_len + 100, 4096), num_return_sequences=1, do_sample=False)
                output_tokens = output_tokens[0, len(input_ids[0]):] 
                output = self.tokenizer.decode(output_tokens, skip_special_tokens=True)
                output = output.strip()
            
            # get answer tokens
            answer_tokens = self.tokenizer(answers, return_tensors="pt").input_ids.to(device)

            # check if output matches answer -> compute "rouge score?"
            rouge = ROUGEScore()
            rouge_score = rouge(output, answers)["rougeL_fmeasure"]
            all_labels.append(rouge_score) # label is rouge score value

            # get last layer logits
            with torch.no_grad():
                logits = self.model(input_ids, return_dict=True).logits
                logits = logits[0, -1, :]
                all_logits.append(logits.cpu().numpy())

            # get pre confidence
            inputs = q[:-7] + self.pre_conf_prompt
            with torch.no_grad():
                logits = self.model(self.tokenizer(inputs, return_tensors="pt").input_ids.to(device), return_dict=True).logits
                logits = logits[0, -1, :]
                pre_dist = torch.stack([logits[yes_token_id], logits[no_token_id]], dim=0).squeeze()
                pre_dist = torch.nn.functional.softmax(pre_dist, dim=0)
                pre_conf = pre_dist[0].cpu().numpy().flatten()

            pre_confs.append(pre_conf)

            # get post confidence from its generated answer
            inputs = q + output + self.post_conf_prompt
            with torch.no_grad():
                logits = self.model(self.tokenizer(inputs, return_tensors="pt").input_ids.to(device), return_dict=True).logits
                logits = logits[0, -1, :]
                post_dist = torch.stack([logits[yes_token_id], logits[no_token_id]], dim=0).squeeze()
                post_dist = torch.nn.functional.softmax(post_dist, dim=0)
                post_conf = post_dist[0].cpu().numpy().flatten()

            post_confs.append(post_conf)

            # get model probabilities of generated answer
            inputs = q + " " + output
            token_dict = self.tokenizer(inputs, padding=True, return_tensors="pt")
            input_ids = token_dict.input_ids.to(device)

            with torch.no_grad():
                logits = self.model(input_ids, return_dict=True).logits
                logits = logits[0]
                output_logits = logits[q_len - 1: -1, :]

                probabilities = torch.nn.functional.softmax(output_logits, dim=1)
                output_tokens = token_dict.input_ids[0, q_len:].cpu().numpy()

                log_probs = torch.log(probabilities[range(probabilities.shape[0]), output_tokens])
                log_probs = log_probs.sum().item()

            model_log_probs.append(log_probs)

            # del from memory
            del input_ids
            del logits
            del output_tokens
            gc.collect()

            # get explanation responses by looping through by size num_batch
            prob_dist = np.zeros((len(self.explanation_prompts),))
            num_batch = 4
            for i in range(0, len(self.explanation_prompts), num_batch):
                exp_inputs = [inputs + " " + exp for exp in self.explanation_prompts[i:i+num_batch]]
                token_dict = self.tokenizer(exp_inputs, padding=True, return_tensors="pt")
                input_ids = token_dict.input_ids.to(device)

                with torch.no_grad():
                    logits = self.model(input_ids, return_dict=True).logits
                    last_token_id = token_dict.attention_mask.sum(1) - 1

                # get probability of yes (w.r.t. distribution [yes, no])
                if self.left_pad:
                    logits = logits[:, -1, :]
                else:
                    logits = logits[range(logits.shape[0]), last_token_id, :]
                prob_dist[i:i+num_batch] = torch.nn.functional.softmax(logits, dim=1)[:, 0].cpu().numpy()
    
                # del from memory
                del input_ids
                del logits
                gc.collect()
            all_data.append(prob_dist)

        # all_data = np.concatenate(all_data, axis=0)
        # print("all data", all_data.shape)
        all_data = np.array(all_data)
        all_labels = np.array(all_labels)
        model_log_probs = np.array(model_log_probs)
        pre_confs = np.array(pre_confs)
        post_confs = np.array(post_confs)
        all_logits = np.array(all_logits)
        
        return all_data, all_labels, model_log_probs, pre_confs, post_confs, all_logits


if __name__ == "__main__":

    # test BoolExplanationDataset
    dataset = BoolExplanationDataset(base_dataset="BooIQ", model_type="llama7b")
    print("Length of dataset: ", len(dataset))
    print("First item: ", dataset[0])
    pass