import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from tqdm import tqdm

import torch as t
import numpy as np
import argparse
from cut.utils import load_model
import json

def prepare_data(data, batch_size=8):
    """
    Return a generator of batches of the form (text_batch, answers_batch). 
    """
    batch = []
    for row in data:

        for i in range(4): 
            question = f"Question: {row['question']}\n\nAnswer: {row['choices'][i]}"
            label = 1 if i == row['answer'] else 0
            batch.append((question, label))
        if len(batch) == batch_size:
            yield batch
            batch = []
            
def prepare_data(data, batch_size=8): 
    """
    Returns a generator of batches of the form (text_batch, answers_batch)
    """
    batch = []
    for row in data: 
        question = f"""\
The following are multiple choice questions (with answers).

{row['question']}
A. {row['choices'][0]}
B. {row['choices'][1]}
C. {row['choices'][2]}
D. {row['choices'][3]}
Answer:"""
        label = [0,0,0,0]
        label[row['answer']] = 1
        batch.append((question, label))
        if len(batch) == batch_size: 
            yield batch
            batch = []
    

def get_activations(model, tokenizer, batches): 
    labels = []
    acts = []
    with t.no_grad(): 
        for batch in tqdm(batches): 
            # get input_ids
            texts = [x[0] for x in batch]
            batch_labels = [x[1] for x in batch]
            labels.extend(batch_labels)
            inputs = tokenizer(texts, return_tensors="pt", padding=True).to(model.device)
            out = model(**inputs, output_hidden_states=True)
            last_pos_hiddens = [x[:,-1].detach().cpu().float().numpy() for x in out.hidden_states]
            last_pos_hiddens = np.array(last_pos_hiddens).transpose(1,0,2).tolist()
            acts.extend(last_pos_hiddens)
    acts = np.stack(acts)
    labels = np.array(labels)
    return acts, labels


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Evaluate a model on a multiple choice dataset')
    parser.add_argument('--model_name_or_path', type=str, default="saprmarks/unlearned_nonrand")
    parser.add_argument('--data_path', type=str, default='data/bio-questions.json')
    parser.add_argument('--batch_size', type=int, default=4, help='The batch size to use for evaluation')
    parser.add_argument('--save_path', type=str, default='data/bio-unlearned')
    args = parser.parse_args()

    model, tokenizer = load_model(args.model_name_or_path)
    t.set_grad_enabled(False)

    with open(args.data_path, 'r') as f: 
        data = json.loads(f.read())
    batches = prepare_data(data, args.batch_size)
    acts, labels = get_activations(model, tokenizer, batches)
    
    np.save(f'{args.save_path}-acts.npy', acts)
    np.save(f'{args.save_path}-labels.npy', labels)
    