import random
import re
import json
import blobfile as bf
import pandas
import numpy as np
import torch
import os
import copy
import gc
from datasets import Dataset, DatasetDict
# from unsloth import FastLanguageModel
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
import gc
# from trl import SFTTrainer
from transformers import TrainingArguments, TextStreamer, AutoModelForCausalLM, set_seed, AutoTokenizer, BitsAndBytesConfig
import math
from trl import  SFTTrainer, DataCollatorForCompletionOnlyLM
from accelerate import PartialState,Accelerator
from datasets import load_dataset
import argparse
from peft import get_peft_model,PeftModel
from peft import AutoPeftModelForCausalLM
import wandb
wandb.init(mode='disabled')




def model_optimization_accelerate():
    accelerator = Accelerator()
    parser = argparse.ArgumentParser()
    parser.add_argument('--sft_data', type=str, required=True)
    parser.add_argument('--model_load_path', type=str, required=True)
    parser.add_argument('--model_save_path', type=str, required=True)
    parser.add_argument('--logging_path', type=str, required=True)
    args = parser.parse_args()
    sft_data,model_load_path,model_save_path,logging_path=args.sft_data,args.model_load_path,args.model_save_path,args.logging_path
    # train_dataset = Dataset.from_json(sft_data)
    train_dataset=load_dataset('json',data_files=sft_data,split='train',cache_dir="./cache")

    # train_dataset = train_dataset.shuffle(seed=42)
    
    print(train_dataset)
    device_string = PartialState().process_index
    print(device_string)
    # model = AutoModelForCausalLM.from_pretrained(model_load_path,device_map={'':device_string},load_in_4bit=True,
    # attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16)
    model = AutoModelForCausalLM.from_pretrained(model_load_path,attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16,use_cache=False)
    tokenizer = AutoTokenizer.from_pretrained(model_load_path)#, trust_remote_code=True, pad_token="<|endoftext|>"

    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    # Avoiding mismatch between model input and tokenizer length size
    model.resize_token_embeddings(len(tokenizer))
#     peft_config = LoraConfig(
#     r=32,
#     lora_alpha= 32,
#     lora_dropout=0.1,
#     bias="none",
#     task_type="CAUSAL_LM"
# )
    # model = get_peft_model(
    #     model,
    #     peft_config
    #     )
    # model, tokenizer = FastLanguageModel.from_pretrained(
    # model_name=model_load_path,
    # max_seq_length=8*1000,
    # dtype=torch.bfloat16,
    # load_in_4bit=False,
    # device_map='cuda',
    # )
    # model = FastLanguageModel.get_peft_model(
    # model,
    # r = 16,
    # target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
    # lora_alpha=16,
    # lora_dropout=0.0,
    # bias='none',
    # use_gradient_checkpointing=True,
    # random_state=42,
    # use_rslora=False
    # )


    def template_dataset(examples):
        return {
            "text": tokenizer.apply_chat_template(examples["messages"], tokenize=False)
        }

    train_dataset = train_dataset.map(template_dataset, remove_columns=["messages"])
    # ds_splits = train_dataset.train_test_split(test_size=0.1)
    print(train_dataset)
    # print(train_dataset[0])
    print(train_dataset[-1])
    response_template = "<|start_header_id|>assistant<|end_header_id|>"
    collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
    trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    dataset_text_field='text',
    max_seq_length=8*1000,
    packing=False,#TODO
    args = TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        gradient_checkpointing=True,
        warmup_steps=50,
        learning_rate=3e-7,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        gradient_checkpointing_kwargs={'use_reentrant':False},
        logging_steps=20,
        optim='adamw_hf',
        weight_decay=0.0,
        lr_scheduler_type='constant',
        seed=42,
        output_dir=model_save_path,#+'/lora',
        logging_dir=model_save_path,
        save_steps=5000,
        # max_steps=10
        num_train_epochs=1,
        report_to='none'
    ),
    # data_collator=collator,
    # peft_config=peft_config
)
    trainer.train()
    # model.save_pretrained_merged(model_save_path, tokenizer, save_method = "merged_16bit",)

    trainer.save_model()
    # del model
    # del trainer
    # torch.cuda.empty_cache()
    # model = AutoPeftModelForCausalLM.from_pretrained(
    #     model_save_path,
    #     torch_dtype=torch.float16
    # )
    # Merge LoRA and base model and save
    # del model
    # del trainer
    # torch.cuda.empty_cache()
    # model = AutoModelForCausalLM.from_pretrained(model_load_path)
    # model.resize_token_embeddings(len(tokenizer))
    # model = PeftModel.from_pretrained(model, model_save_path+'/lora')
    # merged_model = model.merge_and_unload()
    # merged_model.save_pretrained(model_save_path,safe_serialization=True, max_shard_size="2GB")

    
    torch.cuda.empty_cache()
    for _ in range(3):
        gc.collect()
    del model
    del tokenizer
    del train_dataset
    # del ds_splits
    torch.cuda.empty_cache()
    for _ in range(3):
        gc.collect()

    return




if __name__=='__main__':

    model_optimization_accelerate()
    # return