from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, LlamaTokenizerFast, BitsAndBytesConfig, GemmaTokenizerFast
from utils.data_utils import get_dataloaders
from utils import train_utils, loss_aware, lowrank_modeling, eval_utils
import argparse
import torch
import wandb
import os
import time 
import json
import torch 
import numpy as np

parser = argparse.ArgumentParser(description="Transformer model training and evaluation")

parser.add_argument("--model_name", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
                    help="The name or path of the pre-trained model to use.")

parser.add_argument("--batch_size", type=int, default=12,
                    help="Batch size for model")

parser.add_argument('--debug', action='store_true', default=False, help='Debug mode, faster execution')

parser.add_argument("--exp_name", type=str, default='test', help="Experiment name")

parser.add_argument("--cache_dir", type=str, default='train_cache/', help='Directory where distillation cache is stored')

parser.add_argument("--only_compress", type=str, default='', help='Layer to compression for, comma separated')

parser.add_argument("--rank_pct", type=float, default=0.4, help='Percent of singular values to retain')

parser.add_argument("--param_ratio", type=float, default=None, help='param_ratio/compression to use per layer. Overides rank_pct')

parser.add_argument("--act_aware", type=str, default='', help='Loss/activation aware SVD', choices=['', 'fisher', 'activation'])

parser.add_argument("--ignore_compression", action='store_true', default=False, help='Ignore compression, run the baseline')

parser.add_argument("--compression_path", type=str, default='', help='JSON that contains fixed compression rates to use per layer')

parser.add_argument("--load_act_cache", action='store_true', default=False, help='Load activing cache pkl file')

parser.add_argument("--seed", type=int, default=233, help='Set random seed')

parser.add_argument("--alpha", type=float, default=0.5, help='HyperParameter for ASVD')

parser.add_argument("--use_int8", action='store_true', default=False, help='Use Int8 Quantization')

args = parser.parse_args()

def count_parameters(model):
    """
    Calculate the number of parameters in a model and return the count in billions.
    """
    total_params = sum(p.numel() for p in model.parameters())
    total_params_in_billion = total_params / 1e9
    return total_params_in_billion

class ModelWrapper:
    def __init__(self, model):
        self.model = model

    def to(self, *args, **kwargs):
        return self.model

    def __getattr__(self, name):
        # Forward all other attribute access to the underlying model
        return getattr(self.model, name)

if args.debug: 
    os.environ["WANDB_MODE"] = "offline"

args.max_length = 256

# set seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True

# wandb logging
wandb_writer = wandb.init(project="learn-to-compress-lrd_final", name=args.exp_name, config=vars(args))

# load model 
if 'Llama-2' in args.model_name:
    tokenizer = LlamaTokenizerFast.from_pretrained(args.model_name, cache_dir=args.cache_dir)
    tokenizer.pad_token = tokenizer.unk_token
    tokenizer.pad_token_id = tokenizer.unk_token_id
    print('Loaded llama tokenizer')
elif 'Llama-3' in args.model_name:
    tokenizer = LlamaTokenizerFast.from_pretrained(args.model_name, cache_dir=args.cache_dir)
    tokenizer.pad_token = tokenizer.eos_token
    print('Loaded llama 3 tokenizer')
elif 'gemma' in args.model_name.lower():
    tokenizer = GemmaTokenizerFast.from_pretrained(args.model_name, cache_dir=args.cache_dir)
    tokenizer.pad_token = tokenizer.eos_token
else:
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=args.cache_dir)

start = time.time()
if args.use_int8:
    print('Loading model in 8bit')

    # Create a BitsAndBytesConfig object with the desired configuration
    quantization_config = BitsAndBytesConfig(load_in_8bit=True)

    # Load the model with the quantization configuration
    model = AutoModelForCausalLM.from_pretrained(
    args.model_name,
    quantization_config=quantization_config,
    device_map='auto',
    cache_dir=args.cache_dir
    )
    model = ModelWrapper(model) # if bits and bits model, model.to() is not supported. This is a fix for that
    assert args.ignore_compression, 'if using int8 quantization, cannot do compression'
else:
    model = AutoModelForCausalLM.from_pretrained(args.model_name, cache_dir=args.cache_dir, torch_dtype=torch.float16)

    if torch.cuda.is_available():
        print('Cuda available, using it')
        model = model.cuda() 

print(f'Model loaded in {time.time()-start: 0.2f} seconds')

# calculate activation-aware weight matrix
svd_info = {} 
if args.act_aware:
    if args.debug:
        args.num_train_samples = args.num_test_samples = 10
    else:
        args.num_train_samples = args.num_test_samples = 256

    args.layer_type = 'test'
    args.num_test_samples = args.num_train_samples 
    args.distill_mode = 'hs_last'
    
    _, _, calib_loader = get_dataloaders(tokenizer, args, dataset_name="wikitext2")

    if args.act_aware == 'fisher':
        svd_info = loss_aware.calib_fisher_info(model, calib_loader, args=args)
    elif args.act_aware == 'activation':
        svd_info = loss_aware.calib_input_distribution(model, calib_loader, method='abs_mean', args=args)
    else:
        raise NotImplementedError(f'Activation aware {args.act_aware} not supported')

# if pre-define amount of compression is required per layer, use this 
compression_dict = {}
if args.compression_path:
    with open(args.compression_path) as f:
        compression_dict.update(json.load(f)) 
    print("Loaded compression config") 
    assert compression_dict


num_params_old = count_parameters(model)

# perform SVD/ASVD on layers and edit model in-place
if not args.ignore_compression:
    #model = model.cpu(); torch.cuda.empty_cache()
    model = model.cuda() if torch.cuda.is_available() else model
    lowrank_modeling.replace_linear_with_svd_naiive(model, args, svd_info=svd_info, compression_dict=compression_dict)
    model = model.cuda() if torch.cuda.is_available() else model

# calculate compression
num_params_new = train_utils.count_parameters(model)
compression_stats = { "compression_stats/new_params_billion": num_params_new, "compression_stats/old_params_billion": num_params_old, "compression_stats/compression_ratio": num_params_new / num_params_old }
print(f"\n\n--Compression Stats---\n{json.dumps(compression_stats, indent=4)}")

# evaluate
harness_metrics = eval_utils.evaluate_with_harness_full(model, tokenizer, model.device, debug=args.debug, batch_size=args.batch_size)
harness_metrics = {'final_' + k: v for k, v in harness_metrics.items()}
wandb.log({**harness_metrics, **compression_stats, 'step': 0})
print(f'Metrics: {harness_metrics}')
