import argparse
import os
import torch
import collections
from transformers import AutoTokenizer, AutoModelForCausalLM
from model import GPT2Config, GPT2LMModel


parser = argparse.ArgumentParser(description='PyTorch GPT2 beam decoding')

parser.add_argument('--data', type=str, default='../data/wikitext-103',
                    help='location of the data corpus')

parser.add_argument('--batch_size', type=int, default=10,
                    help='batch size')

parser.add_argument('--seq_len', type=int, default=512,
                    help='number of tokens to predict')

parser.add_argument('--eval_len', type=int, default=256,
                    help='evaluation length')

parser.add_argument('--min_length', type=int, default=0,
                    help='minimum generation length')

parser.add_argument('--model_card', default='gpt2.sm', choices=['gpt2.sm', 'gpt2.md', 'gpt2.lg'],
                    help='model names')

parser.add_argument('--init_checkpoint', default=None, type=str, help='initial checkpoint')

parser.add_argument('--lora_dim', type=int, default=0, help='lora attn dimension')

parser.add_argument('--lora_alpha', type=int, default=128, help='lora attn alpha')

parser.add_argument('--work_dir', type=str, default=os.getenv('PT_OUTPUT_DIR', 'gpt2_model'), 
                    help='working folder')

parser.add_argument('--beam', type=int, default=1, help='beam search size')

parser.add_argument('--length_penalty', type=float, default=1.0, help='length penalty')

parser.add_argument('--no_repeat_ngram_size', type=int, default=4, help='no_repeat_ngram_size')

parser.add_argument('--repetition_penalty', type=float, default=1.0, help='repetition_penalty')

parser.add_argument('--eos_token_id', action='append', type=int, default=[50256], 
                    help='eos token id')

parser.add_argument('--output_file', type=str, default='beam_prediction.jsonl', 
                    help='output file name')


args = parser.parse_args()


enable_lora_attn = [True, True, True, True] # Q,K,V,out_proj
enable_lora_mlp = True
enable_lora_head = False

if args.model_card == 'gpt2.sm':
    config = GPT2Config(
        n_embd=768, n_layer=12, n_head=12, 
        lora_dim=args.lora_dim, 
        lora_alpha=args.lora_alpha, 
        enable_lora_attn=enable_lora_attn,
        enable_lora_mlp=enable_lora_mlp,
        enable_lora_head=enable_lora_head
    )
elif args.model_card == 'gpt2.md':
    config = GPT2Config(
        n_embd=1024, n_layer=24, n_head=16, 
        lora_dim=args.lora_dim, 
        lora_alpha=args.lora_alpha, 
        enable_lora_attn=enable_lora_attn,
        enable_lora_mlp=enable_lora_mlp,
        enable_lora_head=enable_lora_head
    )
elif args.model_card == 'gpt2.lg':
    config = GPT2Config(
        n_embd=1280, n_layer=36, n_head=20, 
        lora_dim=args.lora_dim, 
        lora_alpha=args.lora_alpha, 
        enable_lora_attn=enable_lora_attn,
        enable_lora_mlp=enable_lora_mlp,
        enable_lora_head=enable_lora_head
    )

lm_net = GPT2LMModel(config)

print('loading model pretrained weight.')
cp = torch.load(args.init_checkpoint, map_location=torch.device('cpu'))

if 'model_state_dict' in cp:
    cp = cp['model_state_dict']

# in the case of ddp all keys start with 'module'
if ('module' == list(cp.keys())[0].split('.')[0]) and not hasattr(lm_net, 'module'):
    cp = collections.OrderedDict([('.'.join(k.split('.')[1:]),v) for k,v in cp.items()])

# adaptive ranks
for n, m in lm_net.named_modules():
    if hasattr(m, 'lora_B'): # and (m.r != args.lora_dim)
        try:
            lora_b = cp[n + '.lora_B'] # in case of mergedlinear, lora_A can be a multiple of lora dim
            if lora_b.shape[1] != args.lora_dim:
                m.change_lora_rank(lora_b.shape[1])
        except KeyError: # assuming lora dim = 0
            m.change_lora_rank(0)

lm_net.load_state_dict(cp)
del cp


model_name_or_path = "gpt2-medium" # "meta-llama/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)