from transformers import (
    LlamaTokenizer,
    LlamaForCausalLM, 
    TrainingArguments,
)
from accelerate import Accelerator
import torch
import json
# import torch.distributed as dist
from src.dataset import MMapIndexedDatasetWithContext  

import os
os.environ['VAL_MODE']='0'
os.environ["RWKV_FLOAT_MODE"]='BF16'
os.environ['WANDB_MODE'] = 'disabled'

accelerator = Accelerator()

resume_training = False

from sparsity_utils import SrLlamaForCausalLM
import utils
args = utils.parser_gen()
args.rotate = True
args.wandb = False

output_dir = args.proj_dir

model = SrLlamaForCausalLM.from_pretrained(
    args.model,
    torch_dtype = torch.bfloat16,  
).to(accelerator.device)

model.train_mode = True
model.relu_init_mode = True   
if resume_training:
    model.relu_init_mode = False   
    
model.act_wrapper()

model.train()

teacher_model = LlamaForCausalLM.from_pretrained(
    args.model,
    torch_dtype=torch.bfloat16,  
).to(accelerator.device)

tokenizer = LlamaTokenizer.from_pretrained(model_name)

data_file= "falcon-refinedweb/falcon-refinedweb_text_document"

ctx_len=1024
devices=8
epoch_count=250
gradient_accumulation_steps=1
micro_bsz=1

class Args:
    def __init__(self):
        self.data_file = data_file 
        self.ctx_len = ctx_len  
        self.vocab_size = 32000  
        self.epoch_steps = devices*epoch_count*gradient_accumulation_steps 
        self.micro_bsz = micro_bsz  
        
args = Args()
dataset = MMapIndexedDatasetWithContext (args)
dataset_length = len(dataset)

num_train_epochs=100
max_steps=num_train_epochs*epoch_count

deepspeed_config = {
    "train_micro_batch_size_per_gpu": micro_bsz,  
    "zero_optimization": {
        "stage": 2,  
        "allgather_partitions": True, 
        "allgather_bucket_size": 2e8, 
        "reduce_scatter": True,  
        "reduce_bucket_size": 2e8, 
        "contiguous_gradients": True 
    },
    "gradient_accumulation_steps": gradient_accumulation_steps,  
    "fp16": {
        "enabled": False  
    },
    "bf16": {
        "enabled": True  
    },
    "zero_allow_untested_optimizer": True,  
    "gradient_clipping": 1.0
}

with open("deepspeed_config.json", "w") as f:
    json.dump(deepspeed_config, f)

training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=micro_bsz,  
    num_train_epochs=num_train_epochs,
    logging_dir="./log",
    logging_steps=1,
    deepspeed="deepspeed_config.json", 
    bf16=True, 
    gradient_accumulation_steps=gradient_accumulation_steps,
    save_steps=0,
    save_strategy="no",
    max_steps=max_steps,
    learning_rate=1.5e-5,
    lr_scheduler_type="constant",
    max_grad_norm=1.0
    )

def formatting_func(example):
    return example

def custom_collate_fn(batch):
    return {key: torch.stack([example[key] for example in batch]) for key in batch[0]}

from src.trainer import TrainingMonitorCallback,DistillationTrainer
trainer = DistillationTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=tokenizer,
    data_collator=custom_collate_fn, 
    packing=False,
    formatting_func=formatting_func,
    max_seq_length=ctx_len, 
    teacher_model=teacher_model
)

trainer.add_callback(TrainingMonitorCallback(output_dir=output_dir))

trainer.train()

del trainer
torch.cuda.empty_cache()