from transformers import AutoTokenizer, AutoModelForSequenceClassification
#import torch
import numpy as np

# import torch.nn.functional as F
# from torch import Tensor
import pickle
import os
import argparse
import datetime
import numpy as np
import time
from tqdm import tqdm

from semantic_functions import get_labels_logits_and_reps_cross

#get_semantic_ids, get_semantic_matrix, get_semantic_conf_matrix_in_cluster, get_semantic_matrix_and_reps

parser = argparse.ArgumentParser()
parser.add_argument("--dirnameA", type=str)
parser.add_argument("--dirnameB", type=str)
parser.add_argument("--subset", type=int)
parser.add_argument("--trial", type=int)
parser.add_argument("--gpu", type=int)
parser.add_argument("--type",  type=str, default="greedy")
args = parser.parse_args()

start_time = time.time()

current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

# Get the name of the results' folder
path_components_A = args.dirnameA.split(os.sep)
path_components_B = args.dirnameB.split(os.sep)
name1, name2 = sorted([path_components_A[2], path_components_B[2]])
result_dir = os.path.join(path_components_A[0],  path_components_A[1], 'cross', name1+'-X-'+name2)
print(result_dir)
ids_and_conf_dir = os.path.join(result_dir, f"random_subsample_{args.type}", f"subset_{args.subset}")
os.makedirs(ids_and_conf_dir, exist_ok=True)
ids_and_conf_path = os.path.join(ids_and_conf_dir, f"trial_{args.trial}.pickle")
print(ids_and_conf_path)

print(f"Semantic ids and confidence levels, subset size = {args.subset}")

device = f"cuda:{args.gpu}"

# Load the model and tokenizer

model_name = "microsoft/deberta-v2-xlarge-mnli"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='../cache/models')
model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir='../cache/models', output_hidden_states=False).to(device)
# else:
#     model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir='../cache/models').to(device)
labels = model.config.id2label
print(labels)

# Load the responses
pickle_file_path = os.path.join(args.dirnameA, "results.pickle")
with open(pickle_file_path, 'rb') as file:
    dataA = pickle.load(file)
pickle_file_path = os.path.join(args.dirnameB, "results.pickle")
with open(pickle_file_path, 'rb') as file:
    dataB = pickle.load(file)

# Load subset indices
if args.subset > 0:
    answer_dir = os.path.dirname(args.dirnameA)
    subset_path = os.path.join(answer_dir, f"subsample_indices/subset_{args.subset}/trial_{args.trial}.txt") 
    with open(subset_path, 'r') as f:
        indices = np.array([int(line.strip()) for line in f])

ids_confs = {}
for q, a_A in tqdm(dataA.items()):

    a_B = dataB[q]
    if args.subset == -1:
        strings_list_A = a_A
        strings_list_B = a_B
    else:
        strings_list_A = [a_A[i] for i in indices]
        strings_list_B = [a_B[i] for i in indices]

    if args.type == "labels_logits_and_reps":
        r1 = get_labels_logits_and_reps_cross(tokenizer, model, labels, strings_list_A, strings_list_B, device)
        r2 = get_labels_logits_and_reps_cross(tokenizer, model, labels, strings_list_B, strings_list_A, device)
        r = {path_components_A[2]: r1, path_components_B[2]: r2}
    else:
        raise NotImplementedError
    ids_confs[q] = r

with open(ids_and_conf_path, 'wb') as f:
    pickle.dump(ids_confs, f)

print("Done!!")
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time} seconds")
