import os
import sys
import torch
import wandb
import datasets
import json
import pickle
import transformers
from omegaconf import OmegaConf
from accelerate import Accelerator
from trainer import Trainer
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
from accelerate import Accelerator
from models import Union_Model
from utils import *
from StepRunners import AlignRunner
from peft import LoraConfig, get_peft_model
from utils import setup_seed


torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


def main():

    base_model_name = sys.argv[1]
    if base_model_name in ["llama3","gemma2","mistral"]:
         config = OmegaConf.load("./configs/%s_align_config.yaml"%base_model_name)
    else:
         print("Wrong Base Model Name!")
         return
    
    setup_seed(config.seed)
    accelerator = Accelerator(mixed_precision=config.mixed_precision,
                            gradient_accumulation_steps=config.gradient_accumulation,cpu=False)

    if accelerator.is_local_main_process:
            wandb.init(project = config.project_name)

    policy_model = transformers.AutoModelForCausalLM.from_pretrained(config.policy_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16,attn_implementation='eager')
    ref_model = transformers.AutoModelForCausalLM.from_pretrained(config.ref_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16,attn_implementation='eager')

    model_setting(policy_model,config.use_grad_ckpt)
    model_setting(ref_model,config.use_grad_ckpt)

    policy_tokenizer = transformers.AutoTokenizer.from_pretrained(config.policy_tokenizer_name)
    policy_tokenizer.pad_token_id = policy_tokenizer.eos_token_id

    union_model = Union_Model(policy_model, ref_model, policy_tokenizer, accelerator)
    if accelerator.is_local_main_process:
        wandb.watch(union_model)

    train_set = datasets.load_dataset(config.data_path, split="train")
        
    collator = Align_Collator(policy_tokenizer,config)
    train_loader = DataLoader(train_set, batch_size=config.batch_size, collate_fn=collator.collate_fn,shuffle=True,drop_last=True)

    Trainer.StepRunner = AlignRunner
    Trainer.save_ckpt = AlignRunner.save_ckpt

    optimizer_class = getattr(torch.optim, config.optimizer)
    optimizer = optimizer_class(union_model.parameters(), lr=config.learning_rate)
    lr_scheduler = transformers.get_cosine_schedule_with_warmup(optimizer,500,8000)
    #lr_scheduler = LambdaLR(optimizer, lr_lambda=lambda step: min(1.0, (step + 1) / (config.warmup_steps + 1)))
    trainer = Trainer(union_model, accelerator = accelerator,
                            optimizer=optimizer,tokenizer=policy_tokenizer, lr_scheduler=lr_scheduler, config = config)

    trainer.fit(dataloader=train_loader,
                    epochs=config.epoch,ckpt_path = config.ckpt_path)
    
if __name__ == '__main__':
    main()