import os
import yaml
from argparse import ArgumentParser, Namespace
import sys
import torch
# Add the parent directory (project root) to sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification, 
    Trainer, 
    TrainingArguments,
    BitsAndBytesConfig
)
import evaluate
import numpy as np
from peft import PeftModel, LoraConfig, get_peft_model, TaskType

import utils
from thirdparty.tofu.data_module import TextDatasetQA, custom_data_collator as tofu_data_collator
from data_modules.base_data import load_tofu_train_dataset, load_arxiv_train_dataset, custom_data_collator_arxiv as arxiv_data_collator
from data_modules.data_module import UnwatermarkedTextDataset


def add_labels(example, idx):
    example['label'] = idx // 400
    return example


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed')
    parser.add_argument('--dataset_name', type=str, choices=['arxiv', 'tofu'],
                        help='Dataset name')
    parser.add_argument('--data_config_path', type=str,
                        help='Path to dataset and split config')
    parser.add_argument('--train_config_path', type=str,
                        help='Path to training config')
    parser.add_argument('--output_dir', type=str, default='results/',
                        help='Directory to save results and models')
    args = parser.parse_args()
    utils.set_seed(args.seed)
    
    # load data config
    with open(args.data_config_path, 'r') as f:
        data_config = Namespace(**yaml.safe_load(f))
    print('data_config:', vars(data_config))

    # load training config
    with open(args.train_config_path, 'r') as f:
        config = Namespace(**yaml.safe_load(f))
    print('train_config:', vars(config))
    
    tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model_name_or_path)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    # load dataset
    if args.dataset_name == 'tofu':
        train_data, _, _ = load_tofu_train_dataset(**vars(data_config))
        train_data = train_data.rename_column("answer", "text")
        train_data = train_data.remove_columns(["question"])
        
    elif args.dataset_name == 'arxiv':
        train_data, _, _ = load_arxiv_train_dataset(**vars(data_config))
    else:
        raise NotImplementedError
    
    num_samples = len(train_data)
    batch_size_per_class = 400
    num_labels = (num_samples + batch_size_per_class - 1) // batch_size_per_class
    new_labels = [i // batch_size_per_class for i in range(num_samples)]

    # Inject labels and Rename columns for HF Trainer
    train_data = train_data.add_column("label", new_labels)

    def tokenize_function(examples):
        return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)

    tokenized_dataset = train_data.map(tokenize_function, batched=True)

    model = AutoModelForSequenceClassification.from_pretrained(
        config.pretrained_model_name_or_path, 
        num_labels=num_labels,  # This creates the classification head size
        ignore_mismatched_sizes=True,
        # quantization_config=bnb_config,
        device_map="auto"
    )
    model.config.pad_token_id = tokenizer.pad_token_id
    
    lora_config = LoraConfig(
        task_type=TaskType.SEQ_CLS,
        target_modules=["q_proj", "v_proj"],
        **config.lora,
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    
    accuracy = evaluate.load("accuracy")

    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = np.argmax(predictions, axis=1)
        return accuracy.compute(predictions=predictions, references=labels)

    args = TrainingArguments(
        output_dir=args.output_dir,
        num_train_epochs=config.num_epochs,
        per_device_train_batch_size=config.train_batch_size, 
        learning_rate=config.learning_rate,
        save_strategy='no',
        bf16=True, 
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=tokenized_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    trainer.train()
    trainer.save_model(args.output_dir)
