import torch

import os
import itertools
import argparse
import pickle as pkl
import random
import torch
import math
import json
import string
import logging
import numpy as np
import pandas as pd
import pdb
from tqdm import tqdm
from collections import Counter, defaultdict

from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from transformers import GPT2Tokenizer, AutoTokenizer, AutoModelForSequenceClassification
from transformers import RobertaTokenizer, RobertaModel, GPTJForCausalLM
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import AutoModelForCausalLM
from transformers import pipeline
import fire
import datasets

def main(
        dataset = "gsm8k",
        gpu = 0,
        roberta_name =  "roberta-large",
        if_qwa: bool = False,
        if_train: bool = False,
        **kwargs
        ):
    if(if_qwa):
        print("if_qwa is True")
    else:
        print("if_qwa is False")
    
    if(if_train):
        print("if_train is True")
    else:
        print("if_train is False")
    
    roberta_model_name = f"/home/amax/exp/huggingface/sentence_transformers/{roberta_name}"
    
    device= torch.device(gpu)
    test_file = f"./data/{dataset}/{dataset}_test.jsonl"
    train_file = f"./data/{dataset}/{dataset}_train.jsonl"
    tokenizer = RobertaTokenizer.from_pretrained(roberta_model_name)
    model = RobertaModel.from_pretrained(roberta_model_name)
    model = model.to(device)
    model.eval()
    if(not if_train):
        if(if_qwa):
            dir = f"./data/{dataset}/{dataset}_bert_score_qwa_{roberta_name}.npy"
        else:
            dir = f"./data/{dataset}/{dataset}_bert_score_{roberta_name}.npy"
    if(if_train):
        if(if_qwa):
            dir = f"./data/{dataset}/{dataset}_train_self_bert_score_qwa_{roberta_name}.npy"
        else:
            dir = f"./data/{dataset}/{dataset}_train_self_bert_score_{roberta_name}.npy"

    test_dataset = read_jsonl(test_file, if_qwa)
    print(f"Reading {test_file}, total number of entries {len(test_dataset)}")
    train_dataset = read_jsonl(train_file, if_qwa)
    print(f"Reading {train_file}, total number of entries {len(train_dataset)}")
    if(if_train):
        calculate_bertscores(train_dataset, train_dataset, model, tokenizer, device, dir)
    else:
        calculate_bertscores(test_dataset, train_dataset, model, tokenizer, device, dir)


def read_jsonl(fname, if_qwa=False):
    with open(fname, 'r') as f:
        lines = f.readlines()
        
    datas = [json.loads(line) for line in lines]
    if if_qwa:
        for data in datas:
            data["question"] = f'Question: {data["question"]}\nAnswer: {data["answer"]}'
    return datas


@torch.no_grad()
def calculate_bertscores(test_dataset, train_dataset, model, tokenizer, device, dir):
    embedding_tmp = calculate_embeddings(test_dataset[0]["question"], model, tokenizer, device)
    embedding_tmp = embedding_tmp.detach().cpu().numpy()
    len1 = embedding_tmp.shape[0]
    len2 = embedding_tmp.shape[1]
    
    embedding_tests = []
    embedding_trains = []

    results = np.zeros((len(test_dataset), len(train_dataset)))
    print(f"results shape: {results.shape}, {results[0].shape}")
    for idx, data_test in tqdm(enumerate(test_dataset)):
        embedding_test = calculate_embeddings(data_test["question"], model, tokenizer, device)
        embedding_test = embedding_test.detach().cpu().numpy()
        embedding_test = embedding_test / np.linalg.norm(embedding_test, axis=1, keepdims=True)
        embedding_tests.append(embedding_test)
    for idx_train, data_train in enumerate(train_dataset):            
        embedding_train = calculate_embeddings(data_train["question"], model, tokenizer, device)            
        embedding_train = embedding_train.detach().cpu().numpy()
        embedding_train = embedding_train / np.linalg.norm(embedding_train, axis=1, keepdims=True)
        embedding_trains.append(embedding_train)
    
    for idx_test, data_test in tqdm(enumerate(test_dataset)):
        embedding_test = embedding_tests[idx_test]
        for idx_train, data_train in enumerate(train_dataset):
            try:
                embedding_train = embedding_trains[idx_train] 
                inner_product = embedding_test @ embedding_train.T
                p_bert = np.mean(np.max(inner_product, axis=1))
                results[idx_test, idx_train] = p_bert
                if(idx_test == 0):
                    print(f"Sample {idx_test + 1} p_bert: {p_bert}")
            except:
                print(f"Sample {idx_test + 1} failed, skipping...")
    np.save(dir, results)
    print(f"Results saved to {dir}")

def calculate_embeddings(text, model, tokenizer, device, seq_len=512):
    with torch.no_grad():
        inputs = tokenizer(text, return_tensors="pt").to(device)
    
        inputs["input_ids"] = inputs["input_ids"][:, :seq_len]
        inputs["attention_mask"] = inputs["attention_mask"][:, :seq_len]

        outputs = model(**inputs, output_hidden_states=True)
    embeddings = outputs.hidden_states[17][:,1:-1,:]
    embeddings = torch.squeeze(embeddings, 0)
    torch.cuda.synchronize()
    return embeddings

if __name__ == '__main__':
    fire.Fire(main)
