import torch
import pandas as pd
import numpy as np
from dataclasses import dataclass, field
from typing import Optional
from transformers import AutoTokenizer, EarlyStoppingCallback, Trainer, TrainingArguments
from transformers import AutoModelForCausalLM, AutoConfig, DataCollatorForLanguageModeling
from datasets import load_dataset
import wandb
import os
import sys


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="BSM/config/110M_model_config.json")
    


@dataclass
class DataArguments:
    data_path: str = field(default=None, metadata={"help": "Path to the training data."})
    round: int = field(default=1, metadata={"help": "Pretraining round number."})


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    logging_dir: Optional[str] = field(default=./logs)
    run_name: str = field(default="run")
    logging_steps: int = field(default=100)
    model_max_length: int = field(default=1024, metadata={"help": "Maximum sequence length."})
    gradient_accumulation_steps: int = field(default=50)
    num_train_epochs: int = field(default=1)
    batch_size: int = field(default=4096)
    fp16: bool = field(default=True)
    logging_steps: int = field(default=10)
    save_strategy: str = field(default="steps")
    save_steps: int = field(default=200)
    eval_steps: int = field(default=200)
    max_steps: int = field(default=1000000), 
    evaluation_strategy: str = field(default="steps")
    weight_decay: float = field(default=0.01)
    learning_rate: float = field(default=2e-5)
    save_total_limit: int = field(default=3)
    load_best_model_at_end: bool = field(default=True)
    output_dir: str = field(default="./results/models")
    checkpointing: bool = field(default=False)
    dataloader_pin_memory: bool = field(default=False)
    eval_and_save_results: bool = field(default=True)
    save_model: bool = field(default=True)
    seed: int = field(default=42)
    



class Dataset(Dataset):
 
    def __init__(self, 
                 data_path: str, 
                 tokenizer: transformers.PreTrainedTokenizer
                 ):

        super(Dataset, self).__init__()

        texts = load_dataset("csv", data_files=data_path)

        output = tokenizer(
            texts,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        )




def train():
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # load tokenizer
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=True,
        trust_remote_code=True,
    )

    wandb.init(project="BSM-pretraining", name=training_args.run_name)

    # define datasets and data collator
    train_dataset = Dataset(tokenizer=tokenizer, 
                                      data_path=os.path.join(data_args.data_path, "pretrain.csv"))
    val_dataset = Dataset(tokenizer=tokenizer, 
                                     data_path=os.path.join(data_args.data_path, "lucaone_valid.csv"))


    # load model
    model = transformers.AutoModelForCausalLM.from_config(
        model_args.model_name_or_path,
        num_labels=train_dataset.num_labels,
        trust_remote_code=True,
    )

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    # define trainer
    trainer = transformers.Trainer(model=model,
                                   tokenizer=tokenizer,
                                   args=training_args,
                                   train_dataset=train_dataset,
                                   eval_dataset=val_dataset,
                                   data_collator=data_collator)
    trainer.train()



if __name__ == "__main__":
    train()
