import transformers
from datasets import load_dataset
from llm_modeling_revised import LLMCompressor, QGCArguments, ModelArguments, TrainingArguments
# from bert_training_utils import instruct_ft_tokenize_function, DataCollatorForDynamicPadding
import torch
from peft import (
    LoraConfig, get_peft_model
)

import numpy as np
import random
from safetensors.torch import load_file

from PWC_dataset_restatement import TrainDataset as PwC_TrainDataset
from QA_dataset import InferDataset, TrainDataset

from icecream import ic as pprint
from transformers import Trainer, EarlyStoppingCallback
import os

def seed_everything(TORCH_SEED):
    random.seed(TORCH_SEED)
    os.environ['PYTHONHASHSEED'] = str(TORCH_SEED)
    np.random.seed(TORCH_SEED)
    torch.manual_seed(TORCH_SEED)
    torch.cuda.manual_seed_all(TORCH_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_checkpoint(
    model, 
    weight_path, 
    ft_flag
):
    
    # state_dict = torch.load(checkpoint_path, map_location='cpu')
    state_dict = load_file(weight_path)
    if ft_flag:
        state_dict = {k: v for k, v in state_dict.items() if not k.startswith('llm.model')}
    
    model.load_state_dict(state_dict, strict=False)

def main():
    seed_everything(42)
    parser = transformers.HfArgumentParser((QGCArguments, ModelArguments, TrainingArguments))
    args, model_args, training_args = parser.parse_args_into_dataclasses()
    
    lora_config = LoraConfig(
        r=model_args.lora_r,
        lora_alpha=model_args.lora_alpha,
        lora_dropout=model_args.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
        # target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    )

    model = LLMCompressor(
        model_args, 
        training_args, 
        lora_config, 
        args.max_doc_tokens
    ) 

    train_dataset = PwC_TrainDataset(
        filepath=args.data_path,
        model=model,
        cmp_tokenizer=model.tokenizer,
        llm_tokenizer=model.llm_tokenizer,
        max_doc_tokens=args.max_doc_tokens,
        instruction_name=args.instruction_name,
        lm_ratio=training_args.lm_ratio,
        leave_len=training_args.leave_len,
        prefix_type=training_args.prefix_type
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        # eval_dataset=eval_dataset,
        data_collator=train_dataset.collate_fn,
        # callbacks=[EarlyStoppingCallback(early_stopping_patience=training_args.early_stopping_patience)]    
    )

    train_result = trainer.train()
    
    trainer.save_model()
    trainer.log_metrics("train", train_result.metrics)    

if __name__ == "__main__":
    main()