from transformers import AutoModelForSequenceClassification, AutoTokenizer

from datasets import load_from_disk
import torch
import argparse
import polars as pl
from tqdm import tqdm
from utilities import check_path, get_metrics_result


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Eval")
    parser.add_argument("model", type=str, default="", help="Finetuned model")
    parser.add_argument("dataset", type=str, default="data/finetune/sft/", help="Data path")
    parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
    parser.add_argument("--dict", type=str, default="data/new/23/dict.json", help="Vocab path")
    parser.add_argument("--export", type=int, default=1, help="Export")
    parser.add_argument("--export_path", type=str, default="eval/finetuned/", help="Export path")
    
    args    = parser.parse_args()
    arg_model_path  = args.model
    arg_data_path   = args.dataset
    arg_batch_size  = args.batch_size
    arg_dict_path   = args.dict
    arg_export      = args.export
    arg_output_path = args.export_path


    device      = "cuda" if torch.cuda.is_available else "cpu"
    df_dict     = pl.read_json(arg_dict_path)
    model       = AutoModelForSequenceClassification.from_pretrained(arg_model_path, num_labels=df_dict.height).to(device)
    tokenizer   = AutoTokenizer.from_pretrained(arg_model_path)
    dataset     = load_from_disk(arg_data_path)["test"]
    
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id   = tokenizer.eos_token_id

    tokenized_data = dataset.map(lambda row: tokenizer(row["input"], truncation=True, max_length=512), batched=True, remove_columns=["input"]).batch(batch_size=arg_batch_size)
    tokenized_data.set_format("torch")
    y_pred  = []
    for batch in tqdm(tokenized_data, desc="Evaluatin'"):
        output  = model(batch["input_ids"].to(model.device),attention_mask=batch["attention_mask"].to(model.device))["logits"].argmax(dim=1)
        y_pred.extend(output.tolist())
    y_ref   = list(dataset["label"])

    df_result   = pl.DataFrame([y_ref, y_pred], schema=["ref", "pred"])
    (acc,bacc), prfs, cm, report    = get_metrics_result(df_result, verbose=True)
    if arg_export:
        check_path(arg_output_path)
        df_result.write_json(f"{arg_output_path}eval_results.json")
