import transformers
from datasets import load_dataset
from llm_modeling_revised import LLMCompressor, QGCArguments, ModelArguments, TrainingArguments
# from bert_training_utils import instruct_ft_tokenize_function, DataCollatorForDynamicPadding
import torch
from peft import (
    LoraConfig,
)

import numpy as np
import random
from safetensors.torch import load_file

from PWC_dataset_restatement import TrainDataset as PwC_TrainDataset
from QA_dataset import InferDataset, TrainDataset

from icecream import ic as pprint
from transformers import Trainer, EarlyStoppingCallback
import os

def seed_everything(TORCH_SEED):
    random.seed(TORCH_SEED)
    os.environ['PYTHONHASHSEED'] = str(TORCH_SEED)
    np.random.seed(TORCH_SEED)
    torch.manual_seed(TORCH_SEED)
    torch.cuda.manual_seed_all(TORCH_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_checkpoint(
    model, 
    weight_path, 
    ft_flag
):
    
    # state_dict = torch.load(checkpoint_path, map_location='cpu')
    state_dict = load_file(weight_path)
    if ft_flag:
        state_dict = {k: v for k, v in state_dict.items() if not k.startswith('llm.model')}
    
    model.load_state_dict(state_dict, strict=False)

def main():
    seed_everything(42)
    parser = transformers.HfArgumentParser((QGCArguments, ModelArguments, TrainingArguments))
    args, model_args, training_args = parser.parse_args_into_dataclasses()
    
    lora_config = LoraConfig(
        r=model_args.lora_r,
        lora_alpha=model_args.lora_alpha,
        lora_dropout=model_args.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
        # target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    )

    proj_lora_config = LoraConfig(
        r=model_args.lora_r,
        lora_alpha=model_args.lora_alpha,
        lora_dropout=model_args.lora_dropout,
        bias="none",
        target_modules=["q_proj", "k_proj", "v_proj"]
    )

    model = LLMCompressor(
        model_args, 
        training_args, 
        lora_config, 
        proj_lora_config
    ) 

    # if training_args.fine_tune:
    #     train_dataset = FineTune_Dataset(
    #         filepath=args.data_path,
    #         # model=model,
    #         tokenizer=model.tokenizer,
    #         max_doc_tokens=args.max_doc_tokens,
    #     )
    # else:
    if args.max_doc_tokens == 512:
        train_dataset = PwC_TrainDataset(
            filepath=args.data_path,
            model=model,
            cmp_tokenizer=model.tokenizer,
            llm_tokenizer=model.llm_tokenizer,
            max_doc_tokens=args.max_doc_tokens,
            instruction_name=args.instruction_name,
            lm_ratio=training_args.lm_ratio,
            leave_len=training_args.leave_len,
            prefix_type=training_args.prefix_type
        )

        # eval_dataset = PwC_TrainDataset(
        #     filepath=os.path.join(args.data_path, f'process_test_PwC.jsonl'),
        #     model=model,
        #     cmp_tokenizer=model.tokenizer,
        #     llm_tokenizer=model.llm_tokenizer,
        #     max_doc_tokens=args.max_doc_tokens,
        #     # que_mask_ratio=args.question_mask_ratio,
        #     instruction_name=args.instruction_name,
        # )
    else:
        # max_num_documents=None,
        # min_num_documents=None,
        # random_num_documents=False,
        # num_gold_documents=1,
        # use_answer_as_target=False,
        # gold_first_for_kd=False,
        
        # PwC train set: mixed_train_PwC.jsonl
        # os.path.join(args.data_path, 'sampled_output.jsonl')
        train_dataset = PwC_TrainDataset(
            filepath=args.data_path,
            model=model,
            cmp_tokenizer=model.tokenizer,
            llm_tokenizer=model.llm_tokenizer,
            max_doc_tokens=args.max_doc_tokens,
            instruction_name=args.instruction_name,
            lm_ratio=training_args.lm_ratio,
            leave_len=training_args.leave_len,
            prefix_type=training_args.prefix_type
        )

    # eval_dataset = PwC_TrainDataset(
    #     filepath=os.path.join(args.data_path, f'PwC_format_test.jsonl'),
    #     model=model,
    #     cmp_tokenizer=model.tokenizer,
    #     llm_tokenizer=model.llm_tokenizer,
    #     max_doc_tokens=args.max_doc_tokens,
    #     # que_mask_ratio=args.question_mask_ratio,
    #     instruction_name=args.instruction_name,
    #     lm_ratio=training_args.lm_ratio,
    #     leave_len=training_args.leave_len,
    #     prefix_type=training_args.prefix_type
    # )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        # eval_dataset=eval_dataset,
        data_collator=train_dataset.collate_fn,
        # callbacks=[EarlyStoppingCallback(early_stopping_patience=training_args.early_stopping_patience)]    
    )

    # checkpoint = None
    
    checkpoint = training_args.resume_from_checkpoint

    print(f"Loaded from the checkpoint: {checkpoint}")
    
    if (checkpoint is not None) and training_args.fine_tune:
        load_checkpoint(
            model, 
            checkpoint, 
            training_args.fine_tune
        )
        
    if not training_args.fine_tune:
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
    else:
        train_result = trainer.train()
    trainer.deepspeed.save_checkpoint("test_result")
    trainer.save_model()
    trainer.log_metrics("train", train_result.metrics)    

if __name__ == "__main__":
    main()