from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, LlamaTokenizerFast, AutoConfig
from utils.data_utils import get_dataloaders
from utils import distill_utils, train_utils, eval_utils
from utils import lowrank_modeling
from modeling import modeling_llama
import argparse
from tqdm import tqdm
from transformers import AdamW
import torch
import wandb
import pdb
import os
import pickle
import time 
import json 
import torch 
import pandas as pd 
from os.path import join 
import wandb 

parser = argparse.ArgumentParser(description="Transformer model training and evaluation")

parser.add_argument("--model_name", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
                    help="The name or path of the pre-trained model to use.")

parser.add_argument("--model_path", type=str, default="",
                    help="Path to model")

parser.add_argument("--batch_size", type=int, default=2,
                    help="Batch size for model")

parser.add_argument('--debug', action='store_true', default=False, help='Debug mode, faster execution')

parser.add_argument("--exp_name", type=str, default='test', help="Experiment name")

parser.add_argument("--cache_dir", type=str, default='train_cache/', help='Directory where distillation cache is stored')

parser.add_argument("--ignore_compress", action='store_true', default=False, help='If true, do not use any SVD')

parser.add_argument("--lowrank_model", action='store_true', default=False, help='If true, expect that the model path is of a lowrnak model')

parser.add_argument("--rank_pct", type=float, default=0.15, help='Percent of singular values to retain')

parser.add_argument("--only_compress", type=str, default='', help='Layer to compression for, comma separated')

parser.add_argument("--num_train_samples", type=int, default=64,
                    help="The number of samples to use for the training dataset.")

parser.add_argument("--max_length", type=int, default=256, help="Maximum number of input tokens")

args = parser.parse_args()

def count_parameters(model):
    """
    Calculate the number of parameters in a model and return the count in billions.
    """
    total_params = sum(p.numel() for p in model.parameters())
    total_params_in_billion = total_params / 1e9
    return total_params_in_billion


print(args.rank_pct, args.ignore_compress)

args.layer_type = 'test'
args.num_test_samples = args.num_train_samples 
args.distill_mode = 'hs_last'
args.fix_length = True 
args.seed = 233

if args.debug: 
    os.environ["WANDB_MODE"] = "offline"

# wandb_writer = wandb.init(project="learn-to-compress-lrd", name=args.exp_name, config=vars(args))

if 'Llama-2' in args.model_name:
    tokenizer = LlamaTokenizerFast.from_pretrained(args.model_name, cache_dir=args.cache_dir)
    tokenizer.pad_token = tokenizer.unk_token
    tokenizer.pad_token_id = tokenizer.unk_token_id
    print('Loaded llama tokenizer')
else:
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=args.cache_dir)

start = time.time()

# if lowrank model path passed, load it 
if not args.model_path:
    model = AutoModelForCausalLM.from_pretrained(args.model_name, cache_dir=args.cache_dir)
else:
    config = AutoConfig.from_pretrained(args.model_name, cache_dir=args.cache_dir)
    with open(join(args.model_path, 'lowrank_config.json')) as f:
        lowrank_config = json.load(f) 

    config.lowrank_config = lowrank_config
    #model = modeling_llama.LlamaModel.from_pretrained(args.model_path, config=config)
    model = modeling_llama.LlamaForCausalLM.from_pretrained(args.model_path, config=config, torch_dtype=torch.float16)
    print('\nLoaded lowrank model\n')

total_params = count_parameters(model)
print(f'\nTotal parameters in model (billion): {total_params: 0.4f}\n')

if torch.cuda.is_available():
    print('Cuda available, using it')
    model = model.cuda() 
print(f'Model loaded in {time.time()-start: 0.2f} seconds')

if torch.cuda.is_available():
    print('Cuda available, using it')
    model = model.cuda()

wandb_writer = wandb.init(project="learn-to-compress-lrd2", name=args.exp_name, config=vars(args))

model = model.cuda()
harness_metrics = eval_utils.evaluate_with_harness_full(model, tokenizer, model.device, debug=False, batch_size=16)
harness_metrics = {'final_' + k: v for k, v in harness_metrics.items()}
wandb.log({**harness_metrics, 'step': 0})

