from transformers import AutoModelForSequenceClassification, AutoTokenizer

from datasets import load_from_disk
import torch
import argparse
import polars as pl
from tqdm import tqdm
from utilities import check_path

from sklearn.metrics import (
    balanced_accuracy_score, 
    accuracy_score, 
    classification_report, 
    confusion_matrix,
    precision_recall_fscore_support,
)

def get_metrics_result(test_df:pl.DataFrame):
    y_test = test_df.select("label").to_series()
    y_pred = test_df.select("predictions").to_series()

    print("Classification Report:", classification_report(y_test, y_pred))
    print("Balanced Accuracy Score:", balanced_accuracy_score(y_test, y_pred))
    print("Accuracy Score:", accuracy_score(y_test, y_pred))
    print("PRFS:", precision_recall_fscore_support(y_test, y_pred, labels=(0,1)))


def generate_predictions(model, tokenizer, df_test, input_col="input"):
    sentences   = df_test.select(input_col).to_series()
    all_outputs = []

    for sentence in tqdm(sentences):
        inputs  = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=512)
        inputs  = {k: v.to("cuda" if torch.cuda.is_available() else "cpu") for k, v in inputs.items()}
        with torch.no_grad():
            outputs     = model(**inputs)
            all_outputs.append(outputs["logits"].argmax().item())
    df_test = df_test.with_columns(pl.Series(all_outputs).alias("predictions"))
    return df_test

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Eval")
    parser.add_argument("model", type=str, default="", help="Finetuned model")
    parser.add_argument("dataset", type=str, default="data/finetune/sft/", help="Data path")
    parser.add_argument("--export", type=int, default=1, help="Export")
    parser.add_argument("--export_path", type=str, default="eval/finetuned/", help="Export path")
    
    args    = parser.parse_args()
    arg_model_path  = args.model
    arg_data_path   = args.dataset
    arg_export      = args.export
    arg_output_path = args.export_path

    device      = "cuda" if torch.cuda.is_available else "cpu"

    model       = AutoModelForSequenceClassification.from_pretrained(arg_model_path).to(device)
    tokenizer   = AutoTokenizer.from_pretrained(arg_model_path)
    df_test     = load_from_disk(arg_data_path)["test"].to_polars()

    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id   = tokenizer.eos_token_id

    df_result   = generate_predictions(model, tokenizer, df_test)
    (acc, bacc), prfs, cm, report   = get_metrics_result(df_result)
    if arg_export:
        check_path(arg_output_path)
        df_result.write_json(f"{arg_output_path}eval_results.json")
