# Load via Huggingface Style
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-g', '--gpu', default=7)
parser.add_argument('-p', '--ckpt_path')
parser.add_argument('-s', '--save_path')
parser.add_argument('-pro', '--prompt_path')
parser.add_argument('-r', '--rank')
parser.add_argument('-n', '--num_per_path', default=6, type=int)
parser.add_argument('-tk', '--top_k', default=2, type=int)
parser.add_argument('-t', '--temperature', default=1., type=float)
parser.add_argument('--save_every', default=100, type=int)
# parser.add_argument('-e', '--exp_id',default='0.3.brkt')
args = parser.parse_args()

# import pdb;pdb.set_trace()
import torch
import os
torch.set_grad_enabled(False)
os.environ['CUDA_VISIBLE_DEVICES']=str(args.gpu)
# import pdb;pdb.set_trace()
print(torch.cuda.device_count())

device="cuda"

import sys
sys.path.append('.')
sys.path.append('..')
sys.path.append('../..')
sys.path.append('../../..')
from mantis.models.mllava import MLlavaProcessor, LlavaForConditionalGeneration
from mantis.models.mllava import chat_mllava

processor = MLlavaProcessor.from_pretrained("TIGER-Lab/Mantis-8B-siglip-llama3")
attn_implementation = "flash_attention_2"
model = LlavaForConditionalGeneration.from_pretrained("TIGER-Lab/Mantis-8B-siglip-llama3", device_map="cuda", torch_dtype=torch.bfloat16, attn_implementation=attn_implementation)

from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
    r=128,
    lora_alpha=256,
    target_modules=r'.*language_model.*\.(q_proj|v_proj)',
    lora_dropout=0.05,
    bias='none',
    task_type="CAUSAL_LM",
)
# model.to(torch.bfloat16)
print("Adding LoRA adapters...")
model.enable_input_require_grads()
model = get_peft_model(model, lora_config)
print("Successfully added LoRA adapters")

import torch

import os
import glob
import json
import numpy as np
import sys
import jsonlines
from safetensors.torch import load_file
import torch

def load_jsonl(filename):
    with open(filename, "r", encoding="utf-8") as f:
        return [json.loads(l.strip("\n")) for l in f.readlines()]


os.makedirs(args.save_path, exist_ok=True)
save_path = os.path.join(args.save_path,'%s.jsonl'%args.rank)

start_idx = 0
# resume
if os.path.exists(save_path):
    with jsonlines.open(save_path) as f:
        for item in f:
            start_idx += 1
else:
    generated = []

checkpoint_path = args.ckpt_path
files = os.listdir(checkpoint_path)
use_bin = 'adapter_model.bin' in files
use_safetensors = 'adapter_model.safetensors' in files
if use_bin and use_safetensors:
    print('wft? try safetensors')

if use_bin:
    lora_path = os.path.join(checkpoint_path, 'adapter_model.bin')
    print('load lora from {}'.format(lora_path))
    prefix_state_dict = torch.load(lora_path, map_location='cpu')
if use_safetensors:
    lora_path = os.path.join(checkpoint_path, 'adapter_model.safetensors')
    prefix_state_dict = load_file(lora_path)
state_dict = {}
for k,v in prefix_state_dict.items():
    state_dict[k.replace('weight','default.weight')] = v
model.load_state_dict(state_dict, strict=False)
model.eval()

prompts_path = os.path.join(args.prompt_path,'%s.jsonl'%args.rank)
prompts = []
with jsonlines.open(prompts_path) as f:
    for item in f:
        prompts.append(item)

print('starts from %s/%s'%(start_idx,len(prompts)))
prompts = prompts[start_idx:]
iters = len(prompts)
infers = []
import tqdm

if args.top_k > 1:
    generate_kwargs = {
        "max_new_tokens": 75,
        "top_k": args.top_k,
        "do_sample": True,
        "temperature": args.temperature,
        'num_return_sequences': 1,
    }
else:
    generate_kwargs = {
        "max_new_tokens": 75,
        "do_sample": False,
        'num_return_sequences': 1,
    }

print(generate_kwargs)

infers = []
step = 0
for item in tqdm.tqdm(prompts):
    instr_id = item['instruction_id']
    images = [os.path.join('data/vln/imgs_90fov',_) for _ in item['image']]
    # import pdb;pdb.set_trace()
    # print(images)
    text = item['text']
    assert len(images) == text.count('<image>')
    responses = []
    with torch.no_grad():
        for i in range(args.num_per_path):
            response, _ = chat_mllava(text, images, model, processor, **generate_kwargs)
            responses.append(response)
    infers.append([instr_id, responses])
    step += 1
    if step %args.save_every == 0: # add mode
        with open(save_path, 'a') as f:
            for infer in infers:
                f.write(json.dumps(infer) + '\n')
        infers = []

with open(save_path, 'a') as f:
    for infer in infers:
        f.write(json.dumps(infer) + '\n')

# import jsonlines 
# with jsonlines.open(save_path, mode='w') as writer:
#     writer.write_all(infers)