# (c) 2023 anonymous authors, not to be distributed or used for commercial purposes.

import os
import argparse
import logging
from tqdm import tqdm
import json

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset
from datasets import logging as datasets_logging

import numpy as np
import pandas as pd
import seaborn as sns
import sys

import pickle

from huggingface_hub.hf_api import HfFolder

HfFolder.save_token(os.environ["HF_TOKEN"])

torch.set_grad_enabled(False)

import sys

sys.path.append("..")
from metrics import perplexity, entropy

DEVICE = torch.device(0 if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.deterministic = True

falcon_model_id = input("Falcon model ID: ")
falcon_instruct_model_id = input("Falcon instruct model ID: ")

falcon_tokenizer = AutoTokenizer.from_pretrained(falcon_model_id)

tokens_seen = int(input("Enter tokens_seen"))

falcon = AutoModelForCausalLM.from_pretrained(
    falcon_model_id,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    token=os.environ["HF_TOKEN"],
    device_map={"": "cuda" if torch.cuda.is_available() else "cpu"},
)

falcon_instruct = AutoModelForCausalLM.from_pretrained(
    falcon_instruct_model_id,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    token=os.environ["HF_TOKEN"],
    device_map={"": "cuda" if torch.cuda.is_available() else "cpu"},
)

batch_size = 32


def score_samples(base_dir, dir_iter):
    print(dir_iter)
    for dir_name in [dir_iter]:


        print(dir_name)
        print(f"Tokens Used: {tokens_seen}")
        print(f"Batch Size: {batch_size}")
        file_names = []
        token_counts_falcon = np.array([])

        file_cat = []

        try:             
            os.listdir(f"{base_dir}/{dir_name}")
        except FileNotFoundError:
            print(f"File not found")
            exit(0)

        texts = []
        for txt_file_name in os.listdir(f"{base_dir}/{dir_name}"):
            if not txt_file_name.endswith(".txt"):
                continue
            file_names.append(txt_file_name)
            with open(f"{base_dir}/{dir_name}/{txt_file_name}", "r") as f:
                texts.append(f.read())
            file_cat.append(dir_name)

        if not falcon_tokenizer.pad_token:
            falcon_tokenizer.pad_token = falcon_tokenizer.eos_token

        falcon_pad_token_id = falcon_tokenizer.pad_token_id
        falcon_padding_side = falcon_tokenizer.padding_side

        fal_num = np.array([])
        fal_denom = np.array([])

        print("Calculating falcon scores...")

        for chunk_start in tqdm(range(0, len(texts), batch_size), file=sys.stdout):
            batch = texts[chunk_start : chunk_start + batch_size]
            encoding = falcon_tokenizer(
                batch,
                return_tensors="pt",
                padding="longest" if batch_size > 1 else False,
                truncation=True,
                max_length=tokens_seen,
                return_token_type_ids=False,
            ).to(DEVICE)

            token_counts_falcon = np.concatenate(
                (token_counts_falcon, encoding.attention_mask.sum(1).cpu().numpy())
            )

            logits_f = falcon(**encoding).logits
            logits_fi = falcon_instruct(**encoding).logits
            _, _, ppl_fi = perplexity(encoding, logits_fi)

            fal_num = np.concatenate((fal_num, ppl_fi))

            c_ent = entropy(
                logits_f, logits_fi, encoding, falcon_pad_token_id, falcon_padding_side
            )
            fal_denom = np.concatenate((fal_denom, c_ent))

            del encoding, logits_f, logits_fi
            torch.cuda.empty_cache()

        df = dict(
            file_name=file_names,
            file_cat=file_cat,
            token_counts_falcon=token_counts_falcon,
            falcon_ppl=fal_num,
            falcon_cross_ppl=fal_denom,
            falcon_score=fal_num / fal_denom,
        )

        df = pd.DataFrame(df)

        df.to_csv(f"{base_dir}/{dir_name}-{tokens_seen}.csv", index=False)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # Dataset arguments
    print("=" * 60, "START", "=" * 60)

    parser.add_argument("--dataset_path", type=str, default="/")
    parser.add_argument("--dataset_name", type=str)

    args = parser.parse_args()
    score_samples(args.dataset_path, args.dataset_name)

    print("=" * 60, "START", "=" * 60)
