import torch
import utils.data_utils as data_utils
import utils.module_utils as module_utils
import utils.eval_utils as eval_utils
from utils.train_utils import train
from transformers import AutoTokenizer, AutoModelForCausalLM
import argparse
import os

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', default="cuda:0")
    parser.add_argument('--model_path', default="llama-2-7b")
    parser.add_argument('--data_path', default="gen_data.jsonl")
    parser.add_argument('--save_path', default="llama-2-7b_quant")
    parser.add_argument('--model_max_length', default=2048,type=int)
    parser.add_argument('--mode', default='residual')
    parser.add_argument('--bit', default=2,type=int)
    parser.add_argument('--group_size', default=-1,type=int)
    parser.add_argument('--batch_size', default=8,type=int)
    parser.add_argument('--lr', default=1e-4,type=float)
    parser.add_argument('--epochs', default=1,type=int)
    parser.add_argument('--accumulater_step', default=1,type=int)
    parser.add_argument('--kd_loss_scale', default=0.01,type=float)
    args = parser.parse_args()


    tokenizer = AutoTokenizer.from_pretrained(args.model_path,
                                            model_max_length=args.model_max_length,
                                            padding_side="right") 
    model = AutoModelForCausalLM.from_pretrained(args.model_path,
                                            torch_dtype=torch.bfloat16)
    q_model = AutoModelForCausalLM.from_pretrained(args.model_path,
                                            torch_dtype=torch.bfloat16)

    train_set=data_utils.get_generated_dataset(path=args.data_path,
                                tokenizer=tokenizer,
                                max_len=args.model_max_length,
                                batch_size=args.batch_size,
                                return_loader=True)
    valid_set=data_utils.get_wiki_dataset(tokenizer=tokenizer,
                                max_len=args.model_max_length,
                                batch_size=1, 
                                split='validation',
                                return_loader=True
                                )

    model.eval()
    model.to(args.device)
    ppl = eval_utils.calculate_perplexity(model, 
                                          valid_set, 
                                          model.device)
    print(f'fp model validation Perplexity: {ppl:.4f}')

    q_model.to(args.device)
    module_utils.replace_linear(q_model,
                                mode=args.mode,
                                bit=args.bit,
                                groupsize=args.group_size,
                                info=False,
                                )
    model.eval()

    train(model, 
        q_model, 
        train_set, 
        valid_set, 
        args.device, 
        save_path=f'{args.save_path}', 
        lr=args.lr, 
        num_epochs=args.epochs,
        accumulation_steps=args.accumulater_step,
        kd_loss_scale=args.kd_loss_scale)


