from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np

import torch
import torch.nn.functional as F
from torch import Tensor
import pickle
import os
import sys
import argparse
import datetime
import numpy as np
import time

from semantic_functions import get_labels_logits_and_reps

parser = argparse.ArgumentParser()
parser.add_argument("--dirname", type=str)
parser.add_argument("--subset", type=int)
parser.add_argument("--trial", type=int)
parser.add_argument("--cat", type=int, default=0)
parser.add_argument("--max_length", type=int, default=-1)
parser.add_argument("--gpu", type=int)
parser.add_argument("--type",  type=str, default="greedy")
args = parser.parse_args()

start_time = time.time()

current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
device = f"cuda:{args.gpu}"

# Load the model and tokenizer

model_name = "microsoft/deberta-v2-xlarge-mnli"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='../cache/models')
model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir='../cache/models', output_hidden_states=False).to(device)
# else:
#     model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir='../cache/models').to(device)
labels = model.config.id2label
print(labels)

# Load the responses
pickle_file_path = os.path.join(args.dirname, "results.pickle")
with open(pickle_file_path, 'rb') as file:
    data = pickle.load(file)

# Load subset indices
answer_dir = os.path.dirname(args.dirname)
subset_path = os.path.join(answer_dir, f"subsample_indices/subset_{args.subset}/trial_{args.trial}.txt") 
with open(subset_path, 'r') as f:
    indices = np.array([int(line.strip()) for line in f])

ids_confs = {}
for q, a in data.items():

    if args.cat == 1:
        strings_list = [q+a[i] for i in indices]
    else:
        strings_list = [a[i] for i in indices]

    if args.type == "labels_logits_and_reps":
        if args.max_length == -1:
            max_length = None
        else:
            max_length = args.max_length
        r = get_labels_logits_and_reps(tokenizer, model, labels, strings_list, device, max_length=max_length)
        #print(r)
    else:
        raise NotImplementedError
    ids_confs[q] = r

if args.cat == 1:
    if args.max_length == -1:
        ids_and_conf_dir = os.path.join(args.dirname, f"random_subsample_{args.type}_cat_{args.cat}", f"subset_{args.subset}")
    else:
        ids_and_conf_dir = os.path.join(args.dirname, f"random_subsample_{args.type}_cat_{args.cat}_len_{args.max_length}", f"subset_{args.subset}")     
else:
    if args.max_length == -1:
        ids_and_conf_dir = os.path.join(args.dirname, f"random_subsample_{args.type}", f"subset_{args.subset}")
    else:
        ids_and_conf_dir = os.path.join(args.dirname, f"random_subsample_{args.type}_len_{args.max_length}", f"subset_{args.subset}")
os.makedirs(ids_and_conf_dir, exist_ok=True)
ids_and_conf_path = os.path.join(ids_and_conf_dir, f"trial_{args.trial}.pickle")
with open(ids_and_conf_path, 'wb') as f:
    pickle.dump(ids_confs, f)

print("Done!!")
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time} seconds")

