from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import json
import re
import argparse
from transformers import BitsAndBytesConfig
from tqdm import tqdm
import os

def _format_chat_template(tokenzier, instruction, prompt): #few_shots = few_shots):
    messages = [instruction]
    messages.append({"role": "user", "content": prompt})
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    print("\n",text)
    pdb.set_trace()
    return text

def response_to(tokenizer, instruction, device, prompt):
  model_inputs = tokenizer([_format_chat_template(tokenizer, instruction, prompt)], return_tensors="pt").to(device)
  generated_ids = model.generate(
      model_inputs.input_ids,
      max_new_tokens=256, ## assume enough for math multiplication
  )
  generated_ids = [
      output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
  ]
  response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
  return response

def response_to_data(tokenzier, instruction, device, num_steps = 2, k=1000, model_size="14"):
    filename = f"Heights_{num_steps}_steps.json"
    filename_answers = f"Heights_{s}_steps_withAnswer_{model_size}B-Chat_{k}=k.json"
    if os.path.exists(filename_answers):
        return
    dataset = json.load(open(filename,"r"))[:k]
    for entry in tqdm(dataset):
        x = entry["x"]
        y = entry["y"]
        yhat = response_to(tokenizer, instruction, device, entry["x"])
        entry["raw yhat"] = yhat
        matches = matches = re.findall(r'\d+\s*cm', yhat)
        if len(matches) > 0:
            entry["yhat"] = matches[-1].replace(" ","")
        else:
            matches = re.findall(r'\d', yhat)
            if len(matches) > 0:
                entry["yhat"] = matches[-1]+"cm"
            else:
                entry["yhat"] = None
    filename_answers = f"Heights_{s}_steps_withAnswer_{model_size}B-Chat_k={k}.json"
    with open(filename_answers, 'w') as json_file:
        json.dump(dataset, json_file, indent=4)

if __name__ == "__main__":  
    parser = argparse.ArgumentParser(description='main')
    parser.add_argument('--steps', default=[1,2,3,4,5,6], type=int, nargs='+', help='List of integers')
    parser.add_argument('--gpu', default=0, type=int)
    parser.add_argument('--model_size', default="14", type=str)
    parser.add_argument('--k', default="1000", type=int)
    args = parser.parse_args()
    model_name= f"Qwen/Qwen1.5-{args.model_size}B-Chat" 
    device = f"cuda:{args.gpu}"
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map= device,
        load_in_4bit=True, # previous version does not request it 
        trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    few_shots = """{
        "Jordan is 166cm tall. Grace is 4cm shorter than Jordan. How tall is Grace?":"162cm",
        "Diana is 157cm tall. Joyce is 13cm taller than Diana. How tall is Joyce?":"170cm",
        "Lee is 171cm tall. Gary is 3cm shorter than Lee. How tall is Gary?":"168cm",
        "Howard is 178cm tall. Travis is 1cm taller than Howard. Samuel is 6cm taller than Travis. How tall is Samuel?":"185cm",
        "Henry is 197cm tall. Alexander is 37cm shorter than Henry. Brenda is 9cm shorter than Alexander. How tall is Brenda?":"151cm",
        "Thomas is 154cm tall. Linda is 20cm taller than Thomas. John is 25cm taller than Linda. How tall is John?":"199cm",
        "Sophie is 189cm tall. Lauren is 33cm shorter than Sophie. Sarah is 39cm taller than Lauren. Douglas is 12cm shorter than Sarah. How tall is Douglas?":"183cm",
        "Benjamin is 199cm tall. Aiden is 33cm shorter than Benjamin. Michael is 3cm shorter than Aiden. Tiffany is 3cm taller than Michael. How tall is Tiffany?":"166cm",
        "Angela is 197cm tall. Jordan is as tall as Angela. Jonathan is 1cm taller than Jordan. Susan is 47cm shorter than Jonathan. How tall is Susan?":"151cm",
        "Jessica is 166cm tall. Spencer is 28cm taller than Jessica. John is 28cm shorter than Spencer. Amanda is 7cm taller than John. Lee is 14cm shorter than Amanda. How tall is Lee?":"159cm",
        "Mia is 190cm tall. Patrick is 4cm taller than Mia. Matthew is 6cm shorter than Patrick. Phyllis is 34cm shorter than Matthew. Alice is 16cm taller than Phyllis. How tall is Alice?":"170cm",
        "Holly is 188cm tall. Luna is 11cm shorter than Holly. Jasmine is 6cm shorter than Luna. Willie is 11cm shorter than Jasmine. Douglas is 28cm taller than Willie. How tall is Douglas?":"188cm",
    }"""
    instruction ={"role": "system", "content": "Answer each question using one integer followed by 'cm', e.g. '171cm'. Examples:" + few_shots}
    for s in args.steps:
        print(model_name, "steps = ", s)
        response_to_data(tokenizer, instruction, device, num_steps = s, k=args.k, model_size=args.model_size)

## run in terminal : 
## for size in 57 72; do python3 QwenHeights.py --steps 1 2 3 4 5 6 --gpu 0 --model_size ${size}; done;

## equivalently :
## python3 QwenHeights.py --steps 1 2 3 4 5 6 --gpu 0 --model_size 57; python3 QwenHeights.py --steps 1 2 3 4 5 6 --gpu 0 --model_size 72;
    






