import argparse
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    GenerationConfig,
    DataCollatorWithPadding
)
from tqdm import tqdm
from trl import SFTTrainer

from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
import torch
from torch.nn import functional as F
import time
import pandas as pd
import numpy as np
import datasets
import os
import polars as pl
from sklearn.metrics import balanced_accuracy_score, accuracy_score, classification_report, confusion_matrix



def compute_metrics(evaluations):
    predictions, labels = evaluations
    predictions = np.argmax(predictions, axis=1)
    return {"balanced_accuracy" : balanced_accuracy_score(predictions, labels),
    "accuracy": accuracy_score(predictions,labels)}



class CustomTrainer(Trainer):
    def __init__(self, *args, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        if class_weights is not None:
            self.class_weights = torch.tensor(class_weights, dtype=torch.float32).to(self.args.device)
        else:
            self.class_weights = None

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels").long()
        outputs = model(**inputs)
        logits = outputs.get("logits")
        if self.class_weights is not None:
            loss = F.cross_entropy(logits, labels, weight=self.class_weights)
        else:
            loss = F.cross_entropy(logits, labels)

        return (loss, outputs) if return_outputs else loss


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Finetuning the model")
    parser.add_argument("--model", type=str, default="meta-llama/Llama-3.1-8B", help="Base model")
    parser.add_argument("--dataset", type=str, default="data/finetune/sft/", help="Data path")
    parser.add_argument("--wandb_disabled", type=str, default="true", help="Wandb logging disabled")
    parser.add_argument("--export_path", type=str, default="finetuned/", help="Finetuned model model")
    
    args    = parser.parse_args()
    arg_model_path  = args.model
    arg_data_path   = args.dataset
    arg_wandb_disb  = args.wandb_disabled
    arg_output_path = args.export_path
    # disable Weights and Biases
    os.environ["WANDB_DISABLED"]=arg_wandb_disb
    # Load dataset
    dataset = datasets.load_from_disk(arg_data_path)
    # Use f4 to reduce memory consumption aka Quantisation
    compute_dtype = getattr(torch, "float16")
    bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_use_double_quant=False,
        )
    model_name  = arg_model_path
    device_map  = {"": 0}
    original_model = AutoModelForSequenceClassification.from_pretrained(model_name, 
                                                        device_map=device_map,
                                                        quantization_config=bnb_config,
                                                        trust_remote_code=True,
                                                        use_auth_token=True,
                                                        num_labels=2)
    tokenizer = AutoTokenizer.from_pretrained(model_name,trust_remote_code=True,padding_side="left",add_eos_token=True,add_bos_token=True,use_fast=False)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

    # def data_preprocesing(row, column="input"):
    #     return tokenizer(row[column], truncation=True, max_length=512)
    # Configure LoRA
    lora_config = LoraConfig(
        r = 16, 
        lora_alpha = 8,
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_dropout = 0.05, 
        bias = "none",
        task_type = "SEQ_CLS"
    )

    model = prepare_model_for_kbit_training(original_model)
    model = get_peft_model(model, lora_config)
    model.config.pad_token_id   = tokenizer.eos_token_id # Without this the model can't handle batch size > 1

    tokenized_data = dataset.map(lambda row: tokenizer(row["input"], truncation=True, max_length=512), batched=True, remove_columns=["input"])
    tokenized_data.set_format("torch")

    collate_fn = DataCollatorWithPadding(tokenizer=tokenizer)

    training_args = TrainingArguments(
        output_dir = arg_output_path,
        learning_rate = 1e-4,
        per_device_train_batch_size = 8,
        per_device_eval_batch_size = 8,
        num_train_epochs = 1,
        logging_steps=1,
        weight_decay = 0.01,
    )

    trainer = CustomTrainer(
        model = model,
        args = training_args,
        train_dataset = tokenized_data["train"],
        eval_dataset = tokenized_data["dev"],
        tokenizer = tokenizer,
        data_collator = collate_fn,
        compute_metrics = compute_metrics,
    )

    train_result = trainer.train()
